diff --git a/configs/eval_internlm_chat_lmdeploy_pytorch.py b/configs/eval_internlm_chat_lmdeploy_pytorch.py new file mode 100644 index 00000000..b829e759 --- /dev/null +++ b/configs/eval_internlm_chat_lmdeploy_pytorch.py @@ -0,0 +1,69 @@ +from mmengine.config import read_base +from opencompass.models import LmdeployPytorchModel + + +with read_base(): + # choose a list of datasets + from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets + from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets + from .datasets.race.race_gen_69ee4f import race_datasets + from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets + # and output the results in a choosen format + from .summarizers.medium import summarizer + + +datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + + +meta_template = dict( + round=[ + dict(role='HUMAN', begin='<|User|>:', end='\n'), + dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), + ], + eos_token_id=103028) + +# config for internlm-chat-7b +internlm_chat_7b = dict( + type=LmdeployPytorchModel, + abbr='internlm-chat-7b-pytorch', + path='internlm/internlm-chat-7b', + engine_config=dict(session_len=2048, + max_batch_size=16), + gen_config=dict(top_k=1, + top_p=0.8, + temperature=1.0, + max_new_tokens=100), + max_out_len=100, + max_seq_len=2048, + batch_size=16, + concurrency=16, + meta_template=meta_template, + run_cfg=dict(num_gpus=1, num_procs=1), + end_str='', +) + +# config for internlm-chat-20b +internlm_chat_20b = dict( + type=LmdeployPytorchModel, + abbr='internlm-chat-20b-pytorch', + path='internlm/internlm-chat-20b', + engine_config=dict(session_len=2048, + max_batch_size=8), + gen_config=dict(top_k=1, + top_p=0.8, + temperature=1.0, + max_new_tokens=100), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + concurrency=8, + meta_template=meta_template, + run_cfg=dict(num_gpus=1, num_procs=1), + end_str='', + ) + +models = [internlm_chat_20b] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 6790d652..8f3f26e5 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -14,6 +14,7 @@ from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403 from .intern_model import InternLM # noqa: F401, F403 from .lightllm_api import LightllmAPI # noqa: F401 from .llama2 import Llama2, Llama2Chat # noqa: F401, F403 +from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401 from .minimax_api import MiniMax # noqa: F401 from .mixtral import Mixtral # noqa: F401 from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403 diff --git a/opencompass/models/lmdeploy_pytorch.py b/opencompass/models/lmdeploy_pytorch.py new file mode 100644 index 00000000..b47ab419 --- /dev/null +++ b/opencompass/models/lmdeploy_pytorch.py @@ -0,0 +1,157 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.models.base import BaseModel +from opencompass.utils.logging import get_logger +from opencompass.utils.prompt import PromptList + +PromptType = Union[PromptList, str] + + +def valid_str(string, coding='utf-8'): + """decode text according to its encoding type.""" + invalid_chars = [b'\xef\xbf\xbd'] + bstr = bytes(string, coding) + for invalid_char in invalid_chars: + bstr = bstr.replace(invalid_char, b'') + ret = bstr.decode(encoding=coding, errors='ignore') + return ret + + +class LmdeployPytorchModel(BaseModel): + """Model wrapper for lmdeploy pytorch engine through python API. + + Args: + path (str): path of the supported pytorch model. + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + engine_config (Dict, optional): The engine config to set + arguments like session_len, max_batch_size for TurboMind. + gen_config (Dict, optional): Generation config to set + arguments like top_k, top_p, temperature. + end_str (str, optional): Whether to trim generated strings with end_str + if the model has special ending strings that are not handled well. + Defaults to None. + """ + + def __init__(self, + path: str, + concurrency: int = 8, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + engine_config: Optional[Dict] = None, + gen_config: Optional[Dict] = None, + end_str: Optional[str] = None): + super().__init__(path=path, + max_seq_len=max_seq_len, + meta_template=meta_template) + from lmdeploy.pytorch import engine as tm + + if engine_config is not None: + from lmdeploy.messages import PytorchEngineConfig + engine_config = PytorchEngineConfig(**engine_config) + if gen_config is not None: + from lmdeploy.messages import EngineGenerationConfig + gen_config = EngineGenerationConfig(**gen_config) + + self.logger = get_logger() + tm_model = tm.Engine(path, engine_config) + self.tokenizer = tm_model.tokenizer + self.generators = [ + tm_model.create_instance() for i in range(concurrency) + ] + self.generator_ids = [i + 1 for i in range(concurrency)] + self.gen_config = gen_config + self.end_str = end_str + + def generate( + self, + inputs: List[str], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of prompts + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + assert isinstance( + inputs, List), f'List(str) is expected, but got {type(inputs)}' + + # split inputs into batches + batch_size = len(self.generators) + batch_inputs = [ + inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size) + ] + + results = [] + for batch_input in batch_inputs: + with ThreadPoolExecutor() as executor: + _results = list( + executor.map( + self._generate, + self.generators[:len(batch_input)], + self.generator_ids[:len(batch_input)], + batch_input, + [self.gen_config] * len(batch_input), + [self.end_str] * len(batch_input), + )) + results += _results + return results + + def get_token_len(self, prompt: str) -> int: + input_ids = self.tokenizer.encode(prompt) + return len(input_ids) + + def wait(self): + """Wait till the next query can be sent. + + Applicable in both single-thread and multi-thread environments. + """ + return self.token_bucket.get_token() + + def _generate(self, + generator, + session_id, + prompt: str or PromptList, + gen_config=None, + end_str: Optional[str] = None) -> str: + """Generate results given a list of inputs. + + Args: + prompt (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + gen_config (EngineGenerationConfig, optional): Generation + config to set arguments like top_k, top_p, temperature. + end_str (str, optional): Whether to trim generated strings + with end_str if the model has special ending strings + that are not handled well. + Defaults to None. + Returns: + str: The generated string. + """ + assert type( + prompt) is str, 'We only support string for TurboMind Python API' + input_ids = self.tokenizer.encode(prompt) + _, output_ids, _ = generator.infer(session_id, + input_ids, + gen_config=gen_config) + # stop engine + if hasattr(generator, 'end'): + generator.end(session_id) + # decode output + response_all = self.tokenizer.decode(output_ids) + # trim output + if end_str: + response_all = response_all.split(end_str)[0] + # remove invalid characters + response_all = valid_str(response_all) + return response_all