mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[BUG] Fix model_kwargs pass logic for vllm (#1958)
This commit is contained in:
parent
0b7f76e193
commit
5d2d253d83
@ -276,13 +276,15 @@ def change_accelerator(models, accelerator):
|
|||||||
if model.get(item) is not None:
|
if model.get(item) is not None:
|
||||||
acc_model[item] = model[item]
|
acc_model[item] = model[item]
|
||||||
elif accelerator == 'vllm':
|
elif accelerator == 'vllm':
|
||||||
|
model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None))
|
||||||
|
model_kwargs.update(model.get('model_kwargs'))
|
||||||
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
|
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
|
||||||
|
|
||||||
acc_model = dict(
|
acc_model = dict(
|
||||||
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
||||||
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||||
path=model['path'],
|
path=model['path'],
|
||||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None)),
|
model_kwargs=model_kwargs,
|
||||||
max_out_len=model['max_out_len'],
|
max_out_len=model['max_out_len'],
|
||||||
max_seq_len=model.get('max_seq_len', None),
|
max_seq_len=model.get('max_seq_len', None),
|
||||||
batch_size=model['batch_size'],
|
batch_size=model['batch_size'],
|
||||||
@ -296,12 +298,14 @@ def change_accelerator(models, accelerator):
|
|||||||
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
|
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
|
||||||
elif model['type'] in [HuggingFacewithChatTemplate, f'{HuggingFacewithChatTemplate.__module__}.{HuggingFacewithChatTemplate.__name__}']:
|
elif model['type'] in [HuggingFacewithChatTemplate, f'{HuggingFacewithChatTemplate.__module__}.{HuggingFacewithChatTemplate.__name__}']:
|
||||||
if accelerator == 'vllm':
|
if accelerator == 'vllm':
|
||||||
|
model_kwargs = dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None))
|
||||||
|
model_kwargs.update(model.get('model_kwargs'))
|
||||||
mod = VLLMwithChatTemplate
|
mod = VLLMwithChatTemplate
|
||||||
acc_model = dict(
|
acc_model = dict(
|
||||||
type=f'{mod.__module__}.{mod.__name__}',
|
type=f'{mod.__module__}.{mod.__name__}',
|
||||||
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||||
path=model['path'],
|
path=model['path'],
|
||||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus'], max_model_len=model.get('max_seq_len', None)),
|
model_kwargs=model_kwargs,
|
||||||
max_seq_len=model.get('max_seq_len', None),
|
max_seq_len=model.get('max_seq_len', None),
|
||||||
max_out_len=model['max_out_len'],
|
max_out_len=model['max_out_len'],
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
Loading…
Reference in New Issue
Block a user