[Feature] Integrate lmdeploy pipeline api (#1198)

* integrate lmdeploy's pipeline api

* fix linting

* update user guide

* rename

* update

* update

* update

* rollback class name

* update

* remove unused code

* update

* update

* fix ci check

* compatibility

* remove concurrency

* Update configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py

* Update docs/zh_cn/advanced_guides/evaluation_lmdeploy.md

* [Bug] fix lint

---------

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
Co-authored-by: tonysy <sy.zhangbuaa@gmail.com>
This commit is contained in:
Lyu Han 2024-10-09 22:58:06 +08:00 committed by GitHub
parent d2ab51abbd
commit b52ba65c26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 249 additions and 955 deletions

View File

@ -1,69 +0,0 @@
from mmengine.config import read_base
from opencompass.models import LmdeployPytorchModel
with read_base():
# choose a list of datasets
from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets
from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from opencompass.configs.summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
meta_template = dict(
round=[
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
],
eos_token_id=103028)
# config for internlm-chat-7b
internlm_chat_7b = dict(
type=LmdeployPytorchModel,
abbr='internlm-chat-7b-pytorch',
path='internlm/internlm-chat-7b',
engine_config=dict(session_len=2048,
max_batch_size=16),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<eoa>',
)
# config for internlm-chat-20b
internlm_chat_20b = dict(
type=LmdeployPytorchModel,
abbr='internlm-chat-20b-pytorch',
path='internlm/internlm-chat-20b',
engine_config=dict(session_len=2048,
max_batch_size=8),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<eoa>',
)
models = [internlm_chat_20b]

View File

@ -1,41 +0,0 @@
from mmengine.config import read_base
from opencompass.models.lmdeploy_tis import LmdeployTisModel
with read_base():
# choose a list of datasets
from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets
from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from opencompass.configs.summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
meta_template = dict(
round=[
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
],
eos_token_id=92542
)
models = [
dict(
type=LmdeployTisModel,
abbr='internlm-chat-20b-lmdeploy-tis',
path='internlm/internlm-chat-20b',
tis_addr='0.0.0.0:33337',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<|im_end|>',
)
]

View File

@ -1,40 +0,0 @@
from mmengine.config import read_base
from opencompass.models.turbomind_tis import TurboMindTisModel
with read_base():
# choose a list of datasets
from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets
from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format
from opencompass.configs.summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
meta_template = dict(
round=[
dict(role='HUMAN', begin='<|User|>:', end='\n'),
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
],
eos_token_id=103028)
models = [
dict(
type=TurboMindTisModel,
abbr='internlm-chat-20b-turbomind',
path='internlm',
tis_addr='0.0.0.0:33337',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -1,28 +0,0 @@
from mmengine.config import read_base
from opencompass.models.turbomind_tis import TurboMindTisModel
with read_base():
# choose a list of datasets
from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a choosen format
from opencompass.configs.summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
models = [
dict(
type=TurboMindTisModel,
abbr='internlm-chat-20b-turbomind',
path='internlm',
tis_addr='0.0.0.0:33337',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -1,15 +1,24 @@
from opencompass.models import TurboMindModelwithChatTemplate
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='internlm2-chat-7b-turbomind',
abbr=f'internlm2-chat-7b-lmdeploy',
path='internlm/internlm2-chat-7b',
engine_config=dict(session_len=8192, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=4096),
# inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
# If the model is not supported by 'turbomind', it will fallback to
# 'pytorch'
backend='turbomind',
# For the detailed engine config and generation config, please refer to
# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
engine_config=dict(tp=1),
gen_config=dict(do_sample=False),
max_seq_len=8192,
max_out_len=4096,
batch_size=16,
# the max number of prompts that LMDeploy receives
# in `generate` function
batch_size=5000,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -0,0 +1,88 @@
# Evaluation with LMDeploy
We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. It has a remarkable inference performance. We now illustrate how to evaluate a model with the support of LMDeploy in OpenCompass.
## Setup
### Install OpenCompass
Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
### Install LMDeploy
Install lmdeploy via pip (python 3.8+)
```shell
pip install lmdeploy
```
The default prebuilt package is compiled on CUDA 12. However, if CUDA 11+ is required, you can install lmdeploy by:
```shell
export LMDEPLOY_VERSION=0.6.0
export PYTHON_VERSION=310
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
## Evaluation
When evaluating a model, it is necessary to prepare an evaluation configuration that specifies information such as the evaluation dataset, the model, and inference parameters.
Taking [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) as an example, the evaluation config is as follows:
```python
# configure the dataset
from mmengine.config import read_base
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \
gsm8k_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# configure lmdeploy
from opencompass.models import TurboMindModelwithChatTemplate
# configure the model
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr=f'internlm2-chat-7b-lmdeploy',
# model path, which can be the address of a model repository on the Hugging Face Hub or a local path
path='internlm/internlm2-chat-7b',
# inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
# If the model is not supported by 'turbomind', it will fallback to
# 'pytorch'
backend='turbomind',
# For the detailed engine config and generation config, please refer to
# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
engine_config=dict(tp=1),
gen_config=dict(do_sample=False),
# the max size of the context window
max_seq_len=7168,
# the max number of new tokens
max_out_len=1024,
# the max number of prompts that LMDeploy receives
# in `generate` function
batch_size=5000,
run_cfg=dict(num_gpus=1),
)
]
```
Place the aforementioned configuration in a file, such as "configs/eval_internlm2_lmdeploy.py". Then, in the home folder of OpenCompass, start evaluation by the following command:
```shell
python run.py configs/eval_internlm2_lmdeploy.py -w outputs
```
You are expected to get the evaluation results after the inference and evaluation.

View File

@ -1,78 +0,0 @@
# Evaluation with LMDeploy
We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. **TurboMind** is an efficient inference engine proposed by LMDeploy. OpenCompass is compatible with TurboMind. We now illustrate how to evaluate a model with the support of TurboMind in OpenCompass.
## Setup
### Install OpenCompass
Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
### Install LMDeploy
Install lmdeploy via pip (python 3.8+)
```shell
pip install lmdeploy
```
## Evaluation
OpenCompass integrates turbomind's python API for evaluation.
We take the InternLM-20B as example. Firstly, we prepare the evaluation config `configs/eval_internlm_turbomind.py`:
```python
from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # this path should be same as in huggingface
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<eoa>'
)
models = [internlm_20b]
```
Then, in the home folder of OpenCompass, start evaluation by the following command:
```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
```
You are expected to get the evaluation results after the inference and evaluation.
**Note**:
- If you want to pass more arguments for `engine_config`和`gen_config` in the evaluation config file, please refer to [TurbomindEngineConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#turbomindengineconfig)
and [GenerationConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#generationconfig)
- If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py`
- If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line.

View File

@ -0,0 +1,86 @@
# 使用 LMDeploy 加速评测
我们支持在评测大语言模型时,使用 [LMDeploy](https://github.com/InternLM/lmdeploy) 作为推理加速引擎。LMDeploy 是涵盖了 LLM 和 VLM 任务的全套轻量化、部署和服务解决方案,拥有卓越的推理性能。本教程将介绍如何使用 LMDeploy 加速对模型的评测。
## 环境配置
### 安装 OpenCompass
请根据 OpenCompass [安装指南](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) 来安装算法库和准备数据集。
### 安装 LMDeploy
使用 pip 安装 LMDeploy (python 3.8+)
```shell
pip install lmdeploy
```
LMDeploy 预编译包默认基于 CUDA 12 编译。如果需要在 CUDA 11+ 下安装 LMDeploy请执行以下命令
```shell
export LMDEPLOY_VERSION=0.6.0
export PYTHON_VERSION=310
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
## 评测
在评测一个模型时,需要准备一份评测配置,指明评测集、模型和推理参数等信息。
以 [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) 模型为例,相关的配置信息如下:
```python
# configure the dataset
from mmengine.config import read_base
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \
gsm8k_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# configure lmdeploy
from opencompass.models import TurboMindModelwithChatTemplate
# configure the model
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr=f'internlm2-chat-7b-lmdeploy',
# model path, which can be the address of a model repository on the Hugging Face Hub or a local path
path='internlm/internlm2-chat-7b',
# inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
# If the model is not supported by 'turbomind', it will fallback to
# 'pytorch'
backend='turbomind',
# For the detailed engine config and generation config, please refer to
# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
engine_config=dict(tp=1),
gen_config=dict(do_sample=False),
# the max size of the context window
max_seq_len=7168,
# the max number of new tokens
max_out_len=1024,
# the max number of prompts that LMDeploy receives
# in `generate` function
batch_size=32,
run_cfg=dict(num_gpus=1),
)
]
```
把上述配置放在文件中,比如 "configs/eval_internlm2_lmdeploy.py"。然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
```shell
python run.py configs/eval_internlm2_lmdeploy.py -w outputs
```

View File

@ -1,75 +0,0 @@
# 评测 LMDeploy 模型
我们支持评测使用 [LMDeploy](https://github.com/InternLM/lmdeploy) 加速过的大语言模型。LMDeploy 由 MMDeploy 和 MMRazor 团队联合开发,是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。 **TurboMind** 是 LMDeploy 推出的高效推理引擎。OpenCompass 对 TurboMind 进行了适配,本教程将介绍如何使用 OpenCompass 来对 TurboMind 加速后的模型进行评测。
## 环境配置
### 安装 OpenCompass
请根据 OpenCompass [安装指南](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) 来安装算法库和准备数据集。
### 安装 LMDeploy
使用 pip 安装 LMDeploy (python 3.8+)
```shell
pip install lmdeploy
```
## 评测
OpenCompass 支持分别通过 turbomind python API 评测数据集。
下文以 InternLM-20B 模型为例,介绍如何评测。首先我们准备好测试配置文件`configs/eval_internlm_turbomind.py`:
```python
from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
with read_base():
# choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # 注意路径与huggingface保持一致
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<eoa>'
)
models = [internlm_20b]
```
然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
```
**注:**
- 如果想在测评配置文件中`engine_config`和`gen_config`字段传递更多参数,请参考[TurbomindEngineConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#turbomindengineconfig) 和 [GenerationConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#generationconfig)
- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py`
- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。

View File

@ -1,15 +1,24 @@
from opencompass.models import TurboMindModelwithChatTemplate
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='internlm2-chat-7b-turbomind',
abbr=f'internlm2-chat-7b-lmdeploy',
path='internlm/internlm2-chat-7b',
engine_config=dict(session_len=8192, max_batch_size=16, tp=1),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=4096),
# inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
# If the model is not supported by 'turbomind', it will fallback to
# 'pytorch'
backend='turbomind',
# For the detailed engine config and generation config, please refer to
# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
engine_config=dict(tp=1),
gen_config=dict(do_sample=False),
max_seq_len=8192,
max_out_len=4096,
batch_size=16,
# the max number of prompts that LMDeploy receives
# in `generate` function
batch_size=5000,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -25,8 +25,6 @@ from .interntrain import InternTrain # noqa: F401
from .krgpt_api import KrGPT # noqa: F401
from .lightllm_api import LightllmAPI, LightllmChatAPI # noqa: F401
from .llama2 import Llama2, Llama2Chat # noqa: F401
from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401
from .lmdeploy_tis import LmdeployTisModel # noqa: F401
from .minimax_api import MiniMax, MiniMaxChatCompletionV2 # noqa: F401
from .mistral_api import Mistral # noqa: F401
from .mixtral import Mixtral # noqa: F401
@ -41,7 +39,6 @@ from .rendu_api import Rendu # noqa: F401
from .sensetime_api import SenseTime # noqa: F401
from .stepfun_api import StepFun # noqa: F401
from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401
from .turbomind_with_tf_above_v4_33 import \
TurboMindModelwithChatTemplate # noqa: F401
from .unigpt_api import UniGPT # noqa: F401

View File

@ -1,188 +0,0 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.models.base import BaseModel
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret
class LmdeployPytorchModel(BaseModel):
"""Model wrapper for lmdeploy pytorch engine through python API.
Args:
path (str): path of the supported pytorch model.
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
engine_config (Dict, optional): The engine config to set
arguments like session_len, max_batch_size for TurboMind.
gen_config (Dict, optional): Generation config to set
arguments like top_k, top_p, temperature.
end_str (str, optional): Whether to trim generated strings with end_str
if the model has special ending strings that are not handled well.
Defaults to None.
"""
def __init__(self,
path: str,
concurrency: int = 8,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
engine_config: Optional[Dict] = None,
gen_config: Optional[Dict] = None,
end_str: Optional[str] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.pytorch import engine as tm
from lmdeploy.version import version_info
if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig
engine_config = PytorchEngineConfig(**engine_config)
# set thread_safe
if hasattr(engine_config, 'thread_safe'):
engine_config.thread_safe = True
if gen_config is not None:
from lmdeploy.messages import GenerationConfig
gen_config = GenerationConfig(**gen_config)
self.logger = get_logger()
tm_model = tm.Engine(path, engine_config)
self.tokenizer = tm_model.tokenizer
self.generators = [
tm_model.create_instance() for i in range(concurrency)
]
self.generator_ids = [i + 1 for i in range(concurrency)]
from transformers import GenerationConfig
try:
generation_config = GenerationConfig.from_pretrained(path)
except Exception:
generation_config = None
if generation_config and hasattr(generation_config, 'eos_token_id'):
if gen_config.stop_words is None:
stop_words = []
if isinstance(generation_config.eos_token_id, int):
stop_words.append(generation_config.eos_token_id)
else:
assert isinstance(generation_config.eos_token_id, list)
for token_id in generation_config.eos_token_id:
stop_words.append(token_id)
gen_config.stop_words = stop_words
if version_info >= (0, 6, 0):
gen_config.stop_token_ids = stop_words
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version = version_info[:2]
def generate(
self,
inputs: List[str],
max_out_len: int = 512,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of prompts
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
assert isinstance(
inputs, List), f'List(str) is expected, but got {type(inputs)}'
# split inputs into batches
batch_size = len(self.generators)
batch_inputs = [
inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)
]
results = []
for batch_input in batch_inputs:
with ThreadPoolExecutor() as executor:
_results = list(
executor.map(
self._generate,
self.generators[:len(batch_input)],
self.generator_ids[:len(batch_input)],
batch_input,
[self.gen_config] * len(batch_input),
[self.end_str] * len(batch_input),
))
results += _results
return results
def get_token_len(self, prompt: str) -> int:
input_ids = self.tokenizer.encode(prompt)
return len(input_ids)
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def _generate(self,
generator,
session_id,
prompt: PromptType,
gen_config=None,
end_str: Optional[str] = None) -> str:
"""Generate results given a list of inputs.
Args:
prompt (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
gen_config (GenerationConfig, optional): Generation
config to set arguments like top_k, top_p, temperature.
end_str (str, optional): Whether to trim generated strings
with end_str if the model has special ending strings
that are not handled well.
Defaults to None.
Returns:
str: The generated string.
"""
assert type(
prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
if self.major_version >= 0 and self.minor_version >= 4:
outputs = generator.infer(session_id,
input_ids,
gen_config=gen_config)
output_ids = outputs.token_ids
else:
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
# stop engine
if hasattr(generator, 'end'):
generator.end(session_id)
# decode output
response_all = self.tokenizer.decode(output_ids)
# trim output
if end_str:
response_all = response_all.split(end_str)[0]
# remove invalid characters
response_all = valid_str(response_all)
return response_all

View File

@ -1,200 +0,0 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from queue import Queue
from typing import Dict, List, Optional, Union
import numpy as np
from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret
def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
t = grpcclient.InferInput(name, list(input_tensor.shape),
np_to_triton_dtype(input_tensor.dtype))
t.set_data_from_numpy(input_tensor)
return t
def stream_callback(que, result, error):
"""callback function invoked by triton client."""
que.put((result, error))
class LmdeployTisModel(BaseModel):
"""Model wrapper for LMDeploy Python Backend Triton Inference Server gRPC
API.
Args:
path (str): The name of OpenAI's model.
tis_addr (str): The address (ip:port format) of turbomind's
triton inference server
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
is_api: bool = True
def __init__(self,
path: str,
tis_addr: str = '0.0.0.0:33337',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
end_str: Optional[str] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.tokenizer import Tokenizer
self.logger = get_logger()
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr
self.tokenizer = Tokenizer(path)
self.end_str = end_str
def generate(
self,
inputs: List[str or PromptList],
max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs),
[temperature] * len(inputs),
[self.end_str] * len(inputs)))
return results
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def get_token_len(self, prompt: str) -> int:
input_ids = self.tokenizer.encode(prompt)
return len(input_ids)
def _call_triton_server(self, prompt, tis_addr, session_id,
request_output_len, temperature, res_que):
import tritonclient.grpc as grpcclient
with grpcclient.InferenceServerClient(tis_addr) as client:
inputs = [
prepare_tensor('prompt',
np.array([prompt.encode()], dtype=np.object_)),
prepare_tensor('max_tokens',
np.array([request_output_len], dtype=np.int32)),
prepare_tensor('temperature',
np.array([temperature], dtype=np.float_)),
prepare_tensor('top_p', np.array([1.0], dtype=np.float_)),
prepare_tensor('top_k', np.array([1], dtype=np.int32)),
prepare_tensor('ignore_eos', np.array([False],
dtype=np.bool_)),
prepare_tensor('stream', np.array([True], dtype=np.bool_)),
]
# async_stream
client.start_stream(partial(stream_callback, res_que))
client.async_stream_infer('lmdeploy_model',
inputs,
sequence_id=session_id,
sequence_start=True,
sequence_end=True)
res_que.put(None)
return
def _process_result(self, que):
text = ''
while True:
res = que.get()
if res is not None:
result, err = res
if err is not None:
print(err)
else:
res = result.as_numpy('response').item().decode()
text += res
else:
return text
def _generate(self,
prompt: str or PromptList,
max_out_len: int,
temperature: float,
end_str: Optional[str] = None) -> str:
"""Generate results given a list of inputs.
Args:
prompt (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert type(
prompt
) is str, 'We only support string for LMDeploy Python Backend TIS API'
res_que = Queue()
self._call_triton_server(prompt=prompt,
tis_addr=self.tis_addr,
session_id=threading.currentThread().ident,
request_output_len=max_out_len,
temperature=temperature,
res_que=res_que)
text = self._process_result(res_que)
response = valid_str(text)
if end_str:
response = response.split(end_str)[0]
return response

View File

@ -1,135 +0,0 @@
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret
class TurboMindTisModel(BaseModel):
"""Model wrapper for TurboMind Triton Inference Server gRPC API.
Args:
path (str): The name of OpenAI's model.
tis_addr (str): The address (ip:port format) of turbomind's
triton inference server
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
is_api: bool = True
def __init__(
self,
path: str,
tis_addr: str = '0.0.0.0:33337',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.serve.turbomind.utils import Preprocessor
self.preprocess = Preprocessor(tis_addr)
self.logger = get_logger()
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr
def generate(
self,
inputs: List[PromptType],
max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs),
[temperature] * len(inputs)))
return results
def get_token_len(self, prompt: str) -> int:
input_ids, _ = self.preprocess(prompt)
return input_ids.shape[-1]
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def _generate(self, prompt: PromptType, max_out_len: int,
temperature: float) -> str:
"""Generate results given a list of inputs.
Args:
prompt (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert type(
prompt) is str, 'We only support string for TurboMind RPC API'
from lmdeploy.serve.turbomind.chatbot import Chatbot
chatbot = Chatbot(self.tis_addr,
temperature=temperature,
capability='completion',
top_k=1,
log_level=logging.ERROR)
for status, text, n_token in chatbot.stream_infer(
session_id=threading.currentThread().ident,
prompt=prompt,
request_output_len=max_out_len,
sequence_start=True,
sequence_end=True):
continue
response = valid_str(text)
response = response.replace('<eoa>', '')
return response

View File

@ -1,7 +1,6 @@
# flake8: noqa
# yapf: disable
import copy
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.models.base import BaseModel
@ -31,38 +30,32 @@ class TurboMindModelwithChatTemplate(BaseModel):
self,
path: str,
tokenizer_only: bool = False,
backend: str = 'turbomind',
engine_config: Dict = {},
gen_config: Dict = {},
concurrency: int = 8,
max_seq_len: int = None,
meta_template: Optional[Dict] = None,
fastchat_template: Optional[str] = None,
stop_words: List[str] = [],
):
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info
from transformers import AutoTokenizer
self.logger = get_logger()
self.path = path
self.tokenizer_only = tokenizer_only
self.template_parser = _get_meta_template(meta_template)
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
self.origin_tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
from lmdeploy import version_info
from transformers import AutoTokenizer
self.version_info = version_info
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
if not tokenizer_only:
DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len}
_engine_config = DEFAULT_ENGING_CONFIG.copy()
_engine_config.update(engine_config)
engine_config = TurbomindEngineConfig(**_engine_config)
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
self.tokenizer = tm_model.tokenizer
self.generators = [tm_model.create_instance() for i in range(concurrency)]
self.generator_ids = [i + 1 for i in range(concurrency)]
self.concurrency = concurrency
self.pipe = self._build_pipe(path, backend, _engine_config)
else:
self.pipe = None
self.gen_config = gen_config
self.version_info = version_info
self.fastchat_template = fastchat_template
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
self.logger.info(f'using stop words: {self.stop_words}')
@ -76,23 +69,23 @@ class TurboMindModelwithChatTemplate(BaseModel):
generation_config = None
if generation_config and hasattr(generation_config, 'eos_token_id'):
if isinstance(generation_config.eos_token_id, int):
potential_stop_words.append(self.origin_tokenizer.decode(generation_config.eos_token_id))
potential_stop_words.append(self.tokenizer.decode(generation_config.eos_token_id))
else:
assert isinstance(generation_config.eos_token_id, list)
for token_id in generation_config.eos_token_id:
potential_stop_words.append(self.origin_tokenizer.decode(token_id))
if self.origin_tokenizer.eos_token is not None:
potential_stop_words.append(self.origin_tokenizer.eos_token)
potential_stop_words.append(self.tokenizer.decode(token_id))
if self.tokenizer.eos_token is not None:
potential_stop_words.append(self.tokenizer.eos_token)
potential_stop_words = list(set(potential_stop_words))
potential_stop_words = [s for s in potential_stop_words if s]
return potential_stop_words
def generate(self,
inputs: List[str],
max_out_len: int = 512,
max_out_len: int,
stopping_criteria: List[str] = [],
do_sample: Optional[bool] = None,
temperature: int = 1,
temperature: float = 1.0,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
@ -104,93 +97,45 @@ class TurboMindModelwithChatTemplate(BaseModel):
List[str]: A list of generated strings.
"""
assert isinstance(inputs, List), f'List(str) is expected, but got {type(inputs)}'
messages = _convert_chat_messages(inputs)
if self.fastchat_template:
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
else:
messages = [self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
# split messages into batches
batch_messages = [messages[i:i + self.concurrency] for i in range(0, len(messages), self.concurrency)]
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
stop_words = list(set(self.stop_words + stopping_criteria))
encode_stop_words = []
if stop_words is not None and len(stop_words) > 0:
for words in stop_words:
encode_stop_words += self.tokenizer.encode(words, add_bos=False)
DEFAULT_GEN_CONFIG = {
'max_new_tokens': max_out_len,
'min_new_tokens': 1,
'top_k': 1,
'stop_words': encode_stop_words,
'stop_words': stop_words,
}
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
gen_config.update(self.gen_config)
if do_sample:
gen_config['top_k'] = 1000
gen_config['top_k'] = 40
gen_config['temperature'] = temperature
else:
if self.version_info >= (0, 6, 0):
gen_config['do_sample'] = False
else:
gen_config['top_k'] = 1
from lmdeploy.messages import GenerationConfig
from lmdeploy import GenerationConfig
gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)}
gen_config = GenerationConfig(**gen_config)
if self.version_info >= (0, 6, 0):
gen_config.stop_words = stop_words
gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
results = []
for batch_message in batch_messages:
n = len(batch_message)
with ThreadPoolExecutor() as executor:
_results = list(
executor.map(
self._generate,
self.generators[:n],
self.generator_ids[:n],
batch_message,
[gen_config] * n,
))
results += _results
outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False)
for output in outputs:
text = self.tokenizer.decode(output.token_ids)
results.append(text)
for s in stop_words:
results = [r.split(s)[0] for r in results]
return results
def _generate(self,
generator,
session_id,
prompt: PromptType,
gen_config=None) -> str:
"""Generate results given a list of inputs.
Args:
prompt (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
gen_config (GenerationConfig, optional): Generation
config to set arguments like top_k, top_p, temperature.
Returns:
str: The generated string.
"""
assert type(prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt, add_bos=False)
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids],
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
step=0,
stream_output=False):
if self.version_info >= (0, 4, 0):
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids)
response = valid_str(response)
return response
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
@ -201,5 +146,20 @@ class TurboMindModelwithChatTemplate(BaseModel):
int: Length of the input tokens
"""
m = _convert_chat_messages([prompt])[0]
t = self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
return len(t['input_ids'])
def _build_pipe(self, model_path, backend, engine_config):
from lmdeploy import (PytorchEngineConfig, TurbomindEngineConfig,
pipeline)
assert backend in ['pytorch', 'turbomind'], \
f'unsupported backend type: {backend}'
if backend == 'turbomind':
filtered = {k: v for k, v in engine_config.items() if hasattr(TurbomindEngineConfig, k)}
backend_config = TurbomindEngineConfig(**filtered)
else:
filtered = {k: v for k, v in engine_config.items() if hasattr(PytorchEngineConfig, k)}
backend_config = PytorchEngineConfig(**filtered)
return pipeline(model_path, backend_config=backend_config, log_level='INFO', max_log_len=10)

View File

@ -9,7 +9,7 @@ from mmengine.config import Config
from opencompass.datasets.custom import make_custom_dataset_config
from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel,
HuggingFaceCausalLM, HuggingFaceChatGLM3,
HuggingFacewithChatTemplate, TurboMindModel,
HuggingFacewithChatTemplate,
TurboMindModelwithChatTemplate,
VLLMwithChatTemplate)
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
@ -233,7 +233,7 @@ def change_accelerator(models, accelerator):
model_accels = []
for model in models:
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
# change HuggingFace model to VLLM or TurboMindModel
# change HuggingFace model to VLLM or LMDeploy
if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3, f'{HuggingFaceBaseModel.__module__}.{HuggingFaceBaseModel.__name__}']:
gen_args = dict()
if model.get('generation_kwargs') is not None:
@ -254,10 +254,10 @@ def change_accelerator(models, accelerator):
if accelerator == 'lmdeploy':
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
mod = TurboMindModel
mod = TurboMindModelwithChatTemplate
acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}',
abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind',
abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
path=model['path'],
engine_config=dict(session_len=model['max_seq_len'],
max_batch_size=model['batch_size'],
@ -270,7 +270,6 @@ def change_accelerator(models, accelerator):
max_out_len=model['max_out_len'],
max_seq_len=model['max_seq_len'],
batch_size=model['batch_size'],
concurrency=model['batch_size'],
run_cfg=model['run_cfg'],
)
for item in ['meta_template']:
@ -312,7 +311,7 @@ def change_accelerator(models, accelerator):
mod = TurboMindModelwithChatTemplate
acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}',
abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind',
abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
path=model['path'],
engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9),