From 8ea2c404d756a709a885f59bd052b1ada5fbe76d Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Wed, 15 May 2024 21:51:07 +0800 Subject: [PATCH] [Feat] enable HuggingFacewithChatTemplate with --accelerator via cli (#1163) * enable HuggingFacewithChatTemplate with --accelerator via cli * rm vllm_internlm2_chat_7b --- .../hf_internlm/lmdeploy_internlm2_chat_7b.py | 33 +-- .../hf_internlm/vllm_internlm2_chat_7b.py | 13 ++ configs/models/qwen/vllm_qwen1_5_7b_chat.py | 13 ++ opencompass/cli/main.py | 4 +- opencompass/models/__init__.py | 3 + opencompass/models/huggingface_above_v4_33.py | 2 +- .../models/turbomind_with_tf_above_v4_33.py | 195 ++++++++++++++++++ opencompass/models/vllm.py | 15 -- .../models/vllm_with_tf_above_v4_33.py | 127 ++++++++++++ opencompass/utils/run.py | 99 +++++---- 10 files changed, 418 insertions(+), 86 deletions(-) create mode 100644 configs/models/hf_internlm/vllm_internlm2_chat_7b.py create mode 100644 configs/models/qwen/vllm_qwen1_5_7b_chat.py create mode 100644 opencompass/models/turbomind_with_tf_above_v4_33.py create mode 100644 opencompass/models/vllm_with_tf_above_v4_33.py diff --git a/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py b/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py index 65f498e1..b604a04c 100644 --- a/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py +++ b/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py @@ -1,36 +1,23 @@ -from opencompass.models.turbomind import TurboMindModel - - -_meta_template = dict( - round=[ - dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'), - dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True), - ], -) +from opencompass.models import TurboMindModelwithChatTemplate models = [ dict( - type=TurboMindModel, + type=TurboMindModelwithChatTemplate, abbr='internlm2-chat-7b-turbomind', path='internlm/internlm2-chat-7b', - meta_template=_meta_template, engine_config=dict( - session_len=32768, - max_batch_size=32, - model_name='internlm2-chat-7b', + max_batch_size=16, tp=1, - stop_words=[2, 92542], ), gen_config=dict( top_k=1, - top_p=0.8, - temperature=1.0, - max_new_tokens=2000, + temperature=1e-6, + top_p=0.9, ), - max_out_len=2000, - max_seq_len=32768, - batch_size=32, - concurrency=8, - run_cfg=dict(num_gpus=1, num_procs=1), + max_seq_len=2048, + max_out_len=1024, + batch_size=32768, + run_cfg=dict(num_gpus=1), + stop_words=['', '<|im_end|>'], ) ] diff --git a/configs/models/hf_internlm/vllm_internlm2_chat_7b.py b/configs/models/hf_internlm/vllm_internlm2_chat_7b.py new file mode 100644 index 00000000..50a413ea --- /dev/null +++ b/configs/models/hf_internlm/vllm_internlm2_chat_7b.py @@ -0,0 +1,13 @@ +from opencompass.models import VLLMwithChatTemplate + +models = [ + dict( + type=VLLMwithChatTemplate, + abbr='internlm2-chat-7b-vllm', + path='internlm/internlm2-chat-7b', + model_kwargs=dict(tensor_parallel_size=1), + max_out_len=1024, + batch_size=32768, + run_cfg=dict(num_gpus=1), + ) +] diff --git a/configs/models/qwen/vllm_qwen1_5_7b_chat.py b/configs/models/qwen/vllm_qwen1_5_7b_chat.py new file mode 100644 index 00000000..f97c716f --- /dev/null +++ b/configs/models/qwen/vllm_qwen1_5_7b_chat.py @@ -0,0 +1,13 @@ +from opencompass.models import VLLMwithChatTemplate + +models = [ + dict( + type=VLLMwithChatTemplate, + abbr='qwen1.5-7b-chat-vllm', + path='Qwen/Qwen1.5-7B-Chat', + model_kwargs=dict(tensor_parallel_size=1), + max_out_len=1024, + batch_size=32768, + run_cfg=dict(num_gpus=1), + ) +] diff --git a/opencompass/cli/main.py b/opencompass/cli/main.py index c2682ab2..7defe785 100644 --- a/opencompass/cli/main.py +++ b/opencompass/cli/main.py @@ -55,8 +55,8 @@ def parse_args(): parser.add_argument( '-a', '--accelerator', help='Infer accelerator, support vllm and lmdeploy now.', - choices=['vllm', 'lmdeploy', 'hf'], - default='hf', + choices=['vllm', 'lmdeploy', None], + default=None, type=str) parser.add_argument('-m', '--mode', diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index d7f7c063..8a9375a7 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -35,8 +35,11 @@ from .sensetime_api import SenseTime # noqa: F401 from .stepfun_api import StepFun # noqa: F401 from .turbomind import TurboMindModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401 +from .turbomind_with_tf_above_v4_33 import \ + TurboMindModelwithChatTemplate # noqa: F401 from .unigpt_api import UniGPT # noqa: F401 from .vllm import VLLM # noqa: F401 +from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401 from .xunfei_api import XunFei, XunFeiSpark # noqa: F401 from .yayi_api import Yayi # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 41356341..c88d2c4c 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -64,7 +64,7 @@ def _convert_chat_messages(inputs): for _input in inputs: messages = [] if isinstance(_input, str): - messages.append({'role': 'HUMAN', 'prompt': _input}) + messages.append({'role': 'HUMAN', 'content': _input}) else: for item in _input: role = { diff --git a/opencompass/models/turbomind_with_tf_above_v4_33.py b/opencompass/models/turbomind_with_tf_above_v4_33.py new file mode 100644 index 00000000..56aa5430 --- /dev/null +++ b/opencompass/models/turbomind_with_tf_above_v4_33.py @@ -0,0 +1,195 @@ +# flake8: noqa +# yapf: disable +import copy +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.models.base import BaseModel +from opencompass.utils.logging import get_logger +from opencompass.utils.prompt import PromptList + +from .huggingface_above_v4_33 import (_convert_chat_messages, + _format_with_fast_chat_template, + _get_meta_template, + _get_possible_max_seq_len) + +PromptType = Union[PromptList, str] + + +def valid_str(string, coding='utf-8'): + """decode text according to its encoding type.""" + invalid_chars = [b'\xef\xbf\xbd'] + bstr = bytes(string, coding) + for invalid_char in invalid_chars: + bstr = bstr.replace(invalid_char, b'') + ret = bstr.decode(encoding=coding, errors='ignore') + return ret + + +class TurboMindModelwithChatTemplate(BaseModel): + def __init__( + self, + path: str, + tokenizer_only: bool = False, + engine_config: Dict = {}, + gen_config: Dict = {}, + concurrency: int = 8, + max_seq_len: int = None, + meta_template: Optional[Dict] = None, + fastchat_template: Optional[str] = None, + stop_words: List[str] = [], + ): + from lmdeploy.messages import TurbomindEngineConfig + from lmdeploy.turbomind import TurboMind + from lmdeploy.version import version_info + from transformers import AutoTokenizer + + self.logger = get_logger() + self.path = path + self.tokenizer_only = tokenizer_only + self.template_parser = _get_meta_template(meta_template) + self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) + + self.origin_tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + if not tokenizer_only: + DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len} + _engine_config = DEFAULT_ENGING_CONFIG.copy() + _engine_config.update(engine_config) + engine_config = TurbomindEngineConfig(**_engine_config) + tm_model = TurboMind.from_pretrained(path, engine_config=engine_config) + self.tokenizer = tm_model.tokenizer + self.generators = [tm_model.create_instance() for i in range(concurrency)] + self.generator_ids = [i + 1 for i in range(concurrency)] + self.concurrency = concurrency + self.gen_config = gen_config + self.version_info = version_info + self.fastchat_template = fastchat_template + self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) + self.logger.info(f'using stop words: {self.stop_words}') + + def _get_potential_stop_words(self, path: Optional[str]): + from transformers import GenerationConfig + potential_stop_words = [] + try: + generation_config = GenerationConfig.from_pretrained(path) + for token_id in generation_config.eos_token_id: + potential_stop_words.append(self.origin_tokenizer.decode(token_id)) + except: + pass + potential_stop_words.append(self.origin_tokenizer.eos_token) + potential_stop_words = list(set(potential_stop_words)) + return potential_stop_words + + def generate(self, + inputs: List[str], + max_out_len: int = 512, + stopping_criteria: List[str] = [], + do_sample: Optional[bool] = None, + temperature: int = 1, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of prompts + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + assert isinstance(inputs, List), f'List(str) is expected, but got {type(inputs)}' + + messages = _convert_chat_messages(inputs) + if self.fastchat_template: + messages = _format_with_fast_chat_template(messages, self.fastchat_template) + else: + messages = [self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] + + # split messages into batches + batch_messages = [messages[i:i + self.concurrency] for i in range(0, len(messages), self.concurrency)] + + stop_words = list(set(self.stop_words + stopping_criteria)) + DEFAULT_GEN_CONFIG = { + 'max_new_tokens': max_out_len, + 'min_new_tokens': 1, + 'top_k': 1, + 'stop_words': stop_words, + } + gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG) + gen_config.update(self.gen_config) + if do_sample: + gen_config['top_k'] = 1000 + gen_config['temperature'] = temperature + # if stopping_criteria: + # stop_words = gen_config.get('stop_words', []) + # for t in stopping_criteria: + # t = self.tokenizer.encode(t, add_bos=False) + # stop_words.append(t[0]) + # gen_config['stop_words'] = list(set(stop_words)) + from lmdeploy.messages import EngineGenerationConfig, GenerationConfig + gen_config = GenerationConfig(**gen_config) + gen_config = EngineGenerationConfig.From(gen_config, self.tokenizer) + + results = [] + for batch_message in batch_messages: + n = len(batch_message) + with ThreadPoolExecutor() as executor: + _results = list( + executor.map( + self._generate, + self.generators[:n], + self.generator_ids[:n], + batch_message, + [gen_config] * n, + )) + results += _results + + for s in stop_words: + results = [r.split(s)[0] for r in results] + return results + + def _generate(self, + generator, + session_id, + prompt: PromptType, + gen_config=None) -> str: + """Generate results given a list of inputs. + + Args: + prompt (PromptType): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + gen_config (EngineGenerationConfig, optional): Generation + config to set arguments like top_k, top_p, temperature. + Returns: + str: The generated string. + """ + assert type(prompt) is str, 'We only support string for TurboMind Python API' + + input_ids = self.tokenizer.encode(prompt) + for outputs in generator.stream_infer(session_id=session_id, + input_ids=[input_ids], + gen_config=gen_config, + sequence_start=True, + sequence_end=True, + step=0, + stream_output=False): + if self.version_info >= (0, 4, 0): + output_ids = outputs.token_ids + else: + _, output_ids, _ = outputs + response = self.tokenizer.decode(output_ids) + response = valid_str(response) + return response + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + m = _convert_chat_messages([prompt])[0] + t = self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True) + return len(t['input_ids']) diff --git a/opencompass/models/vllm.py b/opencompass/models/vllm.py index e204c0c4..3919e2e1 100644 --- a/opencompass/models/vllm.py +++ b/opencompass/models/vllm.py @@ -130,21 +130,6 @@ class VLLM(BaseModel): ce_loss.append(loss) return np.array(ce_loss) - def prompts_preproccess(self, inputs: List[str]): - if self.use_fastchat_template: - try: - from fastchat.model import get_conversation_template - except ModuleNotFoundError: - raise ModuleNotFoundError( - 'Fastchat is not implemented. You can use ' - "'pip install \"fschat[model_worker,webui]\"' " - 'to implement fastchat.') - conv = get_conversation_template('vicuna') - conv.append_message(conv.roles[0], inputs[0]) - conv.append_message(conv.roles[1], None) - inputs = [conv.get_prompt()] - return inputs - def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. diff --git a/opencompass/models/vllm_with_tf_above_v4_33.py b/opencompass/models/vllm_with_tf_above_v4_33.py new file mode 100644 index 00000000..87421ace --- /dev/null +++ b/opencompass/models/vllm_with_tf_above_v4_33.py @@ -0,0 +1,127 @@ +# flake8: noqa +# yapf: disable +from typing import Dict, List, Optional + +import numpy as np + +from opencompass.models.base import BaseModel +from opencompass.utils import get_logger + +from .huggingface_above_v4_33 import (_convert_chat_messages, + _format_with_fast_chat_template, + _get_meta_template, + _get_possible_max_seq_len) + +try: + from vllm import LLM, SamplingParams +except ImportError: + LLM, SamplingParams = None, None + + +class VLLMwithChatTemplate(BaseModel): + + def __init__( + self, + path: str, + model_kwargs: dict = dict(), + tokenizer_only: bool = False, + generation_kwargs: dict = dict(), + max_seq_len: int = None, + meta_template: Optional[Dict] = None, + fastchat_template: Optional[str] = None, + stop_words: List[str] = [], + ): + assert LLM, ('Please install VLLM with `pip install vllm`. note: torch==2.1.2 is required.') + + self.logger = get_logger() + self.path = path + self.tokenizer_only = tokenizer_only + self.template_parser = _get_meta_template(meta_template) + self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) + if tokenizer_only: + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + else: + self._load_model(path, model_kwargs) + self.tokenizer = self.model.get_tokenizer() + + self.generation_kwargs = generation_kwargs + self.generation_kwargs.pop('do_sample', None) + self.fastchat_template = fastchat_template + self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) + + def _load_model(self, path: str, added_model_kwargs: dict = dict()): + import ray + + if ray.is_initialized(): + self.logger.info('shutdown ray instance to avoid "Calling ray.init() again" error.') + ray.shutdown() + + DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True) + model_kwargs = DEFAULT_MODEL_KWARGS.copy() + model_kwargs.update(added_model_kwargs) + self.model = LLM(path, **model_kwargs) + + def _get_potential_stop_words(self, path: Optional[str]): + from transformers import GenerationConfig + potential_stop_words = [] + try: + generation_config = GenerationConfig.from_pretrained(path) + for token_id in generation_config.eos_token_id: + potential_stop_words.append(self.tokenizer.decode(token_id)) + except: + pass + potential_stop_words.append(self.tokenizer.eos_token) + potential_stop_words = list(set(potential_stop_words)) + return potential_stop_words + + def generate(self, inputs: List[str], max_out_len: int, stopping_criteria: List[str] = [], **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + messages = _convert_chat_messages(inputs) + if self.fastchat_template: + messages = _format_with_fast_chat_template(messages, self.fastchat_template) + else: + messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] + + DEFAULT_GENERATION_KWARGS = { + 'temperature': 0, + 'max_tokens': max_out_len, + 'stop': list(set(self.stop_words + stopping_criteria)) + } + sampling_kwargs = DEFAULT_GENERATION_KWARGS.copy() + sampling_kwargs.update(self.generation_kwargs) + sampling_kwargs.update(kwargs) + sampling_kwargs = SamplingParams(**sampling_kwargs) + + outputs = self.model.generate(messages, sampling_kwargs) + + prompt_list, output_strs = [], [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + prompt_list.append(prompt) + output_strs.append(generated_text) + + return output_strs + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + m = _convert_chat_messages([prompt])[0] + t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True) + return len(t['input_ids']) diff --git a/opencompass/utils/run.py b/opencompass/utils/run.py index 9d7c81bb..816c1655 100644 --- a/opencompass/utils/run.py +++ b/opencompass/utils/run.py @@ -1,3 +1,5 @@ +# flake8: noqa +# yapf: disable import os from typing import List, Tuple, Union @@ -7,7 +9,9 @@ from mmengine.config import Config from opencompass.datasets.custom import make_custom_dataset_config from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel, HuggingFaceCausalLM, HuggingFaceChatGLM3, - HuggingFacewithChatTemplate, TurboMindModel) + HuggingFacewithChatTemplate, TurboMindModel, + TurboMindModelwithChatTemplate, + VLLMwithChatTemplate) from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask @@ -79,25 +83,16 @@ def get_config_from_arg(args) -> Config: config = try_fill_in_custom_cfgs(config) # set infer accelerator if needed if args.accelerator in ['vllm', 'lmdeploy']: - config['models'] = change_accelerator(config['models'], - args.accelerator) + config['models'] = change_accelerator(config['models'], args.accelerator) if 'eval' in config and 'partitioner' in config['eval']: - if 'models' in config['eval']['partitioner']: - config['eval']['partitioner'][ - 'models'] = change_accelerator( - config['eval']['partitioner']['models'], - args.accelerator) - if 'judge_models' in config['eval']['partitioner']: - config['eval']['partitioner'][ - 'judge_models'] = change_accelerator( - config['eval']['partitioner']['judge_models'], - args.accelerator) + config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator) + if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None: + config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator) return config + # parse dataset args if not args.datasets and not args.custom_dataset_path: - raise ValueError('You must specify "--datasets" or ' - '"--custom-dataset-path" if you do not specify a ' - 'config file path.') + raise ValueError('You must specify "--datasets" or "--custom-dataset-path" if you do not specify a config file path.') datasets = [] if args.datasets: datasets_dir = os.path.join(args.config_dir, 'datasets') @@ -110,7 +105,7 @@ def get_config_from_arg(args) -> Config: dataset_key_suffix = '_datasets' for dataset in match_cfg_file(datasets_dir, [dataset_name]): - get_logger().info(f'Loading {dataset[0]}: {dataset[1]}') + logger.info(f'Loading {dataset[0]}: {dataset[1]}') cfg = Config.fromfile(dataset[1]) for k in cfg.keys(): if k.endswith(dataset_key_suffix): @@ -128,19 +123,15 @@ def get_config_from_arg(args) -> Config: # parse model args if not args.models and not args.hf_path: - raise ValueError('You must specify a config file path, ' - 'or specify --models and --datasets, or ' - 'specify HuggingFace model parameters and ' - '--datasets.') + raise ValueError('You must specify a config file path, or specify --models and --datasets, or specify HuggingFace model parameters and --datasets.') models = [] if args.models: model_dir = os.path.join(args.config_dir, 'models') for model in match_cfg_file(model_dir, args.models): - get_logger().info(f'Loading {model[0]}: {model[1]}') + logger.info(f'Loading {model[0]}: {model[1]}') cfg = Config.fromfile(model[1]) if 'models' not in cfg: - raise ValueError( - f'Config file {model[1]} does not contain "models" field') + raise ValueError(f'Config file {model[1]} does not contain "models" field') models += cfg['models'] else: if args.hf_type == 'chat': @@ -167,8 +158,7 @@ def get_config_from_arg(args) -> Config: if args.accelerator in ['vllm', 'lmdeploy']: models = change_accelerator(models, args.accelerator) # parse summarizer args - summarizer_arg = args.summarizer if args.summarizer is not None \ - else 'example' + summarizer_arg = args.summarizer if args.summarizer is not None else 'example' summarizers_dir = os.path.join(args.config_dir, 'summarizers') # Check if summarizer_arg contains '/' @@ -188,9 +178,7 @@ def get_config_from_arg(args) -> Config: # from the configuration file summarizer = cfg[summarizer_key] - return Config(dict(models=models, datasets=datasets, - summarizer=summarizer), - format_python_code=False) + return Config(dict(models=models, datasets=datasets, summarizer=summarizer), format_python_code=False) def change_accelerator(models, accelerator): @@ -200,20 +188,15 @@ def change_accelerator(models, accelerator): for model in models: logger.info(f'Transforming {model["abbr"]} to {accelerator}') # change HuggingFace model to VLLM or TurboMindModel - if model['type'] in [ - HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3 - ]: + if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3]: gen_args = dict() if model.get('generation_kwargs') is not None: generation_kwargs = model['generation_kwargs'].copy() - gen_args['temperature'] = generation_kwargs.get( - 'temperature', 0.001) + gen_args['temperature'] = generation_kwargs.get('temperature', 0.001) gen_args['top_k'] = generation_kwargs.get('top_k', 1) gen_args['top_p'] = generation_kwargs.get('top_p', 0.9) - gen_args['stop_token_ids'] = generation_kwargs.get( - 'eos_token_id', None) - generation_kwargs['stop_token_ids'] = generation_kwargs.get( - 'eos_token_id', None) + gen_args['stop_token_ids'] = generation_kwargs.get('eos_token_id', None) + generation_kwargs['stop_token_ids'] = generation_kwargs.get('eos_token_id', None) generation_kwargs.pop('eos_token_id') else: # if generation_kwargs is not provided, set default values @@ -228,8 +211,7 @@ def change_accelerator(models, accelerator): mod = TurboMindModel acc_model = dict( type=f'{mod.__module__}.{mod.__name__}', - abbr=model['abbr'].replace('hf', 'lmdeploy') - if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy', + abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy', path=model['path'], engine_config=dict(session_len=model['max_seq_len'], max_batch_size=model['batch_size'], @@ -253,11 +235,9 @@ def change_accelerator(models, accelerator): acc_model = dict( type=f'{VLLM.__module__}.{VLLM.__name__}', - abbr=model['abbr'].replace('hf', 'vllm') - if '-hf' in model['abbr'] else model['abbr'] + '-vllm', + abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm', path=model['path'], - model_kwargs=dict( - tensor_parallel_size=model['run_cfg']['num_gpus']), + model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']), max_out_len=model['max_out_len'], max_seq_len=model['max_seq_len'], batch_size=model['batch_size'], @@ -268,7 +248,36 @@ def change_accelerator(models, accelerator): if model.get(item) is not None: acc_model[item] = model[item] else: - raise ValueError(f'Unsupported accelerator {accelerator}') + raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}') + elif model['type'] in [HuggingFacewithChatTemplate]: + if accelerator == 'vllm': + mod = VLLMwithChatTemplate + acc_model = dict( + type=f'{mod.__module__}.{mod.__name__}', + abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-vllm', + path=model['path'], + model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']), + max_out_len=model['max_out_len'], + batch_size=32768, + run_cfg=model['run_cfg'], + stop_words=model.get('stop_words', []), + ) + elif accelerator == 'lmdeploy': + mod = TurboMindModelwithChatTemplate + acc_model = dict( + type=f'{mod.__module__}.{mod.__name__}', + abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-turbomind', + path=model['path'], + engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']), + gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9), + max_seq_len=model.get('max_seq_len', 2048), + max_out_len=model['max_out_len'], + batch_size=32768, + run_cfg=model['run_cfg'], + stop_words=model.get('stop_words', []), + ) + else: + raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}') else: raise ValueError(f'Unsupported model type {model["type"]}') model_accels.append(acc_model)