[Sync] Update accelerator (#1122)

(cherry picked from commit 4beb6d9ab655d8a626971841b7acfd9fae9d438f)

Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
Fengzhe Zhou 2024-05-09 14:32:31 +08:00 committed by GitHub
parent a71122ee18
commit 19d7e630d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 9 deletions

View File

@ -53,8 +53,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--accelerator', '--accelerator',
help='Infer accelerator, support vllm and lmdeploy now.', help='Infer accelerator, support vllm and lmdeploy now.',
choices=['vllm', 'lmdeploy', 'hg'], choices=['vllm', 'lmdeploy', 'hf'],
default='hg', default='hf',
type=str) type=str)
parser.add_argument('-m', parser.add_argument('-m',
'--mode', '--mode',

View File

@ -220,7 +220,7 @@ def change_accelerator(models, accelerator):
if accelerator == 'lmdeploy': if accelerator == 'lmdeploy':
get_logger().info( get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}') f'Transforming {model["abbr"]} to {accelerator}')
model = dict( acc_model = dict(
type= # noqa E251 type= # noqa E251
f'{TurboMindModel.__module__}.{TurboMindModel.__name__}', f'{TurboMindModel.__module__}.{TurboMindModel.__name__}',
abbr=model['abbr'].replace('hf', 'lmdeploy') abbr=model['abbr'].replace('hf', 'lmdeploy')
@ -242,12 +242,12 @@ def change_accelerator(models, accelerator):
) )
for item in ['meta_template']: for item in ['meta_template']:
if model.get(item) is not None: if model.get(item) is not None:
model.update(item, model[item]) acc_model[item] = model[item]
elif accelerator == 'vllm': elif accelerator == 'vllm':
get_logger().info( get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}') f'Transforming {model["abbr"]} to {accelerator}')
model = dict( acc_model = dict(
type=f'{VLLM.__module__}.{VLLM.__name__}', type=f'{VLLM.__module__}.{VLLM.__name__}',
abbr=model['abbr'].replace('hf', 'vllm') abbr=model['abbr'].replace('hf', 'vllm')
if '-hf' in model['abbr'] else model['abbr'] + '-vllm', if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
@ -262,12 +262,10 @@ def change_accelerator(models, accelerator):
) )
for item in ['meta_template', 'end_str']: for item in ['meta_template', 'end_str']:
if model.get(item) is not None: if model.get(item) is not None:
model.update(item, model[item]) acc_model[item] = model[item]
generation_kwargs.update(
dict(temperature=gen_args['temperature']))
else: else:
raise ValueError(f'Unsupported accelerator {accelerator}') raise ValueError(f'Unsupported accelerator {accelerator}')
model_accels.append(model) model_accels.append(acc_model)
return model_accels return model_accels