mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Update DeepSeek-R1 example
This commit is contained in:
parent
8103c0d245
commit
9843e3c63c
@ -39,6 +39,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
|
|||||||
user_guides/evaluation.md
|
user_guides/evaluation.md
|
||||||
user_guides/experimentation.md
|
user_guides/experimentation.md
|
||||||
user_guides/metrics.md
|
user_guides/metrics.md
|
||||||
|
user_guides/deepseek_r1.md
|
||||||
|
|
||||||
.. _Prompt:
|
.. _Prompt:
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
192
docs/en/user_guides/deepseek_r1.md
Normal file
192
docs/en/user_guides/deepseek_r1.md
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
# Tutorial for Evaluating Reasoning Models
|
||||||
|
|
||||||
|
OpenCompass provides an evaluation tutorial for DeepSeek R1 series reasoning models (mathematical datasets).
|
||||||
|
|
||||||
|
- At the model level, we recommend using the sampling approach to reduce repetitions caused by greedy decoding
|
||||||
|
- For datasets with limited samples, we employ multiple evaluation runs and take the average
|
||||||
|
- For answer validation, we utilize LLM-based verification to reduce misjudgments from rule-based evaluation
|
||||||
|
|
||||||
|
## Installation and Preparation
|
||||||
|
|
||||||
|
Please follow OpenCompass's installation guide.
|
||||||
|
|
||||||
|
## Evaluation Configuration Setup
|
||||||
|
|
||||||
|
We provide example configurations in `example/eval_deepseek_r1.py`. Below is the configuration explanation:
|
||||||
|
|
||||||
|
### Configuration Interpretation
|
||||||
|
|
||||||
|
#### 1. Dataset and Validator Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Configuration supporting multiple runs (example)
|
||||||
|
from opencompass.configs.datasets.aime2024.aime2024_llmverify_repeat8_gen_e8fcee import aime2024_datasets
|
||||||
|
|
||||||
|
datasets = sum(
|
||||||
|
(v for k, v in locals().items() if k.endswith('_datasets')),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM validator configuration. Users need to deploy API services via LMDeploy/vLLM/SGLang or use OpenAI-compatible endpoints
|
||||||
|
verifier_cfg = dict(
|
||||||
|
abbr='qwen2-5-32B-Instruct',
|
||||||
|
type=OpenAISDK,
|
||||||
|
path='Qwen/Qwen2.5-32B-Instruct', # Replace with actual path
|
||||||
|
key='YOUR_API_KEY', # Use real API key
|
||||||
|
openai_api_base=['http://your-api-endpoint'], # Replace with API endpoint
|
||||||
|
query_per_second=16,
|
||||||
|
batch_size=1024,
|
||||||
|
temperature=0.001,
|
||||||
|
max_out_len=16384
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply validator to all datasets
|
||||||
|
for item in datasets:
|
||||||
|
if 'judge_cfg' in item['eval_cfg']['evaluator']:
|
||||||
|
item['eval_cfg']['evaluator']['judge_cfg'] = verifier_cfg
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Model Configuration
|
||||||
|
|
||||||
|
We provided an example of evaluation based on LMDeploy as the reasoning model backend, users can modify path (i.e., HF path)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# LMDeploy model configuration example
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='deepseek-r1-distill-qwen-7b-turbomind',
|
||||||
|
path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
|
||||||
|
engine_config=dict(session_len=32768, max_batch_size=128, tp=1),
|
||||||
|
gen_config=dict(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.95,
|
||||||
|
max_new_tokens=32768
|
||||||
|
),
|
||||||
|
max_seq_len=32768,
|
||||||
|
batch_size=64,
|
||||||
|
run_cfg=dict(num_gpus=1),
|
||||||
|
pred_postprocessor=dict(type=extract_non_reasoning_content)
|
||||||
|
),
|
||||||
|
# Extendable 14B/32B configurations...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Evaluation Process Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Inference configuration
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=1),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
|
||||||
|
# Evaluation configuration
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner, n=8),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLEvalTask)))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Summary Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Multiple runs results average configuration
|
||||||
|
summary_groups = [
|
||||||
|
{
|
||||||
|
'name': 'AIME2024-Aveage8',
|
||||||
|
'subsets':[[f'aime2024-run{idx}', 'accuracy'] for idx in range(8)]
|
||||||
|
},
|
||||||
|
# Other dataset average configurations...
|
||||||
|
]
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
dataset_abbrs=[
|
||||||
|
['AIME2024-Aveage8', 'naive_average'],
|
||||||
|
# Other dataset metrics...
|
||||||
|
],
|
||||||
|
summary_groups=summary_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# Work directory configuration
|
||||||
|
work_dir = "outputs/deepseek_r1_reasoning"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Execution
|
||||||
|
|
||||||
|
### Scenario 1: Model loaded on 1 GPU, data evaluated by 1 worker, using a total of 1 GPU
|
||||||
|
|
||||||
|
```bash
|
||||||
|
opencompass example/eval_deepseek_r1.py --debug --dump-eval-details
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation logs will be output in the command line.
|
||||||
|
|
||||||
|
### Scenario 2: Model loaded on 1 GPU, data evaluated by 8 workers, using a total of 8 GPUs
|
||||||
|
|
||||||
|
You need to modify the `infer` configuration in the configuration file and set `num_worker` to 8
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Inference configuration
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=1),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
```
|
||||||
|
|
||||||
|
At the same time, remove the `--debug` parameter from the evaluation command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
opencompass example/eval_deepseek_r1.py --dump-eval-details
|
||||||
|
```
|
||||||
|
|
||||||
|
In this mode, OpenCompass will use multithreading to start `$num_worker` tasks. Specific logs will not be displayed in the command line, instead, detailed evaluation logs will be shown under `$work_dir`.
|
||||||
|
|
||||||
|
### Scenario 3: Model loaded on 2 GPUs, data evaluated by 4 workers, using a total of 8 GPUs
|
||||||
|
|
||||||
|
Note that in the model configuration, `num_gpus` in `run_cfg` needs to be set to 2 (if using an inference backend, parameters such as `tp` in LMDeploy also need to be modified accordingly to 2), and at the same time, set `num_worker` in the `infer` configuration to 4
|
||||||
|
|
||||||
|
```python
|
||||||
|
models += [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='deepseek-r1-distill-qwen-14b-turbomind',
|
||||||
|
path='deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
|
||||||
|
engine_config=dict(session_len=32768, max_batch_size=128, tp=2),
|
||||||
|
gen_config=dict(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.95,
|
||||||
|
max_new_tokens=32768),
|
||||||
|
max_seq_len=32768,
|
||||||
|
max_out_len=32768,
|
||||||
|
batch_size=128,
|
||||||
|
run_cfg=dict(num_gpus=2),
|
||||||
|
pred_postprocessor=dict(type=extract_non_reasoning_content)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Inference configuration
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=4),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation Results
|
||||||
|
|
||||||
|
The evaluation results are displayed as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dataset version metric mode deepseek-r1-distill-qwen-7b-turbomind ---------------------------------- --------- ------------- ------ --------------------------------------- MATH - - - AIME2024-Aveage8 - naive_average gen 56.25
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Baseline
|
||||||
|
|
||||||
|
Since the model uses Sampling for decoding, and the AIME dataset size is small, there may still be a performance fluctuation of 1-3 points even when averaging over 8 evaluations.
|
||||||
|
|
||||||
|
| Model | Dataset | Metric | Value |
|
||||||
|
| ---------------------------- | -------- | -------- | ----- |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-7B | AIME2024 | Accuracy | 56.3 |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-14B | AIME2024 | Accuracy | 74.2 |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-32B | AIME2024 | Accuracy | 74.2 |
|
@ -40,6 +40,7 @@ OpenCompass 上手路线
|
|||||||
user_guides/evaluation.md
|
user_guides/evaluation.md
|
||||||
user_guides/experimentation.md
|
user_guides/experimentation.md
|
||||||
user_guides/metrics.md
|
user_guides/metrics.md
|
||||||
|
user_guides/deepseek_r1.md
|
||||||
|
|
||||||
.. _提示词:
|
.. _提示词:
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
192
docs/zh_cn/user_guides/deepseek_r1.md
Normal file
192
docs/zh_cn/user_guides/deepseek_r1.md
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
# 强推理模型评测教程
|
||||||
|
|
||||||
|
OpenCompass提供针对DeepSeek R1系列推理模型的评测教程(数学数据集)。
|
||||||
|
|
||||||
|
- 在模型层面,我们建议使用Sampling方式,以减少因为Greedy评测带来的大量重复
|
||||||
|
- 在数据集层面,我们对数据量较小的评测基准,使用多次评测并取平均的方式。
|
||||||
|
- 在答案验证层面,为了减少基于规则评测带来的误判,我们统一使用基于LLM验证的方式进行评测。
|
||||||
|
|
||||||
|
## 安装和准备
|
||||||
|
|
||||||
|
请按OpenCompass安装教程进行安装。
|
||||||
|
|
||||||
|
## 构建评测配置
|
||||||
|
|
||||||
|
我们在 `example/eval_deepseek_r1.py` 中提供了示例配置,以下对评测配置进行解读
|
||||||
|
|
||||||
|
### 评测配置解读
|
||||||
|
|
||||||
|
#### 1. 数据集与验证器配置
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 支持多运行次数的数据集配置(示例)
|
||||||
|
from opencompass.configs.datasets.aime2024.aime2024_llmverify_repeat8_gen_e8fcee import aime2024_datasets
|
||||||
|
|
||||||
|
datasets = sum(
|
||||||
|
(v for k, v in locals().items() if k.endswith('_datasets')),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置LLM验证器, 用户需事先通过LMDeploy/vLLM/SGLang等工具启动API 评测服务器,或者直接使用兼容OpenAI标准接口的模型服务
|
||||||
|
verifier_cfg = dict(
|
||||||
|
abbr='qwen2-5-32B-Instruct',
|
||||||
|
type=OpenAISDK,
|
||||||
|
path='Qwen/Qwen2.5-32B-Instruct', # 需替换实际路径
|
||||||
|
key='YOUR_API_KEY', # 需替换真实API Key
|
||||||
|
openai_api_base=['http://your-api-endpoint'], # 需替换API地址
|
||||||
|
query_per_second=16,
|
||||||
|
batch_size=1024,
|
||||||
|
temperature=0.001,
|
||||||
|
max_out_len=16384
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用验证器到所有数据集
|
||||||
|
for item in datasets:
|
||||||
|
if 'judge_cfg' in item['eval_cfg']['evaluator']:
|
||||||
|
item['eval_cfg']['evaluator']['judge_cfg'] = verifier_cfg
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 模型配置
|
||||||
|
|
||||||
|
我们提供了基于LMDeploy作为推理后端的评测示例,用户可以通过修改path(即HF路径)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# LMDeploy模型配置示例
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='deepseek-r1-distill-qwen-7b-turbomind',
|
||||||
|
path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
|
||||||
|
engine_config=dict(session_len=32768, max_batch_size=128, tp=1),
|
||||||
|
gen_config=dict(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.95,
|
||||||
|
max_new_tokens=32768
|
||||||
|
),
|
||||||
|
max_seq_len=32768,
|
||||||
|
batch_size=64,
|
||||||
|
run_cfg=dict(num_gpus=1),
|
||||||
|
pred_postprocessor=dict(type=extract_non_reasoning_content)
|
||||||
|
),
|
||||||
|
# 可扩展14B/32B配置...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 评估流程配置
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 推理配置
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=1),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
|
||||||
|
# 评估配置
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner, n=8),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLEvalTask)))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. 结果汇总配置
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 多运行结果平均配置
|
||||||
|
summary_groups = [
|
||||||
|
{
|
||||||
|
'name': 'AIME2024-Aveage8',
|
||||||
|
'subsets':[[f'aime2024-run{idx}', 'accuracy'] for idx in range(8)]
|
||||||
|
},
|
||||||
|
# 其他数据集平均配置...
|
||||||
|
]
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
dataset_abbrs=[
|
||||||
|
['AIME2024-Aveage8', 'naive_average'],
|
||||||
|
# 其他数据集指标...
|
||||||
|
],
|
||||||
|
summary_groups=summary_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# 工作目录设置
|
||||||
|
work_dir = "outputs/deepseek_r1_reasoning"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 执行评测
|
||||||
|
|
||||||
|
### 场景1:模型1卡加载,数据1个worker评测,共使用1个GPU
|
||||||
|
|
||||||
|
```bash
|
||||||
|
opencompass example/eval_deepseek_r1.py --debug --dump-eval-details
|
||||||
|
```
|
||||||
|
|
||||||
|
评测日志会在命令行输出。
|
||||||
|
|
||||||
|
### 场景2:模型1卡加载,数据8个worker评测,共使用8个GPU
|
||||||
|
|
||||||
|
需要修改配置文件中的infer配置,将num_worker设置为8
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 推理配置
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=1),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
```
|
||||||
|
|
||||||
|
同时评测命令去掉`--debug`参数
|
||||||
|
|
||||||
|
```bash
|
||||||
|
opencompass example/eval_deepseek_r1.py --dump-eval-details
|
||||||
|
```
|
||||||
|
|
||||||
|
此模式下,OpenCompass将使用多线程启动`$num_worker`个任务,命令行不展示具体日志,具体的评测日志将会在`$work_dir`下中展示。
|
||||||
|
|
||||||
|
### 场景3:模型2卡加载,数据4个worker评测,共使用8个GPU
|
||||||
|
|
||||||
|
需要注意模型配置中,`run_cfg`中的`num_gpus`需要设置为2(如使用推理后端,则推理后端的参数也需要同步修改,比如LMDeploy中的tp需要设置为2),同时修改`infer`配置中的`num_worker`为4
|
||||||
|
|
||||||
|
```python
|
||||||
|
models += [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='deepseek-r1-distill-qwen-14b-turbomind',
|
||||||
|
path='deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
|
||||||
|
engine_config=dict(session_len=32768, max_batch_size=128, tp=2),
|
||||||
|
gen_config=dict(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.95,
|
||||||
|
max_new_tokens=32768),
|
||||||
|
max_seq_len=32768,
|
||||||
|
max_out_len=32768,
|
||||||
|
batch_size=128,
|
||||||
|
run_cfg=dict(num_gpus=2),
|
||||||
|
pred_postprocessor=dict(type=extract_non_reasoning_content)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 推理配置
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=4),
|
||||||
|
runner=dict(type=LocalRunner, task=dict(type=OpenICLInferTask))
|
||||||
|
```
|
||||||
|
|
||||||
|
### 评测结果
|
||||||
|
|
||||||
|
评测结果展示如下:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
dataset version metric mode deepseek-r1-distill-qwen-7b-turbomind ---------------------------------- --------- ------------- ------ --------------------------------------- MATH - - - AIME2024-Aveage8 - naive_average gen 56.25
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能基线参考
|
||||||
|
|
||||||
|
由于模型使用Sampling进行解码,同时AIME数据量较小,使用8次评测取平均情况下,仍会出现1-3分的性能抖动
|
||||||
|
|
||||||
|
| 模型 | 数据集 | 指标 | 数值 |
|
||||||
|
| ---------------------------- | -------- | -------- | ---- |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-7B | AIME2024 | Accuracy | 56.3 |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-14B | AIME2024 | Accuracy | 74.2 |
|
||||||
|
| DeepSeek-R1-Distill-Qwen-32B | AIME2024 | Accuracy | 74.2 |
|
@ -22,12 +22,14 @@ from opencompass.models import (
|
|||||||
with read_base():
|
with read_base():
|
||||||
# You can comment out the datasets you don't want to evaluate
|
# You can comment out the datasets you don't want to evaluate
|
||||||
|
|
||||||
|
# Datasets
|
||||||
# from opencompass.configs.datasets.math.math_prm800k_500_llmverify_gen_6ff468 import math_datasets # 1 Run
|
# from opencompass.configs.datasets.math.math_prm800k_500_llmverify_gen_6ff468 import math_datasets # 1 Run
|
||||||
from opencompass.configs.datasets.aime2024.aime2024_llmverify_repeat8_gen_e8fcee import aime2024_datasets # 8 Run
|
from opencompass.configs.datasets.aime2024.aime2024_llmverify_repeat8_gen_e8fcee import aime2024_datasets # 8 Run
|
||||||
# from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_llmverify_gen_be8b13 import olympiadbench_datasets
|
# from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_llmverify_gen_be8b13 import olympiadbench_datasets
|
||||||
# from opencompass.configs.datasets.omni_math.omni_math_llmverify_gen_ccf9c0 import omnimath_datasets # 1 Run
|
# from opencompass.configs.datasets.omni_math.omni_math_llmverify_gen_ccf9c0 import omnimath_datasets # 1 Run
|
||||||
# from opencompass.configs.datasets.livemathbench.livemathbench_hard_custom_llmverify_gen_85d0ef import livemathbench_datasets
|
# from opencompass.configs.datasets.livemathbench.livemathbench_hard_custom_llmverify_gen_85d0ef import livemathbench_datasets
|
||||||
|
|
||||||
|
|
||||||
# Summarizer
|
# Summarizer
|
||||||
from opencompass.configs.summarizers.groups.OlympiadBench import OlympiadBenchMath_summary_groups
|
from opencompass.configs.summarizers.groups.OlympiadBench import OlympiadBenchMath_summary_groups
|
||||||
|
|
||||||
@ -205,6 +207,6 @@ summarizer = dict(
|
|||||||
# PART 5 Utils #
|
# PART 5 Utils #
|
||||||
#######################################################################
|
#######################################################################
|
||||||
|
|
||||||
work_dir = "outputs/deepseek_r1_reasoning"
|
work_dir = 'outputs/deepseek_r1_reasoning'
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,6 +130,7 @@ class TurboMindModelwithChatTemplate(BaseModel):
|
|||||||
if self.fastchat_template:
|
if self.fastchat_template:
|
||||||
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
|
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
|
||||||
else:
|
else:
|
||||||
|
# NOTE: DeepSeek-R1 series model's chat template will add <think> after the
|
||||||
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
|
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
|
||||||
# LMDeploy tokenize prompts by AutoTokenizer with its default parameter "add_special_token=True"
|
# LMDeploy tokenize prompts by AutoTokenizer with its default parameter "add_special_token=True"
|
||||||
# OC add bos_token in the prompt, which requires tokenizing prompts using "add_speicial_token=False"
|
# OC add bos_token in the prompt, which requires tokenizing prompts using "add_speicial_token=False"
|
||||||
|
Loading…
Reference in New Issue
Block a user