[Sync] Update accelerator (#1122)

(cherry picked from commit 4beb6d9ab655d8a626971841b7acfd9fae9d438f)

Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
Fengzhe Zhou 2024-05-09 14:32:31 +08:00 committed by GitHub
parent a71122ee18
commit 19d7e630d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 9 deletions

View File

@ -53,8 +53,8 @@ def parse_args():
parser.add_argument(
'--accelerator',
help='Infer accelerator, support vllm and lmdeploy now.',
choices=['vllm', 'lmdeploy', 'hg'],
default='hg',
choices=['vllm', 'lmdeploy', 'hf'],
default='hf',
type=str)
parser.add_argument('-m',
'--mode',

View File

@ -220,7 +220,7 @@ def change_accelerator(models, accelerator):
if accelerator == 'lmdeploy':
get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}')
model = dict(
acc_model = dict(
type= # noqa E251
f'{TurboMindModel.__module__}.{TurboMindModel.__name__}',
abbr=model['abbr'].replace('hf', 'lmdeploy')
@ -242,12 +242,12 @@ def change_accelerator(models, accelerator):
)
for item in ['meta_template']:
if model.get(item) is not None:
model.update(item, model[item])
acc_model[item] = model[item]
elif accelerator == 'vllm':
get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}')
model = dict(
acc_model = dict(
type=f'{VLLM.__module__}.{VLLM.__name__}',
abbr=model['abbr'].replace('hf', 'vllm')
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
@ -262,12 +262,10 @@ def change_accelerator(models, accelerator):
)
for item in ['meta_template', 'end_str']:
if model.get(item) is not None:
model.update(item, model[item])
generation_kwargs.update(
dict(temperature=gen_args['temperature']))
acc_model[item] = model[item]
else:
raise ValueError(f'Unsupported accelerator {accelerator}')
model_accels.append(model)
model_accels.append(acc_model)
return model_accels