mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[BUG] Fix model_kwargs pass logic for vllm (#1958)
This commit is contained in:
parent
0b7f76e193
commit
5d2d253d83
@ -276,13 +276,15 @@ def change_accelerator(models, accelerator):
|
||||
if model.get(item) is not None:
|
||||
acc_model[item] = model[item]
|
||||
elif accelerator == 'vllm':
|
||||
model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None))
|
||||
model_kwargs.update(model.get('model_kwargs'))
|
||||
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
|
||||
|
||||
acc_model = dict(
|
||||
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
||||
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||
path=model['path'],
|
||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None)),
|
||||
model_kwargs=model_kwargs,
|
||||
max_out_len=model['max_out_len'],
|
||||
max_seq_len=model.get('max_seq_len', None),
|
||||
batch_size=model['batch_size'],
|
||||
@ -296,12 +298,14 @@ def change_accelerator(models, accelerator):
|
||||
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
|
||||
elif model['type'] in [HuggingFacewithChatTemplate, f'{HuggingFacewithChatTemplate.__module__}.{HuggingFacewithChatTemplate.__name__}']:
|
||||
if accelerator == 'vllm':
|
||||
model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None))
|
||||
model_kwargs.update(model.get('model_kwargs'))
|
||||
mod = VLLMwithChatTemplate
|
||||
acc_model = dict(
|
||||
type=f'{mod.__module__}.{mod.__name__}',
|
||||
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||
path=model['path'],
|
||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None)),
|
||||
model_kwargs=model_kwargs,
|
||||
max_seq_len=model.get('max_seq_len', None),
|
||||
max_out_len=model['max_out_len'],
|
||||
batch_size=16,
|
||||
|
Loading…
Reference in New Issue
Block a user