[Fix] Fix turbomind and update docs (#808)

* update

* update docs

* add engine_config and gen_config in eval_config

* update

* fix

* fix

* fix

* fix docstr

* fix url
This commit is contained in:
RunningLeon 2024-01-18 14:41:35 +08:00 committed by GitHub
parent 9e5746d3d8
commit 61fe873c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 322 deletions

View File

@ -1,17 +1,16 @@
from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
# from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
# from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
# from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
# from .datasets.race.race_gen_69ee4f import race_datasets
# from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
from .datasets.race.race_gen_69ee4f import race_datasets
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer
@ -24,56 +23,18 @@ internlm_meta_template = dict(round=[
],
eos_token_id=103028)
llama2_meta_template = dict(
round=[
dict(role='HUMAN', begin='[INST] ', end=' [/INST]'),
dict(role='BOT', generate=True),
],
eos_token_id=2)
qwen_meta_template = dict(round=[
dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'),
dict(role='BOT',
begin='\n<|im_start|>assistant\n',
end='<|im_end|>',
generate=True)
])
baichuan2_meta_template = dict(round=[
dict(role='HUMAN', begin='<reserved_106>'),
dict(role='BOT', begin='<reserved_107>', generate=True)
])
# config for internlm-chat-7b
internlm_chat_7b = dict(
type=TurboMindModel,
abbr='internlm-chat-7b-turbomind',
path='internlm/internlm-chat-7b',
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
internlm_chat_7b_w4 = dict(
type=TurboMindModel,
abbr='internlm-chat-7b-w4-turbomind',
path='internlm/internlm-chat-7b-w4',
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-chat-7b-w4kv8 model
internlm_chat_7b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-chat-7b-w4kv8-turbomind',
path='internlm/internlm-chat-7b-w4kv8',
engine_config=dict(session_len=2048,
max_batch_size=32,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=32,
@ -87,6 +48,13 @@ internlm_chat_20b = dict(
type=TurboMindModel,
abbr='internlm-chat-20b-turbomind',
path='internlm/internlm-chat-20b',
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
@ -95,108 +63,4 @@ internlm_chat_20b = dict(
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-chat-20b-w4 model
internlm_chat_20b_w4 = dict(
type=TurboMindModel,
abbr='internlm-chat-20b-w4-turbomind',
path='internlm/internlm-chat-20b-w4',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-chat-20b-w4kv8 model
internlm_chat_20b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-chat-20b-w4kv8-turbomind',
path='internlm/internlm-chat-20b-w4kv8',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-7b
llama2_chat_7b = dict(
type=TurboMindModel,
abbr='llama2-chat-7b-turbomind',
path='meta-llama/Llama-2-7b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-13b
llama2_chat_13b = dict(
type=TurboMindModel,
abbr='llama2-chat-13b-turbomind',
path='meta-llama/Llama-2-13b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-70b
llama2_chat_70b = dict(
type=TurboMindModel,
abbr='llama2-chat-70b-turbomind',
path='meta-llama/Llama-2-70b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for qwen-chat-7b
qwen_chat_7b = dict(
type=TurboMindModel,
abbr='qwen-chat-7b-turbomind',
path='Qwen/Qwen-7B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=qwen_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for qwen-chat-7b
qwen_chat_14b = dict(
type=TurboMindModel,
abbr='qwen-chat-14b-turbomind',
path='Qwen/Qwen-14B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=qwen_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for baichuan2-chat-7b
baichuan2_chat_7b = dict(
type=TurboMindModel,
abbr='baichuan2-chat-7b-turbomind',
path='baichuan-inc/Baichuan2-7B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=baichuan2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_chat_20b]

View File

@ -20,30 +20,13 @@ internlm_7b = dict(
type=TurboMindModel,
abbr='internlm-7b-turbomind',
path="internlm/internlm-7b",
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# # config for internlm-7b-w4 model
internlm_7b_w4 = dict(
type=TurboMindModel,
abbr='internlm-7b-w4-turbomind',
path="internlm/internlm-7b-w4",
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# # config for internlm-7b-w4kv8 model
internlm_7b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-7b-w4kv8-turbomind',
path="internlm/internlm-7b-w4kv8",
engine_config=dict(session_len=2048,
max_batch_size=32,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=32,
@ -56,6 +39,12 @@ internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b",
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
@ -63,29 +52,4 @@ internlm_20b = dict(
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-20b-w4 model
internlm_20b_w4 = dict(
type=TurboMindModel,
abbr='internlm-20b-w4-turbomind',
path="internlm/internlm-20b-w4",
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-20b-w4kv8 model
internlm_20b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-20b-w4kv8-turbomind',
path="internlm/internlm-20b-w4kv8",
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b]

View File

@ -18,21 +18,50 @@ pip install lmdeploy
## Evaluation
OpenCompass integrates both turbomind's python API and gRPC API for evaluation. And the former is highly recommended.
OpenCompass integrates turbomind's python API for evaluation.
We take the InternLM-20B as example. Please download it from huggingface:
We take the InternLM-20B as example. Firstly, we prepare the evaluation config `configs/eval_internlm_turbomind.py`:
```shell
# Download InternLM model(or use the cached model's checkpoint)
```python
from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # this path should be same as in huggingface
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b]
```
### Evaluation with Turbomind Python API (recommended)
In the home folder of OpenCompass, start evaluation by the following command:
Then, in the home folder of OpenCompass, start evaluation by the following command:
```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
@ -42,42 +71,7 @@ You are expected to get the evaluation results after the inference and evaluatio
**Note**:
- If you want to pass more arguments for `engine_config`和`gen_config` in the evaluation config file, please refer to [TurbomindEngineConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#turbomindengineconfig)
and [EngineGenerationConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#generationconfig)
- If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py`
- If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line.
- If you want to evaluate other chat models like Llama2, QWen-7B, Baichuan2-7B, you could change to the setting of `models` in `eval_internlm_chat_turbomind.py`.
### Evaluation with Turbomind gPRC API (optional)
Convert model to TurboMind format using lmdeploy
```shell
lmdeploy convert internlm /path/to/internlm-20b \
--dst-path {/home/folder/of/opencompass}/turbomind
```
**Note**:
If evaluating the InternLM Chat model, make sure to pass `internlm-chat` as the model name instead of `internlm` when converting the model format. The specific command is:
```shell
lmdeploy convert internlm-chat /path/to/internlm-20b-chat \
--dst-path {/home/folder/of/opencompass}/turbomind
```
In the home folder of OpenCompass, launch the Triton Inference Server:
```shell
bash turbomind/service_docker_up.sh
```
And start evaluation by the following command:
```shell
python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b
```
\*\*Note: \*\*
- If the InternLM Chat model is requested to be evaluated, please use config file `eval_internlm_chat_turbomind_tis.py`
- In `eval_internlm_turbomind_tis.py`, the configured Triton Inference Server (TIS) address is `tis_addr='0.0.0.0:33337'`. Please modify `tis_addr` to the IP address of the machine where the server is launched.
- If evaluating the InternLM 7B model, please modify the `models` configuration in `eval_internlm_xxx_turbomind_tis.py`.

View File

@ -18,21 +18,50 @@ pip install lmdeploy
## 评测
OpenCompass 支持分别通过 turbomind python API 和 gRPC API 评测数据集。我们强烈推荐使用前者进行评测。
OpenCompass 支持分别通过 turbomind python API 评测数据集。
下文以 InternLM-20B 模型为例,介绍如何评测。首先,从 huggingface 上下载 InternLM 模型:
下文以 InternLM-20B 模型为例,介绍如何评测。首先我们准备好测试配置文件`configs/eval_internlm_turbomind.py`:
```shell
Download InternLM model(or use the cached model's checkpoint)
```python
from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # 注意路径与huggingface保持一致
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b]
```
### 通过 TurboMind Python API 评测(推荐)
在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
@ -40,41 +69,6 @@ python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-2
**注:**
- 如果想在测评配置文件中`engine_config`和`gen_config`字段传递更多参数,请参考[TurbomindEngineConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#turbomindengineconfig) 和 [EngineGenerationConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#generationconfig)
- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py`
- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。
- 如果评测其他模型如 Llama2, QWen-7B, Baichuan2-7B, 请修改`eval_internlm_chat_turbomind.py`中`models`字段 。
### 通过 TurboMind gPRC API 评测(可选)
首先需要将模型转换为 turbomind 格式
```shell script
lmdeploy convert internlm /path/to/internlm-20b \
--dst-path {/home/folder/of/opencompass}/turbomind
```
注意:如果评测 InternLM Chat 模型,那么在转换模型格式的时候,模型名字要填写 `internlm-chat`。具体命令是:
```shell
lmdeploy convert internlm-chat /path/to/internlm-20b-chat \
--dst-path {/home/folder/of/opencompass}/turbomind
```
在 OpenCompass 的项目目录下,启动 triton inference server
```shell
bash turbomind/service_docker_up.sh
```
然后,执行如下命令进行评测:
```shell
python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b
``
**注:**
- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind_tis.py`
- 在配置文件中triton inference server(TIS) 地址是 `tis_addr='0.0.0.0:33337'`。请把配置中的`tis_addr`修改为server所在机器的ip地址。
- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_xxx_turbomind_tis.py`中`models`字段。
```

View File

@ -30,43 +30,49 @@ class TurboMindModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
engine_config (Dict, optional): The engine config to set
arguments like session_len, max_batch_size for TurboMind.
gen_config (Dict, optional): Generation config to set
arguments like top_k, top_p, temperature.
"""
def __init__(
self,
def __init__(self,
path: str,
concurrency: int = 8,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
):
from lmdeploy.turbomind import TurboMind
engine_config: Optional[Dict] = None,
gen_config: Optional[Dict] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.turbomind import TurboMind
if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig
engine_config = TurbomindEngineConfig(**engine_config)
if gen_config is not None:
from lmdeploy.messages import EngineGenerationConfig
gen_config = EngineGenerationConfig(**gen_config)
self.logger = get_logger()
tm_model = TurboMind.from_pretrained(path)
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
self.tokenizer = tm_model.tokenizer
self.generators = [
tm_model.create_instance() for i in range(concurrency)
]
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
def generate(
self,
inputs: List[str],
max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of prompts
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 1.0.
Returns:
List[str]: A list of generated strings.
@ -88,7 +94,7 @@ class TurboMindModel(BaseModel):
self.generators[:len(batch_input)],
self.generator_ids[:len(batch_input)],
batch_input, [max_out_len] * len(batch_input),
[temperature] * len(batch_input)))
[self.gen_config] * len(batch_input)))
results += _results
return results
@ -103,8 +109,12 @@ class TurboMindModel(BaseModel):
"""
return self.token_bucket.get_token()
def _generate(self, generator, session_id, prompt: str or PromptList,
max_out_len: int, temperature: float) -> str:
def _generate(self,
generator,
session_id,
prompt: str or PromptList,
max_out_len: int,
gen_config=None) -> str:
"""Generate results given a list of inputs.
Args:
@ -112,10 +122,8 @@ class TurboMindModel(BaseModel):
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
gen_config (EngineGenerationConfig, optional): Generation
config to set arguments like top_k, top_p, temperature.
Returns:
str: The generated string.
@ -127,11 +135,10 @@ class TurboMindModel(BaseModel):
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids],
gen_config=gen_config,
request_output_len=max_out_len,
sequence_start=True,
sequence_end=True,
top_k=1,
top_p=0.8,
step=0,
stream_output=False):
_, output_ids, _ = outputs