[Fix] Fix turbomind and update docs (#808)

* update

* update docs

* add engine_config and gen_config in eval_config

* update

* fix

* fix

* fix

* fix docstr

* fix url
This commit is contained in:
RunningLeon 2024-01-18 14:41:35 +08:00 committed by GitHub
parent 9e5746d3d8
commit 61fe873c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 322 deletions

View File

@ -1,17 +1,16 @@
from mmengine.config import read_base from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel from opencompass.models.turbomind import TurboMindModel
with read_base(): with read_base():
# choose a list of datasets # choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
# from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
# from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
# from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
# from .datasets.race.race_gen_69ee4f import race_datasets from .datasets.race.race_gen_69ee4f import race_datasets
# from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format # and output the results in a choosen format
from .summarizers.medium import summarizer from .summarizers.medium import summarizer
@ -24,56 +23,18 @@ internlm_meta_template = dict(round=[
], ],
eos_token_id=103028) eos_token_id=103028)
llama2_meta_template = dict(
round=[
dict(role='HUMAN', begin='[INST] ', end=' [/INST]'),
dict(role='BOT', generate=True),
],
eos_token_id=2)
qwen_meta_template = dict(round=[
dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'),
dict(role='BOT',
begin='\n<|im_start|>assistant\n',
end='<|im_end|>',
generate=True)
])
baichuan2_meta_template = dict(round=[
dict(role='HUMAN', begin='<reserved_106>'),
dict(role='BOT', begin='<reserved_107>', generate=True)
])
# config for internlm-chat-7b # config for internlm-chat-7b
internlm_chat_7b = dict( internlm_chat_7b = dict(
type=TurboMindModel, type=TurboMindModel,
abbr='internlm-chat-7b-turbomind', abbr='internlm-chat-7b-turbomind',
path='internlm/internlm-chat-7b', path='internlm/internlm-chat-7b',
max_out_len=100, engine_config=dict(session_len=2048,
max_seq_len=2048, max_batch_size=32,
batch_size=32, rope_scaling_factor=1.0),
concurrency=32, gen_config=dict(top_k=1,
meta_template=internlm_meta_template, top_p=0.8,
run_cfg=dict(num_gpus=1, num_procs=1), temperature=1.0,
) max_new_tokens=100),
internlm_chat_7b_w4 = dict(
type=TurboMindModel,
abbr='internlm-chat-7b-w4-turbomind',
path='internlm/internlm-chat-7b-w4',
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-chat-7b-w4kv8 model
internlm_chat_7b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-chat-7b-w4kv8-turbomind',
path='internlm/internlm-chat-7b-w4kv8',
max_out_len=100, max_out_len=100,
max_seq_len=2048, max_seq_len=2048,
batch_size=32, batch_size=32,
@ -87,6 +48,13 @@ internlm_chat_20b = dict(
type=TurboMindModel, type=TurboMindModel,
abbr='internlm-chat-20b-turbomind', abbr='internlm-chat-20b-turbomind',
path='internlm/internlm-chat-20b', path='internlm/internlm-chat-20b',
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100, max_out_len=100,
max_seq_len=2048, max_seq_len=2048,
batch_size=8, batch_size=8,
@ -95,108 +63,4 @@ internlm_chat_20b = dict(
run_cfg=dict(num_gpus=1, num_procs=1), run_cfg=dict(num_gpus=1, num_procs=1),
) )
# config for internlm-chat-20b-w4 model
internlm_chat_20b_w4 = dict(
type=TurboMindModel,
abbr='internlm-chat-20b-w4-turbomind',
path='internlm/internlm-chat-20b-w4',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-chat-20b-w4kv8 model
internlm_chat_20b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-chat-20b-w4kv8-turbomind',
path='internlm/internlm-chat-20b-w4kv8',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=internlm_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-7b
llama2_chat_7b = dict(
type=TurboMindModel,
abbr='llama2-chat-7b-turbomind',
path='meta-llama/Llama-2-7b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-13b
llama2_chat_13b = dict(
type=TurboMindModel,
abbr='llama2-chat-13b-turbomind',
path='meta-llama/Llama-2-13b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for llama2-chat-70b
llama2_chat_70b = dict(
type=TurboMindModel,
abbr='llama2-chat-70b-turbomind',
path='meta-llama/Llama-2-70b-chat-hf',
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
meta_template=llama2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for qwen-chat-7b
qwen_chat_7b = dict(
type=TurboMindModel,
abbr='qwen-chat-7b-turbomind',
path='Qwen/Qwen-7B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=qwen_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for qwen-chat-7b
qwen_chat_14b = dict(
type=TurboMindModel,
abbr='qwen-chat-14b-turbomind',
path='Qwen/Qwen-14B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=qwen_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for baichuan2-chat-7b
baichuan2_chat_7b = dict(
type=TurboMindModel,
abbr='baichuan2-chat-7b-turbomind',
path='baichuan-inc/Baichuan2-7B-Chat',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=32,
meta_template=baichuan2_meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_chat_20b] models = [internlm_chat_20b]

View File

@ -20,30 +20,13 @@ internlm_7b = dict(
type=TurboMindModel, type=TurboMindModel,
abbr='internlm-7b-turbomind', abbr='internlm-7b-turbomind',
path="internlm/internlm-7b", path="internlm/internlm-7b",
max_out_len=100, engine_config=dict(session_len=2048,
max_seq_len=2048, max_batch_size=32,
batch_size=32, rope_scaling_factor=1.0),
concurrency=32, gen_config=dict(top_k=1,
run_cfg=dict(num_gpus=1, num_procs=1), top_p=0.8,
) temperature=1.0,
max_new_tokens=100),
# # config for internlm-7b-w4 model
internlm_7b_w4 = dict(
type=TurboMindModel,
abbr='internlm-7b-w4-turbomind',
path="internlm/internlm-7b-w4",
max_out_len=100,
max_seq_len=2048,
batch_size=32,
concurrency=32,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# # config for internlm-7b-w4kv8 model
internlm_7b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-7b-w4kv8-turbomind',
path="internlm/internlm-7b-w4kv8",
max_out_len=100, max_out_len=100,
max_seq_len=2048, max_seq_len=2048,
batch_size=32, batch_size=32,
@ -56,6 +39,12 @@ internlm_20b = dict(
type=TurboMindModel, type=TurboMindModel,
abbr='internlm-20b-turbomind', abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", path="internlm/internlm-20b",
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100, max_out_len=100,
max_seq_len=2048, max_seq_len=2048,
batch_size=8, batch_size=8,
@ -63,29 +52,4 @@ internlm_20b = dict(
run_cfg=dict(num_gpus=1, num_procs=1), run_cfg=dict(num_gpus=1, num_procs=1),
) )
# config for internlm-20b-w4 model
internlm_20b_w4 = dict(
type=TurboMindModel,
abbr='internlm-20b-w4-turbomind',
path="internlm/internlm-20b-w4",
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
# config for internlm-20b-w4kv8 model
internlm_20b_w4kv8 = dict(
type=TurboMindModel,
abbr='internlm-20b-w4kv8-turbomind',
path="internlm/internlm-20b-w4kv8",
max_out_len=100,
max_seq_len=2048,
batch_size=16,
concurrency=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b] models = [internlm_20b]

View File

@ -18,21 +18,50 @@ pip install lmdeploy
## Evaluation ## Evaluation
OpenCompass integrates both turbomind's python API and gRPC API for evaluation. And the former is highly recommended. OpenCompass integrates turbomind's python API for evaluation.
We take the InternLM-20B as example. Please download it from huggingface: We take the InternLM-20B as example. Firstly, we prepare the evaluation config `configs/eval_internlm_turbomind.py`:
```shell ```python
# Download InternLM model(or use the cached model's checkpoint) from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install with read_base():
git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b # choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # this path should be same as in huggingface
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b]
``` ```
### Evaluation with Turbomind Python API (recommended) Then, in the home folder of OpenCompass, start evaluation by the following command:
In the home folder of OpenCompass, start evaluation by the following command:
```shell ```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
@ -42,42 +71,7 @@ You are expected to get the evaluation results after the inference and evaluatio
**Note**: **Note**:
- If you want to pass more arguments for `engine_config`和`gen_config` in the evaluation config file, please refer to [TurbomindEngineConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#turbomindengineconfig)
and [EngineGenerationConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#generationconfig)
- If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py` - If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py`
- If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line. - If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line.
- If you want to evaluate other chat models like Llama2, QWen-7B, Baichuan2-7B, you could change to the setting of `models` in `eval_internlm_chat_turbomind.py`.
### Evaluation with Turbomind gPRC API (optional)
Convert model to TurboMind format using lmdeploy
```shell
lmdeploy convert internlm /path/to/internlm-20b \
--dst-path {/home/folder/of/opencompass}/turbomind
```
**Note**:
If evaluating the InternLM Chat model, make sure to pass `internlm-chat` as the model name instead of `internlm` when converting the model format. The specific command is:
```shell
lmdeploy convert internlm-chat /path/to/internlm-20b-chat \
--dst-path {/home/folder/of/opencompass}/turbomind
```
In the home folder of OpenCompass, launch the Triton Inference Server:
```shell
bash turbomind/service_docker_up.sh
```
And start evaluation by the following command:
```shell
python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b
```
\*\*Note: \*\*
- If the InternLM Chat model is requested to be evaluated, please use config file `eval_internlm_chat_turbomind_tis.py`
- In `eval_internlm_turbomind_tis.py`, the configured Triton Inference Server (TIS) address is `tis_addr='0.0.0.0:33337'`. Please modify `tis_addr` to the IP address of the machine where the server is launched.
- If evaluating the InternLM 7B model, please modify the `models` configuration in `eval_internlm_xxx_turbomind_tis.py`.

View File

@ -18,21 +18,50 @@ pip install lmdeploy
## 评测 ## 评测
OpenCompass 支持分别通过 turbomind python API 和 gRPC API 评测数据集。我们强烈推荐使用前者进行评测。 OpenCompass 支持分别通过 turbomind python API 评测数据集。
下文以 InternLM-20B 模型为例,介绍如何评测。首先,从 huggingface 上下载 InternLM 模型: 下文以 InternLM-20B 模型为例,介绍如何评测。首先我们准备好测试配置文件`configs/eval_internlm_turbomind.py`:
```shell ```python
Download InternLM model(or use the cached model's checkpoint) from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install with read_base():
git clone https://huggingface.co/internlm/internlm-20b /path/to/internlm-20b # choose a list of datasets
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
# and output the results in a chosen format
from .summarizers.medium import summarizer
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
# config for internlm-20b model
internlm_20b = dict(
type=TurboMindModel,
abbr='internlm-20b-turbomind',
path="internlm/internlm-20b", # 注意路径与huggingface保持一致
engine_config=dict(session_len=2048,
max_batch_size=8,
rope_scaling_factor=1.0),
gen_config=dict(top_k=1, top_p=0.8,
temperature=1.0,
max_new_tokens=100),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_20b]
``` ```
### 通过 TurboMind Python API 评测(推荐) 然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
在 OpenCompass 的项目目录下,执行如下命令可得到评测结果:
```shell ```shell
python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b
@ -40,41 +69,6 @@ python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-2
**注:** **注:**
- 如果想在测评配置文件中`engine_config`和`gen_config`字段传递更多参数,请参考[TurbomindEngineConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#turbomindengineconfig) 和 [EngineGenerationConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#generationconfig)
- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py` - 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py`
- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。 - 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。
- 如果评测其他模型如 Llama2, QWen-7B, Baichuan2-7B, 请修改`eval_internlm_chat_turbomind.py`中`models`字段 。
### 通过 TurboMind gPRC API 评测(可选)
首先需要将模型转换为 turbomind 格式
```shell script
lmdeploy convert internlm /path/to/internlm-20b \
--dst-path {/home/folder/of/opencompass}/turbomind
```
注意:如果评测 InternLM Chat 模型,那么在转换模型格式的时候,模型名字要填写 `internlm-chat`。具体命令是:
```shell
lmdeploy convert internlm-chat /path/to/internlm-20b-chat \
--dst-path {/home/folder/of/opencompass}/turbomind
```
在 OpenCompass 的项目目录下,启动 triton inference server
```shell
bash turbomind/service_docker_up.sh
```
然后,执行如下命令进行评测:
```shell
python run.py configs/eval_internlm_turbomind_tis.py -w outputs/turbomind-tis/internlm-20b
``
**注:**
- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind_tis.py`
- 在配置文件中triton inference server(TIS) 地址是 `tis_addr='0.0.0.0:33337'`。请把配置中的`tis_addr`修改为server所在机器的ip地址。
- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_xxx_turbomind_tis.py`中`models`字段。
```

View File

@ -30,43 +30,49 @@ class TurboMindModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
engine_config (Dict, optional): The engine config to set
arguments like session_len, max_batch_size for TurboMind.
gen_config (Dict, optional): Generation config to set
arguments like top_k, top_p, temperature.
""" """
def __init__( def __init__(self,
self,
path: str, path: str,
concurrency: int = 8, concurrency: int = 8,
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
): engine_config: Optional[Dict] = None,
from lmdeploy.turbomind import TurboMind gen_config: Optional[Dict] = None):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
from lmdeploy.turbomind import TurboMind
if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig
engine_config = TurbomindEngineConfig(**engine_config)
if gen_config is not None:
from lmdeploy.messages import EngineGenerationConfig
gen_config = EngineGenerationConfig(**gen_config)
self.logger = get_logger() self.logger = get_logger()
tm_model = TurboMind.from_pretrained(path) tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
self.tokenizer = tm_model.tokenizer self.tokenizer = tm_model.tokenizer
self.generators = [ self.generators = [
tm_model.create_instance() for i in range(concurrency) tm_model.create_instance() for i in range(concurrency)
] ]
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
def generate( def generate(
self, self,
inputs: List[str], inputs: List[str],
max_out_len: int = 512, max_out_len: int = 512,
temperature: float = 1.0,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str]): A list of prompts inputs (List[str]): A list of prompts
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 1.0.
Returns: Returns:
List[str]: A list of generated strings. List[str]: A list of generated strings.
@ -88,7 +94,7 @@ class TurboMindModel(BaseModel):
self.generators[:len(batch_input)], self.generators[:len(batch_input)],
self.generator_ids[:len(batch_input)], self.generator_ids[:len(batch_input)],
batch_input, [max_out_len] * len(batch_input), batch_input, [max_out_len] * len(batch_input),
[temperature] * len(batch_input))) [self.gen_config] * len(batch_input)))
results += _results results += _results
return results return results
@ -103,8 +109,12 @@ class TurboMindModel(BaseModel):
""" """
return self.token_bucket.get_token() return self.token_bucket.get_token()
def _generate(self, generator, session_id, prompt: str or PromptList, def _generate(self,
max_out_len: int, temperature: float) -> str: generator,
session_id,
prompt: str or PromptList,
max_out_len: int,
gen_config=None) -> str:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
@ -112,10 +122,8 @@ class TurboMindModel(BaseModel):
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use, gen_config (EngineGenerationConfig, optional): Generation
between 0 and 2. Higher values like 0.8 will make the output config to set arguments like top_k, top_p, temperature.
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns: Returns:
str: The generated string. str: The generated string.
@ -127,11 +135,10 @@ class TurboMindModel(BaseModel):
for outputs in generator.stream_infer(session_id=session_id, for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids], input_ids=[input_ids],
gen_config=gen_config,
request_output_len=max_out_len, request_output_len=max_out_len,
sequence_start=True, sequence_start=True,
sequence_end=True, sequence_end=True,
top_k=1,
top_p=0.8,
step=0, step=0,
stream_output=False): stream_output=False):
_, output_ids, _ = outputs _, output_ids, _ = outputs