mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Update accelerator (#1122)
(cherry picked from commit 4beb6d9ab655d8a626971841b7acfd9fae9d438f) Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
parent
a71122ee18
commit
19d7e630d6
@ -53,8 +53,8 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--accelerator',
|
'--accelerator',
|
||||||
help='Infer accelerator, support vllm and lmdeploy now.',
|
help='Infer accelerator, support vllm and lmdeploy now.',
|
||||||
choices=['vllm', 'lmdeploy', 'hg'],
|
choices=['vllm', 'lmdeploy', 'hf'],
|
||||||
default='hg',
|
default='hf',
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('-m',
|
parser.add_argument('-m',
|
||||||
'--mode',
|
'--mode',
|
||||||
|
@ -220,7 +220,7 @@ def change_accelerator(models, accelerator):
|
|||||||
if accelerator == 'lmdeploy':
|
if accelerator == 'lmdeploy':
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f'Transforming {model["abbr"]} to {accelerator}')
|
f'Transforming {model["abbr"]} to {accelerator}')
|
||||||
model = dict(
|
acc_model = dict(
|
||||||
type= # noqa E251
|
type= # noqa E251
|
||||||
f'{TurboMindModel.__module__}.{TurboMindModel.__name__}',
|
f'{TurboMindModel.__module__}.{TurboMindModel.__name__}',
|
||||||
abbr=model['abbr'].replace('hf', 'lmdeploy')
|
abbr=model['abbr'].replace('hf', 'lmdeploy')
|
||||||
@ -242,12 +242,12 @@ def change_accelerator(models, accelerator):
|
|||||||
)
|
)
|
||||||
for item in ['meta_template']:
|
for item in ['meta_template']:
|
||||||
if model.get(item) is not None:
|
if model.get(item) is not None:
|
||||||
model.update(item, model[item])
|
acc_model[item] = model[item]
|
||||||
elif accelerator == 'vllm':
|
elif accelerator == 'vllm':
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f'Transforming {model["abbr"]} to {accelerator}')
|
f'Transforming {model["abbr"]} to {accelerator}')
|
||||||
|
|
||||||
model = dict(
|
acc_model = dict(
|
||||||
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
||||||
abbr=model['abbr'].replace('hf', 'vllm')
|
abbr=model['abbr'].replace('hf', 'vllm')
|
||||||
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||||
@ -262,12 +262,10 @@ def change_accelerator(models, accelerator):
|
|||||||
)
|
)
|
||||||
for item in ['meta_template', 'end_str']:
|
for item in ['meta_template', 'end_str']:
|
||||||
if model.get(item) is not None:
|
if model.get(item) is not None:
|
||||||
model.update(item, model[item])
|
acc_model[item] = model[item]
|
||||||
generation_kwargs.update(
|
|
||||||
dict(temperature=gen_args['temperature']))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported accelerator {accelerator}')
|
raise ValueError(f'Unsupported accelerator {accelerator}')
|
||||||
model_accels.append(model)
|
model_accels.append(acc_model)
|
||||||
return model_accels
|
return model_accels
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user