[Feature] Add general math, llm judge evaluator (#1892)

* update_doc

* update llm_judge

* update README

* update md file name
This commit is contained in:
Linchen Xiao 2025-02-26 15:08:50 +08:00 committed by GitHub
parent fd6fbf01a2
commit bdb2d46f59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1075 additions and 47 deletions

View File

@ -57,6 +57,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.02.15\]** We have added two powerful evaluation tools: `GenericLLMEvaluator` for LLM-as-judge evaluations and `MATHEvaluator` for mathematical reasoning assessments. Check out the documentation for [LLM Judge](docs/en/advanced_guides/llm_judge.md) and [Math Evaluation](docs/en/advanced_guides/general_math.md) for more details! 🔥🔥🔥
- **\[2025.01.16\]** We now support the [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) model which has enhanced performance on reasoning and knowledge-intensive tasks.
- **\[2024.12.17\]** We have provided the evaluation script for the December [CompassAcademic](examples/eval_academic_leaderboard_202412.py), which allows users to easily reproduce the official evaluation results by configuring it.
- **\[2024.11.14\]** OpenCompass now offers support for a sophisticated benchmark designed to evaluate complex reasoning skills — [MuSR](https://arxiv.org/pdf/2310.16049). Check out the [demo](examples/eval_musr.py) and give it a spin! 🔥🔥🔥

View File

@ -57,6 +57,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.02.15\]** 我们新增了两个实用的评测工具用于LLM作为评判器的`GenericLLMEvaluator`和用于数学推理评估的`MATHEvaluator`。查看[LLM评判器](docs/zh_cn/advanced_guides/llm_judge.md)和[数学能力评测](docs/zh_cn/advanced_guides/general_math.md)文档了解更多详情!🔥🔥🔥
- **\[2025.01.16\]** 我们现已支持 [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) 模型,该模型在推理、知识类任务上取得同量级最优性能,欢迎尝试。
- **\[2024.12.17\]** 我们提供了12月CompassAcademic学术榜单评估脚本 [CompassAcademic](configs/eval_academic_leaderboard_202412.py),你可以通过简单地配置复现官方评测结果。
- **\[2024.10.14\]** 现已支持OpenAI多语言问答数据集[MMMLU](https://huggingface.co/datasets/openai/MMMLU),欢迎尝试! 🔥🔥🔥

View File

@ -0,0 +1,252 @@
# LLM as Judge Evaluation
## Introduction
The GenericLLMEvaluator is particularly useful for scenarios where rule-based methods (like regular expressions) cannot perfectly judge outputs, such as:
- Cases where models output answer content without option identifiers
- Factual judgment datasets that are difficult to evaluate with rules
- Open-ended responses requiring complex understanding and reasoning
- Evaluation that requires a lot of rules to be designed
OpenCompass provides the GenericLLMEvaluator component to facilitate LLM-as-judge evaluations.
## Dataset Format
The dataset for LLM judge evaluation should be in either JSON Lines (.jsonl) or CSV format. Each entry should contain at least:
- A problem or question
- A reference answer or gold standard
- (The model's prediction will be generated during evaluation)
Example JSONL format:
```json
{"problem": "What is the capital of France?", "answer": "Paris"}
```
Example CSV format:
```csv
problem,answer
"What is the capital of France?","Paris"
```
## Configuration
To set up an LLM judge evaluation, you'll need to configure three main components:
1. Dataset Reader Configuration
```python
reader_cfg = dict(
input_columns=['problem'], # Column name for the question
output_column='answer' # Column name for the reference answer
)
```
2. Inference Configuration
```python
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}', # Template for prompting the model
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. Evaluation Configuration with LLM Judge
```python
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator, # Using LLM as evaluator
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # Template for the judge
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # Configuration for the judge model
dict_postprocessor=dict(type=generic_llmjudge_postprocess), # Post-processing the judge's output
),
)
```
## Using CustomDataset with GenericLLMEvaluator
Here's how to set up a complete configuration for LLM judge evaluation:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# Import your judge model configuration
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as judge_model,
)
# Define your judge template
JUDGE_TEMPLATE = """
Please evaluate whether the following response correctly answers the question.
Question: {problem}
Reference Answer: {answer}
Model Response: {prediction}
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
""".strip()
# Dataset reader configuration
reader_cfg = dict(input_columns=['problem'], output_column='answer')
# Inference configuration for the model being evaluated
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Evaluation configuration with LLM judge
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=judge_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
pred_role='BOT',
)
# Dataset configuration
datasets = [
dict(
type=CustomDataset,
abbr='my-dataset',
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
]
# Model configuration for the model being evaluated
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='model-to-evaluate',
path='path/to/your/model',
# ... other model configurations
)
]
# Output directory
work_dir = './outputs/llm_judge_eval'
```
## GenericLLMEvaluator
The GenericLLMEvaluator is designed to use an LLM as a judge for evaluating model outputs. Key features include:
1. Flexible prompt templates for instructing the judge
2. Support for various judge models (local or API-based)
3. Customizable evaluation criteria through prompt engineering
4. Post-processing of judge outputs to extract structured evaluations
**Important Note**: The current generic version of the judge template only supports outputs in the format of "A" (correct) or "B" (incorrect), and does not support other output formats (like "CORRECT" or "INCORRECT"). This is because the post-processing function `generic_llmjudge_postprocess` is specifically designed to parse this format.
The evaluator works by:
1. Taking the original problem, reference answer, and model prediction
2. Formatting them into a prompt for the judge model
3. Parsing the judge's response to determine the evaluation result (looking for "A" or "B")
4. Aggregating results across the dataset
If you would like to see the full details of evaluation results, you can add `--dump-eval-details` to the command line when you start the job.
Example evaluation output:
```python
{
'accuracy': 75.0, # Percentage of responses judged as correct
'details': [
{
'origin_prompt': """
Please evaluate whether the following response correctly answers the question.
Question: What is the capital of France?
Reference Answer: Paris
Model Response: Paris
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
""",
'gold': 'Paris',
'prediction': 'A',
},
# ... more results
]
}
```
## Complete Example
For a complete working example, refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving using an LLM judge.

