This commit is contained in:
Zhangzefeng 2025-05-28 11:28:48 +02:00 committed by GitHub
commit d7ad028276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ from typing import List, Tuple, Union
import tabulate import tabulate
from mmengine.config import Config from mmengine.config import Config
import opencompass
from opencompass.datasets.custom import make_custom_dataset_config from opencompass.datasets.custom import make_custom_dataset_config
from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel, from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel,
HuggingFaceCausalLM, HuggingFaceChatGLM3, HuggingFaceCausalLM, HuggingFaceChatGLM3,
@ -240,7 +241,7 @@ def change_accelerator(models, accelerator):
for model in models: for model in models:
logger.info(f'Transforming {model["abbr"]} to {accelerator}') logger.info(f'Transforming {model["abbr"]} to {accelerator}')
# change HuggingFace model to VLLM or LMDeploy # change HuggingFace model to VLLM or LMDeploy
if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3, f'{HuggingFaceBaseModel.__module__}.{HuggingFaceBaseModel.__name__}']: if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3, eval(f'{HuggingFaceBaseModel.__module__}.{HuggingFaceBaseModel.__name__}')]:
gen_args = dict() gen_args = dict()
if model.get('generation_kwargs') is not None: if model.get('generation_kwargs') is not None:
generation_kwargs = model['generation_kwargs'].copy() generation_kwargs = model['generation_kwargs'].copy()
@ -302,7 +303,7 @@ def change_accelerator(models, accelerator):
acc_model[item] = model[item] acc_model[item] = model[item]
else: else:
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}') raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
elif model['type'] in [HuggingFacewithChatTemplate, f'{HuggingFacewithChatTemplate.__module__}.{HuggingFacewithChatTemplate.__name__}']: elif model['type'] in [HuggingFacewithChatTemplate, eval(f'{HuggingFacewithChatTemplate.__module__}.{HuggingFacewithChatTemplate.__name__}')]:
if accelerator == 'vllm': if accelerator == 'vllm':
model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None)) model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None))
model_kwargs.update(model.get('model_kwargs')) model_kwargs.update(model.get('model_kwargs'))