diff --git a/configs/eval_internlm_chat_turbomind.py b/configs/eval_internlm_chat_turbomind.py index 34667249..0118ebc2 100644 --- a/configs/eval_internlm_chat_turbomind.py +++ b/configs/eval_internlm_chat_turbomind.py @@ -1,17 +1,16 @@ from mmengine.config import read_base from opencompass.models.turbomind import TurboMindModel - with read_base(): # choose a list of datasets from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - # from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - # from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - # from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets - # from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets + from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - # from .datasets.race.race_gen_69ee4f import race_datasets - # from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets + from .datasets.race.race_gen_69ee4f import race_datasets + from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets # and output the results in a choosen format from .summarizers.medium import summarizer @@ -24,56 +23,18 @@ internlm_meta_template = dict(round=[ ], eos_token_id=103028) -llama2_meta_template = dict( - round=[ - dict(role='HUMAN', begin='[INST] ', end=' [/INST]'), - dict(role='BOT', generate=True), - ], - eos_token_id=2) - -qwen_meta_template = dict(round=[ - dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'), - dict(role='BOT', - begin='\n<|im_start|>assistant\n', - end='<|im_end|>', - generate=True) - ]) - -baichuan2_meta_template = dict(round=[ - dict(role='HUMAN', begin=''), - dict(role='BOT', begin='', generate=True) - ]) - # config for internlm-chat-7b internlm_chat_7b = dict( type=TurboMindModel, abbr='internlm-chat-7b-turbomind', path='internlm/internlm-chat-7b', - max_out_len=100, - max_seq_len=2048, - batch_size=32, - concurrency=32, - meta_template=internlm_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -internlm_chat_7b_w4 = dict( - type=TurboMindModel, - abbr='internlm-chat-7b-w4-turbomind', - path='internlm/internlm-chat-7b-w4', - max_out_len=100, - max_seq_len=2048, - batch_size=32, - concurrency=32, - meta_template=internlm_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for internlm-chat-7b-w4kv8 model -internlm_chat_7b_w4kv8 = dict( - type=TurboMindModel, - abbr='internlm-chat-7b-w4kv8-turbomind', - path='internlm/internlm-chat-7b-w4kv8', + engine_config=dict(session_len=2048, + max_batch_size=32, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, + top_p=0.8, + temperature=1.0, + max_new_tokens=100), max_out_len=100, max_seq_len=2048, batch_size=32, @@ -87,6 +48,13 @@ internlm_chat_20b = dict( type=TurboMindModel, abbr='internlm-chat-20b-turbomind', path='internlm/internlm-chat-20b', + engine_config=dict(session_len=2048, + max_batch_size=8, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, + top_p=0.8, + temperature=1.0, + max_new_tokens=100), max_out_len=100, max_seq_len=2048, batch_size=8, @@ -95,108 +63,4 @@ internlm_chat_20b = dict( run_cfg=dict(num_gpus=1, num_procs=1), ) -# config for internlm-chat-20b-w4 model -internlm_chat_20b_w4 = dict( - type=TurboMindModel, - abbr='internlm-chat-20b-w4-turbomind', - path='internlm/internlm-chat-20b-w4', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - meta_template=internlm_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for internlm-chat-20b-w4kv8 model -internlm_chat_20b_w4kv8 = dict( - type=TurboMindModel, - abbr='internlm-chat-20b-w4kv8-turbomind', - path='internlm/internlm-chat-20b-w4kv8', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - meta_template=internlm_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for llama2-chat-7b -llama2_chat_7b = dict( - type=TurboMindModel, - abbr='llama2-chat-7b-turbomind', - path='meta-llama/Llama-2-7b-chat-hf', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=32, - meta_template=llama2_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for llama2-chat-13b -llama2_chat_13b = dict( - type=TurboMindModel, - abbr='llama2-chat-13b-turbomind', - path='meta-llama/Llama-2-13b-chat-hf', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - meta_template=llama2_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for llama2-chat-70b -llama2_chat_70b = dict( - type=TurboMindModel, - abbr='llama2-chat-70b-turbomind', - path='meta-llama/Llama-2-70b-chat-hf', - max_out_len=100, - max_seq_len=2048, - batch_size=8, - concurrency=8, - meta_template=llama2_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for qwen-chat-7b -qwen_chat_7b = dict( - type=TurboMindModel, - abbr='qwen-chat-7b-turbomind', - path='Qwen/Qwen-7B-Chat', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=32, - meta_template=qwen_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for qwen-chat-7b -qwen_chat_14b = dict( - type=TurboMindModel, - abbr='qwen-chat-14b-turbomind', - path='Qwen/Qwen-14B-Chat', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=32, - meta_template=qwen_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - -# config for baichuan2-chat-7b -baichuan2_chat_7b = dict( - type=TurboMindModel, - abbr='baichuan2-chat-7b-turbomind', - path='baichuan-inc/Baichuan2-7B-Chat', - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=32, - meta_template=baichuan2_meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), -) - models = [internlm_chat_20b] diff --git a/configs/eval_internlm_turbomind.py b/configs/eval_internlm_turbomind.py index 8e62fa34..210d0b18 100644 --- a/configs/eval_internlm_turbomind.py +++ b/configs/eval_internlm_turbomind.py @@ -20,30 +20,13 @@ internlm_7b = dict( type=TurboMindModel, abbr='internlm-7b-turbomind', path="internlm/internlm-7b", - max_out_len=100, - max_seq_len=2048, - batch_size=32, - concurrency=32, - run_cfg=dict(num_gpus=1, num_procs=1), - ) - -# # config for internlm-7b-w4 model -internlm_7b_w4 = dict( - type=TurboMindModel, - abbr='internlm-7b-w4-turbomind', - path="internlm/internlm-7b-w4", - max_out_len=100, - max_seq_len=2048, - batch_size=32, - concurrency=32, - run_cfg=dict(num_gpus=1, num_procs=1), - ) - -# # config for internlm-7b-w4kv8 model -internlm_7b_w4kv8 = dict( - type=TurboMindModel, - abbr='internlm-7b-w4kv8-turbomind', - path="internlm/internlm-7b-w4kv8", + engine_config=dict(session_len=2048, + max_batch_size=32, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, + top_p=0.8, + temperature=1.0, + max_new_tokens=100), max_out_len=100, max_seq_len=2048, batch_size=32, @@ -56,6 +39,12 @@ internlm_20b = dict( type=TurboMindModel, abbr='internlm-20b-turbomind', path="internlm/internlm-20b", + engine_config=dict(session_len=2048, + max_batch_size=8, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, top_p=0.8, + temperature=1.0, + max_new_tokens=100), max_out_len=100, max_seq_len=2048, batch_size=8, @@ -63,29 +52,4 @@ internlm_20b = dict( run_cfg=dict(num_gpus=1, num_procs=1), ) -# config for internlm-20b-w4 model -internlm_20b_w4 = dict( - type=TurboMindModel, - abbr='internlm-20b-w4-turbomind', - path="internlm/internlm-20b-w4", - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - run_cfg=dict(num_gpus=1, num_procs=1), - ) - - -# config for internlm-20b-w4kv8 model -internlm_20b_w4kv8 = dict( - type=TurboMindModel, - abbr='internlm-20b-w4kv8-turbomind', - path="internlm/internlm-20b-w4kv8", - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - run_cfg=dict(num_gpus=1, num_procs=1), - ) - models = [internlm_20b] diff --git a/docs/en/advanced_guides/evaluation_turbomind.md b/docs/en/advanced_guides/evaluation_turbomind.md index 0fa75fc3..6e9a5b1a 100644 --- a/docs/en/advanced_guides/evaluation_turbomind.md +++ b/docs/en/advanced_guides/evaluation_turbomind.md @@ -18,21 +18,50 @@ pip install lmdeploy ## Evaluation -OpenCompass integrates both turbomind's python API and gRPC API for evaluation. And the former is highly recommended. +OpenCompass integrates turbomind's python API for evaluation. -We take the InternLM-20B as example. Please download it from huggingface: +We take the InternLM-20B as example. Firstly, we prepare the evaluation config `configs/eval_internlm_turbomind.py`: -```shell -# Download InternLM model(or use the cached model's checkpoint) +```python +from mmengine.config import read_base +from opencompass.models.turbomind import TurboMindModel -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b + +with read_base(): + # choose a list of datasets + from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets + from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets + # and output the results in a chosen format + from .summarizers.medium import summarizer + +datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + +# config for internlm-20b model +internlm_20b = dict( + type=TurboMindModel, + abbr='internlm-20b-turbomind', + path="internlm/internlm-20b", # this path should be same as in huggingface + engine_config=dict(session_len=2048, + max_batch_size=8, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, top_p=0.8, + temperature=1.0, + max_new_tokens=100), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + concurrency=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) + +models = [internlm_20b] ``` -### Evaluation with Turbomind Python API (recommended) - -In the home folder of OpenCompass, start evaluation by the following command: +Then, in the home folder of OpenCompass, start evaluation by the following command: ```shell python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b @@ -42,42 +71,7 @@ You are expected to get the evaluation results after the inference and evaluatio **Note**: +- If you want to pass more arguments for `engine_config`和`gen_config` in the evaluation config file, please refer to [TurbomindEngineConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#turbomindengineconfig) + and [EngineGenerationConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#generationconfig) - If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py` - If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line. -- If you want to evaluate other chat models like Llama2, QWen-7B, Baichuan2-7B, you could change to the setting of `models` in `eval_internlm_chat_turbomind.py`. - -### Evaluation with Turbomind gPRC API (optional) - -Convert model to TurboMind format using lmdeploy - -```shell -lmdeploy convert internlm /path/to/internlm-20b \ - --dst-path {/home/folder/of/opencompass}/turbomind -``` - -**Note**: - -If evaluating the InternLM Chat model, make sure to pass `internlm-chat` as the model name instead of `internlm` when converting the model format. The specific command is: - -```shell -lmdeploy convert internlm-chat /path/to/internlm-20b-chat \ - --dst-path {/home/folder/of/opencompass}/turbomind -``` - -In the home folder of OpenCompass, launch the Triton Inference Server: - -```shell -bash turbomind/service_docker_up.sh -``` - -And start evaluation by the following command: - -```shell -python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b -``` - -\*\*Note: \*\* - -- If the InternLM Chat model is requested to be evaluated, please use config file `eval_internlm_chat_turbomind_tis.py` -- In `eval_internlm_turbomind_tis.py`, the configured Triton Inference Server (TIS) address is `tis_addr='0.0.0.0:33337'`. Please modify `tis_addr` to the IP address of the machine where the server is launched. -- If evaluating the InternLM 7B model, please modify the `models` configuration in `eval_internlm_xxx_turbomind_tis.py`. diff --git a/docs/zh_cn/advanced_guides/evaluation_turbomind.md b/docs/zh_cn/advanced_guides/evaluation_turbomind.md index 240b2f38..be86e172 100644 --- a/docs/zh_cn/advanced_guides/evaluation_turbomind.md +++ b/docs/zh_cn/advanced_guides/evaluation_turbomind.md @@ -18,21 +18,50 @@ pip install lmdeploy ## 评测 -OpenCompass 支持分别通过 turbomind python API 和 gRPC API 评测数据集。我们强烈推荐使用前者进行评测。 +OpenCompass 支持分别通过 turbomind python API 评测数据集。 -下文以 InternLM-20B 模型为例,介绍如何评测。首先,从 huggingface 上下载 InternLM 模型: +下文以 InternLM-20B 模型为例,介绍如何评测。首先我们准备好测试配置文件`configs/eval_internlm_turbomind.py`: -```shell -Download InternLM model(or use the cached model's checkpoint) +```python +from mmengine.config import read_base +from opencompass.models.turbomind import TurboMindModel -# Make sure you have git-lfs installed (https://git-lfs.com) -git lfs install -git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b + +with read_base(): + # choose a list of datasets + from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets + from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets + # and output the results in a chosen format + from .summarizers.medium import summarizer + +datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + +# config for internlm-20b model +internlm_20b = dict( + type=TurboMindModel, + abbr='internlm-20b-turbomind', + path="internlm/internlm-20b", # 注意路径与huggingface保持一致 + engine_config=dict(session_len=2048, + max_batch_size=8, + rope_scaling_factor=1.0), + gen_config=dict(top_k=1, top_p=0.8, + temperature=1.0, + max_new_tokens=100), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + concurrency=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) + +models = [internlm_20b] ``` -### 通过 TurboMind Python API 评测(推荐) - -在 OpenCompass 的项目目录下,执行如下命令可得到评测结果: +然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果: ```shell python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b @@ -40,41 +69,6 @@ python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-2 **注:** +- 如果想在测评配置文件中`engine_config`和`gen_config`字段传递更多参数,请参考[TurbomindEngineConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#turbomindengineconfig) 和 [EngineGenerationConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#generationconfig) - 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py` - 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。 -- 如果评测其他模型如 Llama2, QWen-7B, Baichuan2-7B, 请修改`eval_internlm_chat_turbomind.py`中`models`字段 。 - -### 通过 TurboMind gPRC API 评测(可选) - -首先需要将模型转换为 turbomind 格式 - -```shell script -lmdeploy convert internlm /path/to/internlm-20b \ - --dst-path {/home/folder/of/opencompass}/turbomind -``` - -注意:如果评测 InternLM Chat 模型,那么在转换模型格式的时候,模型名字要填写 `internlm-chat`。具体命令是: - -```shell -lmdeploy convert internlm-chat /path/to/internlm-20b-chat \ - --dst-path {/home/folder/of/opencompass}/turbomind -``` - -在 OpenCompass 的项目目录下,启动 triton inference server: - -```shell -bash turbomind/service_docker_up.sh -``` - -然后,执行如下命令进行评测: - -```shell -python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b -`` - -**注:** - -- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind_tis.py` -- 在配置文件中,triton inference server(TIS) 地址是 `tis_addr='0.0.0.0:33337'`。请把配置中的`tis_addr`修改为server所在机器的ip地址。 -- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_xxx_turbomind_tis.py`中`models`字段。 -``` diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py index f435fe86..1a17a7ba 100644 --- a/opencompass/models/turbomind.py +++ b/opencompass/models/turbomind.py @@ -30,43 +30,49 @@ class TurboMindModel(BaseModel): meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. + engine_config (Dict, optional): The engine config to set + arguments like session_len, max_batch_size for TurboMind. + gen_config (Dict, optional): Generation config to set + arguments like top_k, top_p, temperature. """ - def __init__( - self, - path: str, - concurrency: int = 8, - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - ): - from lmdeploy.turbomind import TurboMind - + def __init__(self, + path: str, + concurrency: int = 8, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + engine_config: Optional[Dict] = None, + gen_config: Optional[Dict] = None): super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template) + from lmdeploy.turbomind import TurboMind + + if engine_config is not None: + from lmdeploy.messages import TurbomindEngineConfig + engine_config = TurbomindEngineConfig(**engine_config) + if gen_config is not None: + from lmdeploy.messages import EngineGenerationConfig + gen_config = EngineGenerationConfig(**gen_config) self.logger = get_logger() - tm_model = TurboMind.from_pretrained(path) + tm_model = TurboMind.from_pretrained(path, engine_config=engine_config) self.tokenizer = tm_model.tokenizer self.generators = [ tm_model.create_instance() for i in range(concurrency) ] self.generator_ids = [i + 1 for i in range(concurrency)] + self.gen_config = gen_config def generate( self, inputs: List[str], max_out_len: int = 512, - temperature: float = 1.0, ) -> List[str]: """Generate results given a list of inputs. Args: inputs (List[str]): A list of prompts max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. Defaults to 1.0. Returns: List[str]: A list of generated strings. @@ -88,7 +94,7 @@ class TurboMindModel(BaseModel): self.generators[:len(batch_input)], self.generator_ids[:len(batch_input)], batch_input, [max_out_len] * len(batch_input), - [temperature] * len(batch_input))) + [self.gen_config] * len(batch_input))) results += _results return results @@ -103,8 +109,12 @@ class TurboMindModel(BaseModel): """ return self.token_bucket.get_token() - def _generate(self, generator, session_id, prompt: str or PromptList, - max_out_len: int, temperature: float) -> str: + def _generate(self, + generator, + session_id, + prompt: str or PromptList, + max_out_len: int, + gen_config=None) -> str: """Generate results given a list of inputs. Args: @@ -112,10 +122,8 @@ class TurboMindModel(BaseModel): The PromptDict should be organized in OpenCompass' API format. max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. + gen_config (EngineGenerationConfig, optional): Generation + config to set arguments like top_k, top_p, temperature. Returns: str: The generated string. @@ -127,11 +135,10 @@ class TurboMindModel(BaseModel): for outputs in generator.stream_infer(session_id=session_id, input_ids=[input_ids], + gen_config=gen_config, request_output_len=max_out_len, sequence_start=True, sequence_end=True, - top_k=1, - top_p=0.8, step=0, stream_output=False): _, output_ids, _ = outputs