View File

@ -0,0 +1,190 @@
# General Math Evaluation Guidance
## Introduction
Mathematical reasoning is a crucial capability for large language models (LLMs). To evaluate a model's mathematical abilities, we need to test its capability to solve mathematical problems step by step and provide accurate final answers. OpenCompass provides a convenient way to evaluate mathematical reasoning through the CustomDataset and MATHEvaluator components.
## Dataset Format
The math evaluation dataset should be in either JSON Lines (.jsonl) or CSV format. Each problem should contain at least:
- A problem statement
- A solution/answer (typically in LaTeX format with the final answer in \\boxed{})
Example JSONL format:
```json
{"problem": "Find the value of x if 2x + 3 = 7", "solution": "Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"}
```
Example CSV format:
```csv
problem,solution
"Find the value of x if 2x + 3 = 7","Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"
```
## Configuration
To evaluate mathematical reasoning, you'll need to set up three main components:
1. Dataset Reader Configuration
```python
math_reader_cfg = dict(
input_columns=['problem'], # Column name for the question
output_column='solution' # Column name for the answer
)
```
2. Inference Configuration
```python
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. Evaluation Configuration
```python
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator),
)
```
## Using CustomDataset
Here's how to set up a complete configuration for math evaluation:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
math_datasets = [
dict(
type=CustomDataset,
abbr='my-math-dataset', # Dataset abbreviation
path='path/to/your/dataset', # Path to your dataset file
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
```
## MATHEvaluator
The MATHEvaluator is specifically designed to evaluate mathematical answers. It is developed based on the math_verify library, which provides mathematical expression parsing and verification capabilities, supporting extraction and equivalence verification for both LaTeX and general expressions.
The MATHEvaluator implements:
1. Extracts answers from both predictions and references using LaTeX extraction
2. Handles various LaTeX formats and environments
3. Verifies mathematical equivalence between predicted and reference answers
4. Provides detailed evaluation results including:
- Accuracy score
- Detailed comparison between predictions and references
- Parse results of both predicted and reference answers
The evaluator supports:
- Basic arithmetic operations
- Fractions and decimals
- Algebraic expressions
- Trigonometric functions
- Roots and exponents
- Mathematical symbols and operators
Example evaluation output:
```python
{
'accuracy': 85.0, # Percentage of correct answers
'details': [
{
'predictions': 'x = 2', # Parsed prediction
'references': 'x = 2', # Parsed reference
'correct': True # Whether they match
},
# ... more results
]
}
```
## Complete Example
Here's a complete example of how to set up math evaluation:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.openicl.icl_evaluator.math_evaluator import MATHEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# Dataset reader configuration
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
# Inference configuration
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Evaluation configuration
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator),
)
# Dataset configuration
math_datasets = [
dict(
type=CustomDataset,
abbr='my-math-dataset',
path='path/to/your/dataset.jsonl', # or .csv
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
# Model configuration
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='your-model-name',
path='your/model/path',
# ... other model configurations
)
]
# Output directory
work_dir = './outputs/math_eval'
```

View File

@ -39,8 +39,6 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
user_guides/evaluation.md
user_guides/experimentation.md
user_guides/metrics.md
user_guides/summarizer.md
user_guides/corebench.md
.. _Prompt:
.. toctree::
@ -62,16 +60,12 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
advanced_guides/custom_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_lmdeploy.md
advanced_guides/evaluation_lightllm.md
advanced_guides/accelerator_intro.md
advanced_guides/math_verify.md
advanced_guides/llm_judge.md
advanced_guides/code_eval.md
advanced_guides/code_eval_service.md
advanced_guides/prompt_attack.md
advanced_guides/longeval.md
advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/contamination_eval.md
advanced_guides/needleinahaystack_eval.md
.. _Tools:
.. toctree::

View File

@ -0,0 +1,251 @@
# LLM 作为评判器
## 简介
GenericLLMEvaluator组件特别适用于那些难以通过规则式方法如正则表达式进行完美判断的场景例如
- 模型不输出选项标识而只输出选项内容的情况
- 需要事实性判断的数据集
- 需要复杂理解和推理的开放式回答
- 需要设计大量规则的判断
OpenCompass提供了GenericLLMEvaluator组件来实现LLM作为评判器的评估。
## 数据集格式
用于LLM评判的数据集应该是JSON Lines (.jsonl)或CSV格式。每个条目至少应包含
- 问题或任务
- 参考答案或标准答案
- (模型的预测将在评估过程中生成)
JSONL格式示例
```json
{"problem": "法国的首都是什么?", "answer": "巴黎"}
```
CSV格式示例
```csv
problem,answer
"法国的首都是什么?","巴黎"
```
## 配置说明
要设置LLM评判评估你需要配置三个主要组件
1. 数据集读取配置
```python
reader_cfg = dict(
input_columns=['problem'], # 问题列的名称
output_column='answer' # 参考答案列的名称
)
```
2. 推理配置
```python
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}', # 提示模型的模板
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. 使用LLM评判器的评估配置
```python
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator, # 使用LLM作为评估器
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="你是一个负责评估模型输出正确性和质量的助手。",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # 评判器的模板
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # 评判模型的配置
dict_postprocessor=dict(type=generic_llmjudge_postprocess), # 处理评判器输出的后处理器
),
)
```
## 使用CustomDataset和GenericLLMEvaluator
以下是如何设置完整的LLM评判评估配置
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# 导入评判模型配置
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as judge_model,
)
# 定义评判模板
JUDGE_TEMPLATE = """
请评估以下回答是否正确地回答了问题。
问题:{problem}
参考答案:{answer}
模型回答:{prediction}
模型回答是否正确?如果正确,请回答"A";如果不正确,请回答"B"。
""".strip()
# 数据集读取配置
reader_cfg = dict(input_columns=['problem'], output_column='answer')
# 被评估模型的推理配置
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# 使用LLM评判器的评估配置
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="你是一个负责评估模型输出正确性和质量的助手。",
)
],
round=[
dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=judge_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
pred_role='BOT',
)
# 数据集配置
datasets = [
dict(
type=CustomDataset,
abbr='my-dataset',
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
]
# 被评估模型的配置
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='model-to-evaluate',
path='path/to/your/model',
# ... 其他模型配置
)
]
# 输出目录
work_dir = './outputs/llm_judge_eval'
```
## GenericLLMEvaluator
GenericLLMEvaluator专为使用LLM作为评判器评估模型输出而设计。主要特点包括
1. 灵活的提示模板,用于指导评判器
2. 支持各种评判模型本地或基于API
3. 通过提示工程自定义评估标准
4. 对评判器输出进行后处理以提取结构化评估
**重要说明**:目前通用版本的评判模板只支持输出"A"(正确)或"B"(不正确)的格式,不支持其他输出格式(如"正确"或"不正确")。这是因为后处理函数`generic_llmjudge_postprocess`专门设计为解析这种格式。
评估器的工作原理:
1. 获取原始问题、参考答案和模型预测
2. 将它们格式化为评判模型的提示
3. 解析评判器的响应以确定评估结果(寻找"A"或"B"
4. 汇总整个数据集的结果
如果需要查看评估的详细结果,可以在启动任务时添加`--dump-eval-details`到命令行。
评估输出示例:
```python
{
'accuracy': 75.0, # 被判断为正确的回答百分比
'details': [
{
'origin_prompt': """
请评估以下回答是否正确地回答了问题。
问题:法国的首都是什么?
参考答案:巴黎
模型回答:法国的首都是巴黎。
模型回答是否正确?如果正确,请回答"A";如果不正确,请回答"B"。""",
'gold': '巴黎',
'prediction': 'A',
},
# ... 更多结果
]
}
```
## 完整示例
有关完整的工作示例请参考examples目录中的`eval_llm_judge.py`文件该文件演示了如何使用LLM评判器评估数学问题解决能力。

View File

@ -0,0 +1,190 @@
# 数学能力评测
## 简介
数学推理能力是大语言模型(LLMs)的一项关键能力。为了评估模型的数学能力我们需要测试其逐步解决数学问题并提供准确最终答案的能力。OpenCompass 通过 CustomDataset 和 MATHEvaluator 组件提供了一种便捷的数学推理评测方式。
## 数据集格式
数学评测数据集应该是 JSON Lines (.jsonl) 或 CSV 格式。每个问题至少应包含:
- 问题陈述
- 解答/答案(通常使用 LaTeX 格式,最终答案需要用 \\boxed{} 括起来)
JSONL 格式示例:
```json
{"problem": "求解方程 2x + 3 = 7", "solution": "让我们逐步解决:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\n因此\\boxed{2}"}
```
CSV 格式示例:
```csv
problem,solution
"求解方程 2x + 3 = 7","让我们逐步解决:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\n因此\\boxed{2}"
```
## 配置说明
要进行数学推理评测,你需要设置三个主要组件:
1. 数据集读取配置
```python
math_reader_cfg = dict(
input_columns=['problem'], # 问题列的名称
output_column='solution' # 答案列的名称
)
```
2. 推理配置
```python
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\n请逐步推理并将最终答案放在 \\boxed{} 中。',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. 评测配置
```python
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator),
)
```
## 使用 CustomDataset
以下是如何设置完整的数学评测配置:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
math_datasets = [
dict(
type=CustomDataset,
abbr='my-math-dataset', # 数据集简称
path='path/to/your/dataset', # 数据集文件路径
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
```
## MATHEvaluator
MATHEvaluator 是专门设计用于评估数学答案的评测器。它基于 math_verify 库进行开发,该库提供了数学表达式解析和验证功能,支持 LaTeX 和一般表达式的提取与等价性验证。
MATHEvaluator 具有以下功能:
1. 使用 LaTeX 提取器从预测和参考答案中提取答案
2. 处理各种 LaTeX 格式和环境
3. 验证预测答案和参考答案之间的数学等价性
4. 提供详细的评测结果,包括:
- 准确率分数
- 预测和参考答案的详细比较
- 预测和参考答案的解析结果
评测器支持:
- 基本算术运算
- 分数和小数
- 代数表达式
- 三角函数
- 根式和指数
- 数学符号和运算符
评测输出示例:
```python
{
'accuracy': 85.0, # 正确答案的百分比
'details': [
{
'predictions': 'x = 2', # 解析后的预测答案
'references': 'x = 2', # 解析后的参考答案
'correct': True # 是否匹配
},
# ... 更多结果
]
}
```
## 完整示例
以下是设置数学评测的完整示例:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.openicl.icl_evaluator.math_evaluator import MATHEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# 数据集读取配置
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
# 推理配置
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\n请逐步推理并将最终答案放在 \\boxed{} 中。',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# 评测配置
math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator),
)
# 数据集配置
math_datasets = [
dict(
type=CustomDataset,
abbr='my-math-dataset',
path='path/to/your/dataset.jsonl', # 或 .csv
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
# 模型配置
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='your-model-name',
path='your/model/path',
# ... 其他模型配置
)
]
# 输出目录
work_dir = './outputs/math_eval'
```

View File

@ -40,8 +40,6 @@ OpenCompass 上手路线
user_guides/evaluation.md
user_guides/experimentation.md
user_guides/metrics.md
user_guides/summarizer.md
user_guides/corebench.md
.. _提示词:
.. toctree::
@ -62,17 +60,12 @@ OpenCompass 上手路线
advanced_guides/custom_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_lmdeploy.md
advanced_guides/evaluation_lightllm.md
advanced_guides/accelerator_intro.md
advanced_guides/math_verify.md
advanced_guides/llm_judge.md
advanced_guides/code_eval.md
advanced_guides/code_eval_service.md
advanced_guides/prompt_attack.md
advanced_guides/longeval.md
advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/contamination_eval.md
advanced_guides/compassbench_intro.md
advanced_guides/needleinahaystack_eval.md
.. _工具:
.. toctree::

116
examples/eval_llm_judge.py Normal file
View File

@ -0,0 +1,116 @@
from mmengine.config import read_base
from opencompass.models.openai_api import OpenAISDK
# Import pre-configured models from OpenCompass
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import (
models as lmdeploy_qwen2_5_7b_instruct_model,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as lmdeploy_qwen2_5_14b_instruct_model,
)
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.datasets import CustomDataset
# Dataset reader configuration
math_reader_cfg = dict(input_columns=['problem'], output_column='answer')
# Inference configuration
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\nRemember to put your final answer within \\boxed{}.',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Template for the LLM judge
GRADER_TEMPLATE = """
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer.
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
Just return the letters "A" or "B", with no text around it.
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
<Original Question Begin>: \n{problem}\n<Original Question End>\n\n
<Gold Target Begin>: \n{answer}\n<Gold Target End>\n\n
<Predicted Answer Begin>: \n{prediction}\n<Predicted End>\n\n
Judging the correctness of candidates' answers:
""".strip()
# Evaluation configuration using LLM as judge
math_eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=GRADER_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='opencompass/math',
file_name='test_prm800k_500.jsonl',
reader_cfg=math_reader_cfg,
),
judge_cfg=lmdeploy_qwen2_5_14b_instruct_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)
# Dataset configuration
datasets = [
dict(
type=CustomDataset,
path='opencompass/math',
file_name='test_prm800k_500.jsonl',
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
# Model to be evaluated
models = lmdeploy_qwen2_5_7b_instruct_model
# Limiting test to first 8 examples for quick testing
math_reader_cfg['test_range'] = '[0:8]'
# Output directory
work_dir = 'outputs/llm_judge'

View File

@ -3,6 +3,7 @@ import copy
import math
import os
import os.path as osp
import random
import statistics
import sys
import time
@ -37,16 +38,29 @@ class OpenICLEvalTask(BaseTask):
super().__init__(cfg)
self.logger = get_logger()
self.num_gpus = max(
c.get('eval_cfg', {}).get('num_gpus', 0)
max(
c.get('eval_cfg', {}).get('num_gpus', 0),
c.get('eval_cfg', {}).get('evaluator', {}).get(
'judge_cfg', {}).get('run_cfg', {}).get('num_gpus', 0),
) for c in sum(self.dataset_cfgs, []))
self.num_procs = max(
c.get('eval_cfg', {}).get('evaluator', {}).get(
'judge_cfg', {}).get('run_cfg', {}).get('num_procs', 1)
for c in sum(self.dataset_cfgs, []))
self.dump_details = cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('dump_details', False)
self.cal_extract_rate = cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('cal_extract_rate', False)
self.dump_details = (cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('dump_details', False))
self.cal_extract_rate = (cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('cal_extract_rate', False))
def get_command(self, cfg_path, template):
sys.path.append(os.getcwd())
script_path = __file__
if self.num_gpus > 1:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
python = sys.executable
command = f'{python} {script_path} {cfg_path}'
return template.format(task_cmd=command)
@ -63,8 +77,10 @@ class OpenICLEvalTask(BaseTask):
dataset_cfg['reader_cfg']['output_column'])
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
self.model_cfg,
self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
if osp.exists(out_path):
continue
self._score()
@ -86,8 +102,10 @@ class OpenICLEvalTask(BaseTask):
# Load predictions
filename = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
self.model_cfg,
self.dataset_cfg,
osp.join(self.work_dir, 'predictions'),
)
# in case the prediction is partial
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
@ -123,6 +141,7 @@ class OpenICLEvalTask(BaseTask):
and not MODELS.get(self.model_cfg['type']).is_api):
# Create a prompt template for role config parsing
from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(self.model_cfg['meta_template'])
role = parser.roles[self.eval_cfg['pred_role']]
if sc_size is not None:
@ -131,15 +150,19 @@ class OpenICLEvalTask(BaseTask):
'must be list.')
if pred_list_flag:
pred_strs = [[
extract_role_pred(_pred, role.get('begin', None),
role.get('end', None))
for _pred in pred
extract_role_pred(
_pred,
role.get('begin', None),
role.get('end', None),
) for _pred in pred
] for pred in pred_strs]
else:
pred_strs = [
extract_role_pred(pred, role.get('begin', None),
role.get('end', None))
for pred in pred_strs
extract_role_pred(
pred,
role.get('begin', None),
role.get('end', None),
) for pred in pred_strs
]
# Postprocess predictions if necessary
@ -195,8 +218,10 @@ class OpenICLEvalTask(BaseTask):
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
# need results dir to save other files
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
self.model_cfg,
self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
icl_evaluator._out_dir = osp.splitext(out_path)[
0] # strip extension
@ -235,9 +260,13 @@ class OpenICLEvalTask(BaseTask):
details = result.get('details', None)
try:
result['details'] = self.format_details(
pred_strs, model_pred_strs,
test_set[self.output_column], details, model_details,
pred_dicts)
pred_strs,
model_pred_strs,
test_set[self.output_column],
details,
model_details,
pred_dicts,
)
self.logger.warning(
f"result['details'] : {result['details']}"),
result['type'] = result['details'].pop('type', None)
@ -247,8 +276,8 @@ class OpenICLEvalTask(BaseTask):
if 'PPL' in str(
self.dataset_cfg.infer_cfg.inferencer.type):
result['correct_bpb'], result['incorrect_bpb'] = \
self.calculate_bpb(pred_dicts)
result['correct_bpb'], result['incorrect_bpb'] = (
self.calculate_bpb(pred_dicts))
except Exception as e:
self.logger.warning(f'Skip dumping details due to: {e}.')
else:
@ -281,8 +310,11 @@ class OpenICLEvalTask(BaseTask):
f'{task_abbr_from_cfg(self.cfg)}:{model_result_wo_details}')
# Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
out_path = get_infer_output_path(
self.model_cfg,
self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
@ -305,8 +337,15 @@ class OpenICLEvalTask(BaseTask):
success_rate = 100 - len(invalid_extractions) / len(details) * 100
return success_rate
def format_details(self, predictions, model_pred_strs, references, details,
model_details, pred_dicts):
def format_details(
self,
predictions,
model_pred_strs,
references,
details,
model_details,
pred_dicts,
):
"""This function is responsible for formatting prediction details.
Args:
@ -344,8 +383,9 @@ class OpenICLEvalTask(BaseTask):
result['references'] = str(references[i])
result['correct'] = str(predictions[i]) == str(references[i])
elif details is not None and model_details is not None:
assert model_pred_strs != [], \
'Model details is not None, but model_pred_strs is empty'
assert (
model_pred_strs != []
), 'Model details is not None, but model_pred_strs is empty'
self.logger.info(
f"model_details[i]['pred']: {model_details[i]['pred']}")
results['type'] = 'GEN'