[Feature] Add CascadeEvaluator (#1992)

* [Feature] Add CascadeEvaluator

* update

* updat
This commit is contained in:
Linchen Xiao 2025-04-08 11:58:14 +08:00 committed by GitHub
parent b564e608b1
commit bb58cfc85d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 681 additions and 17 deletions

View File

@ -57,6 +57,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.04.01\]** OpenCompass now supports `CascadeEvaluator`, a flexible evaluation mechanism that allows multiple evaluators to work in sequence. This enables creating customized evaluation pipelines for complex assessment scenarios. Check out the [documentation](docs/en/advanced_guides/llm_judge.md) for more details! 🔥🔥🔥
- **\[2025.03.11\]** We have supported evaluation for `SuperGPQA` which is a great benchmark for measuring LLM knowledge ability 🔥🔥🔥
- **\[2025.02.28\]** We have added a tutorial for `DeepSeek-R1` series model, please check [Evaluating Reasoning Model](docs/en/user_guides/deepseek_r1.md) for more details! 🔥🔥🔥
- **\[2025.02.15\]** We have added two powerful evaluation tools: `GenericLLMEvaluator` for LLM-as-judge evaluations and `MATHEvaluator` for mathematical reasoning assessments. Check out the documentation for [LLM Judge](docs/en/advanced_guides/llm_judge.md) and [Math Evaluation](docs/en/advanced_guides/general_math.md) for more details! 🔥🔥🔥

View File

@ -57,8 +57,9 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.04.01\]** OpenCompass 现已支持 `CascadeEvaluator`,允许多个评估器按顺序工作,可以为更复杂的评估场景创建自定义评估流程,查看[文档](docs/zh_cn/advanced_guides/llm_judge.md)了解具体用法!🔥🔥🔥
- **\[2025.03.11\]** 现已支持 `SuperGPQA` 覆盖285 个研究生学科的知识能力评测,欢迎尝试!🔥🔥🔥
- **\[2025.02.28\]** 我们为 `DeepSeek-R1` 系列模型添加了教程,请查看 [评估推理模型](docs/en/user_guides/deepseek_r1.md) 了解更多详情!🔥🔥🔥
- **\[2025.02.28\]** 我们为 `DeepSeek-R1` 系列模型添加了教程,请查看 [评估推理模型](docs/zh_cn/user_guides/deepseek_r1.md) 了解更多详情!🔥🔥🔥
- **\[2025.02.15\]** 我们新增了两个实用的评测工具用于LLM作为评判器的`GenericLLMEvaluator`和用于数学推理评估的`MATHEvaluator`。查看[LLM评判器](docs/zh_cn/advanced_guides/llm_judge.md)和[数学能力评测](docs/zh_cn/advanced_guides/general_math.md)文档了解更多详情!🔥🔥🔥
- **\[2025.01.16\]** 我们现已支持 [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) 模型,该模型在推理、知识类任务上取得同量级最优性能,欢迎尝试。
- **\[2024.12.17\]** 我们提供了12月CompassAcademic学术榜单评估脚本 [CompassAcademic](configs/eval_academic_leaderboard_202412.py),你可以通过简单地配置复现官方评测结果。

View File

@ -49,7 +49,7 @@ export OC_JUDGE_API_BASE=http://172.30.56.1:4000/v1
Note that by default, OpenCompass will use these three environment variables, but if you use configuration files to configure the evaluation service, these environment variables will not take effect.
### ### Using LLM for Evaluation via Configuration Files
### Using LLM for Evaluation via Configuration Files
To set up an LLM judge evaluation, you'll need to configure three main components:
@ -264,6 +264,107 @@ Example evaluation output:
}
```
## CascadeEvaluator
OpenCompass also provides a CascadeEvaluator that combines the strengths of rule-based evaluation and LLM-based evaluation. The cascade evaluator has two modes:
1. **Cascade Mode (parallel=False)**: First evaluates all samples with a rule-based evaluator, then only sends samples that were deemed incorrect by the rule-based evaluation to an LLM judge for re-evaluation. This approach reduces reliance on LLM judgments while maintaining accuracy, thus lowering evaluation costs and time.
2. **Parallel Mode (parallel=True)**: Evaluates all samples with both the rule-based evaluator and LLM judge, then considers a sample correct if either method marks it as correct. This approach can increase the leniency of evaluation but may result in higher costs since all samples require LLM evaluation.
### Configuring CascadeEvaluator
Here's an example of how to configure the CascadeEvaluator:
```python
# Define a rule-based evaluator
rule_evaluator = dict(type=MATHEvaluator)
# Define an LLM judge evaluator
llm_judge_evaluator = dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=YourDataset,
path='path/to/your/dataset',
reader_cfg=reader_cfg,
),
judge_cfg=dict(), # Can use environment variables to configure the judge model
)
# Configure cascade evaluator (cascade mode)
cascade_evaluator = dict(
type=CascadeEvaluator,
llm_evaluator=llm_judge_evaluator,
rule_evaluator=rule_evaluator,
parallel=False # Cascade mode
)
# For parallel mode, set parallel=True
parallel_evaluator = dict(
type=CascadeEvaluator,
llm_evaluator=llm_judge_evaluator,
rule_evaluator=rule_evaluator,
parallel=True # Parallel mode
)
# Use the cascade evaluator in your dataset evaluation config
eval_cfg = dict(evaluator=cascade_evaluator)
```
### Evaluation Results
The cascade evaluator outputs detailed evaluation statistics including:
- Accuracy of the rule-based evaluation
- Accuracy of the LLM evaluation (for samples that failed rule-based evaluation in cascade mode)
- Final combined accuracy
Example output:
```python
{
'accuracy': 85.0, # Final accuracy
'cascade_stats': {
'total_samples': 100,
'rule_correct': 70, # Number of samples correct by rule evaluation
'rule_accuracy': 70.0, # Accuracy of rule evaluation
'llm_evaluated': 30, # Number of samples evaluated by LLM (failed samples in cascade mode)
'llm_correct': 15, # Number of samples correct by LLM evaluation
'llm_accuracy': 50.0, # Accuracy of LLM evaluation
'final_correct': 85, # Total correct samples
'final_accuracy': 85.0, # Final accuracy
'parallel_mode': False, # Whether parallel mode was used
},
'details': [
# Detailed evaluation results for each sample
]
}
```
The cascade evaluator is particularly useful for:
1. Scenarios that require balancing evaluation cost and accuracy
2. Cases where rule-based evaluators are available but might not be comprehensive
3. Evaluation tasks that need more nuanced judgment for edge cases
## Complete Example
For a complete working example, refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving using an LLM judge.
For a complete working example using GenericLLMEvaluator
, refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .
For a complete working example using CascadeEvaluator, refer to the `eval_cascade_evaluator.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .

View File

@ -263,6 +263,106 @@ GenericLLMEvaluator专为使用LLM作为评判器评估模型输出而设计。
}
```
## 级联评估器 (CascadeEvaluator)
OpenCompass还提供了级联评估器`CascadeEvaluator`它结合了规则式评估和LLM评估的优势。级联评估器有两种模式
1. **级联模式Cascade Mode, parallel=False**首先使用规则式评估器评估所有样本然后只将规则式评估认为不正确的样本发送给LLM评判器进行重新评估。这种方式可以在保持准确性的同时减少对LLM评判的依赖从而降低评估成本和时间。
2. **并行模式Parallel Mode, parallel=True**使用规则式评估器和LLM评判器同时评估所有样本如果任何一个评估器认为样本是正确的则将该样本视为正确。这种方式可以提高评估的宽容度但可能会导致更高的成本因为所有样本都需要LLM评估。
### 配置CascadeEvaluator
以下是配置`CascadeEvaluator`的示例:
```python
# 定义规则式评估器
rule_evaluator = dict(type=MATHEvaluator)
# 定义LLM评判器
llm_judge_evaluator = dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="你是一个负责评估模型输出正确性和质量的助手。",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=YourDataset,
path='path/to/your/dataset',
reader_cfg=reader_cfg,
),
judge_cfg=dict(), # 可以使用环境变量配置评判模型
)
# 配置级联评估器(级联模式)
cascade_evaluator = dict(
type=CascadeEvaluator,
llm_evaluator=llm_judge_evaluator,
rule_evaluator=rule_evaluator,
parallel=False # 级联模式
)
# 如果需要并行模式可以设置parallel=True
parallel_evaluator = dict(
type=CascadeEvaluator,
llm_evaluator=llm_judge_evaluator,
rule_evaluator=rule_evaluator,
parallel=True # 并行模式
)
# 在数据集评估配置中使用级联评估器
eval_cfg = dict(evaluator=cascade_evaluator)
```
### 评估结果
级联评估器会输出详细的评估统计信息,包括:
- 规则评估的准确率
- LLM评估的准确率针对规则评估失败的样本
- 最终的综合准确率
输出示例:
```python
{
'accuracy': 85.0, # 最终准确率
'cascade_stats': {
'total_samples': 100,
'rule_correct': 70, # 规则评估认为正确的样本数
'rule_accuracy': 70.0, # 规则评估的准确率
'llm_evaluated': 30, # LLM评估的样本数级联模式下为规则评估失败的样本数
'llm_correct': 15, # LLM评估认为正确的样本数
'llm_accuracy': 50.0, # LLM评估的准确率
'final_correct': 85, # 最终正确的样本数
'final_accuracy': 85.0, # 最终准确率
'parallel_mode': False, # 是否是并行模式
},
'details': [
# 每个样本的详细评估结果
]
}
```
级联评估器特别适用于:
1. 需要平衡评估成本和准确性的场景
2. 有可用的规则式评估器但可能不够完善的情况
3. 需要对边界情况进行更精确判断的评估任务
## 完整示例
有关完整的工作示例请参考examples目录中的`eval_llm_judge.py`文件该文件演示了如何使用LLM评判器评估数学问题解决能力。
如果希望了解通用LLM评判器请参考examples目录中的`eval_llm_judge.py`文件该示例展示了如何使用LLM评判器评估数学问题。
如果希望了解级联评估器请参考examples目录中的`eval_cascade_evaluator.py`文件,该示例展示了如何使用级联评估器评估数学问题。

View File

@ -0,0 +1,127 @@
from mmengine.config import read_base
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.evaluator import GenericLLMEvaluator, CascadeEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_evaluator import MATHEvaluator
from opencompass.datasets import (
MATHDataset,
math_postprocess_v2,
normalize_final_answer,
)
#######################################################################
# PART 0 Essential Configs #
#######################################################################
with read_base():
# Datasets, Summarizer
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import (
models as lmdeploy_qwen2_5_7b_instruct_model,
)
reader_cfg = dict(input_columns=['problem'], output_column='solution')
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
########################## Evaluator #################################
GRADER_TEMPLATE = """
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer.
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
Just return the letters "A" or "B", with no text around it.
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
<Original Question Begin>: \n{problem}\n<Original Question End>\n\n
<Gold Target Begin>: \n{solution}\n<Gold Target End>\n\n
<Predicted Answer Begin>: \n{prediction}\n<Predicted End>\n\n
Judging the correctness of candidates' answers:
""".strip()
llm_judge_evaluator = dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=GRADER_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=MATHDataset,
path='opencompass/math',
file_name='test_prm800k_500.json',
),
judge_cfg=dict(),
)
rule_evaluator =dict(type=MATHEvaluator)
cascade_evaluator = dict(type=CascadeEvaluator,
llm_evaluator=llm_judge_evaluator,
rule_evaluator=rule_evaluator,
parallel=False
)
########################## #################################
eval_cfg = dict()
# eval_cfg['evaluator'] = rule_evaluator
# eval_cfg['evaluator'] = llm_judge_evaluator
eval_cfg['evaluator'] = cascade_evaluator
math_datasets = [
dict(
abbr='math_prm800k_500',
type=MATHDataset,
path='opencompass/math',
file_name='test_prm800k_500.json',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
]
datasets = math_datasets
models = lmdeploy_qwen2_5_7b_instruct_model
work_dir = 'math_prm800k_500_cascade_evaluator'

View File

@ -1 +1,2 @@
from .cascade_evaluator import CascadeEvaluator # noqa
from .generic_llm_evaluator import GenericLLMEvaluator # noqa

View File

@ -0,0 +1,302 @@
import os
from typing import Any, Callable, Dict, List, Optional
import mmengine
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS
from opencompass.utils.logging import get_logger
@ICL_EVALUATORS.register_module()
class CascadeEvaluator(BaseEvaluator):
"""Cascade Evaluator.
First uses a rule-based method to judge predictions.
If a sample is marked as incorrect by the rule-based method,
then it uses an LLM judge to re-evaluate it.
Arguments:
llm_evaluator (dict): Configuration for the LLM evaluator.
rule_evaluator (Optional[dict]): Configuration for the
rule-based evaluator.
sample_score_fn (Optional[Callable]): A function to
score individual samples. If provided without rule_evaluator,
this function will be used directly.
parallel (bool): Whether to run in parallel mode.
"""
def __init__(
self,
llm_evaluator: Dict,
rule_evaluator: Optional[Dict] = None,
sample_score_fn: Optional[Callable] = None,
parallel: bool = True,
) -> None:
self.logger = get_logger()
# Initialize the LLM evaluator
llm_evaluator_type = llm_evaluator.pop('type')
if isinstance(llm_evaluator_type, str):
llm_evaluator_type = ICL_EVALUATORS.get(llm_evaluator_type)
self.llm_evaluator = llm_evaluator_type(**llm_evaluator)
# Initialize the rule evaluator if provided
self.rule_evaluator = None
if rule_evaluator:
rule_evaluator_type = rule_evaluator.pop('type')
if isinstance(rule_evaluator_type, str):
rule_evaluator_type = ICL_EVALUATORS.get(rule_evaluator_type)
self.rule_evaluator = rule_evaluator_type(**rule_evaluator)
self.sample_score_fn = sample_score_fn
self.parallel = parallel
# At least one of rule_evaluator or sample_score_fn must be provided
if not self.rule_evaluator and not self.sample_score_fn:
raise ValueError(
'Either rule_evaluator or sample_score_fn must be provided')
def sample_score(self, prediction: str, reference: str) -> Dict[str, Any]:
"""Score a single sample using sample_score_fn or rule_evaluator.
Args:
prediction: The model's prediction.
reference: The ground truth.
Returns:
Dict: A dictionary containing the score and other details.
"""
if self.sample_score_fn:
# Use user-provided function to evaluate a single sample
result = self.sample_score_fn(prediction, reference)
if not isinstance(result, dict):
# Ensure result is a dictionary with at least 'correct' field
result = {
'correct': bool(result),
'pred': prediction,
'answer': reference,
}
return result
else:
# Use rule_evaluator to evaluate a single sample by calling
# the score method with single-element lists
result = self.rule_evaluator.score([prediction], [reference])
if 'details' in result and len(result['details']) > 0:
return result['details'][0]
else:
# Fallback if rule_evaluator doesn't provide detailed results
return {
'correct': result.get('accuracy', 0) > 0,
'pred': prediction,
'answer': reference,
}
def _get_llm_correctness(self, llm_detail):
"""Determine if the LLM judge considers the answer correct.
Args:
llm_detail: The evaluation details from the LLM judge.
Returns:
bool: Whether the answer is correct according to the LLM judge.
"""
if 'prediction' in llm_detail:
response = llm_detail['prediction'].strip().upper()
return response == 'A' or response.startswith('CORRECT')
elif 'correct' in llm_detail:
return llm_detail['correct']
elif 'score' in llm_detail:
return llm_detail['score'] > 0.5
return False
def score(
self,
predictions: List[str],
references: List[str],
test_set: Optional[Dataset] = None,
) -> Dict[str, Any]:
"""Score predictions using cascade or parallel evaluation.
Args:
predictions: List of model predictions.
references: List of ground truths.
test_set: Huggingface Dataset containing original test samples.
Returns:
Dict: A dictionary containing the scores and details.
"""
self.logger.info(
f"Running {'parallel' if self.parallel else 'cascade'} evaluation")
# Step 1: Evaluate each sample individually using rule-based evaluation
details = []
failed_predictions = []
failed_references = []
failed_indices = []
for i, (pred, ref) in enumerate(zip(predictions, references)):
result = self.sample_score(pred, ref)
result['evaluation_method'] = 'rule'
details.append({'rule_evaluation': result})
# If the sample failed rule-based evaluation or in parallel
# mode, mark it for LLM evaluation
if not result.get('correct', False) or self.parallel:
failed_predictions.append(pred)
failed_references.append(ref)
failed_indices.append(i)
# Calculate initial accuracy based on rule evaluation
initial_correct = sum(
1 for detail in details
if detail['rule_evaluation'].get('correct', False))
initial_accuracy = (100 * initial_correct /
len(predictions) if predictions else 0)
self.logger.info(
f'Rule-based evaluation: {initial_correct}/{len(predictions)} '
f'correct ({initial_accuracy:.2f}%)')
eval_mode = ('parallel (all samples)'
if self.parallel else 'cascade (only failed samples)')
self.logger.info(f'Samples requiring LLM evaluation ({eval_mode}): '
f'{len(failed_indices)}')
# Step 2: If there are samples for LLM evaluation
if failed_predictions and test_set is not None:
self.logger.info(f'Running LLM evaluation in {eval_mode} mode...')
# Create a subset of the test_set for LLM evaluation
failed_subset = test_set.select(failed_indices)
# Add prediction and reference columns to the dataset
failed_subset = failed_subset.add_column('prediction',
failed_predictions)
failed_subset = failed_subset.add_column('reference',
failed_references)
# Set a custom output path for LLM evaluation
original_out_dir = getattr(self.llm_evaluator, '_out_dir', None)
self.llm_evaluator._out_dir = f'{self._out_dir}_llm_judge'
# Check if results already exist to avoid re-evaluation
llm_results_path = f'{self.llm_evaluator._out_dir}.json'
if os.path.exists(llm_results_path):
self.logger.info(
f'Loading existing LLM evaluation results from '
f'{llm_results_path}')
llm_results = mmengine.load(llm_results_path)
# Extract details from loaded results
if llm_results.get('details', []):
loaded_details = llm_results['details']
else:
loaded_details = llm_results
# Strictly verify that the loaded results match
# the current evaluation needs
if len(loaded_details) != len(failed_indices):
error_msg = (
f'Error: Loaded LLM results contain '
f'{len(loaded_details)} samples, but current '
f'evaluation requires {len(failed_indices)} samples. '
f"The cached results at {llm_results_path} don't match"
f'the current evaluation needs. '
f'Please remove the cache file or fix the mismatch.')
self.logger.error(error_msg)
raise ValueError(error_msg)
else:
# Use GenericLLMEvaluator to evaluate samples
# unset dataset_cfg for GenericLLMEvaluator to
# directly use test_set
self.llm_evaluator.dataset_cfg = None
llm_results = self.llm_evaluator.score(
predictions=failed_predictions,
references=failed_references,
test_set=failed_subset,
)
# Restore original output directory
if original_out_dir:
self.llm_evaluator._out_dir = original_out_dir
if llm_results.get('details', []):
llm_details = llm_results['details']
else:
llm_details = llm_results
# Initialize counters for accuracy calculation
final_correct = initial_correct if not self.parallel else 0
llm_correct = 0
llm_evaluated = 0
# Update the details for samples that were evaluated by LLM
for i, llm_detail in enumerate(llm_details.values()):
original_index = failed_indices[i]
# Store original rule-based evaluation result
rule_result = details[original_index].copy()
rule_correct = rule_result['rule_evaluation'].get(
'correct', False)
# Add LLM evaluation details
details[original_index]['llm_evaluation'] = llm_detail
# Determine LLM correctness judgment and store it
is_correct = self._get_llm_correctness(llm_detail)
details[original_index]['llm_evaluation'][
'llm_correct'] = is_correct
# Count LLM evaluation statistics
llm_evaluated += 1
if is_correct:
llm_correct += 1
# Update final_correct counter based on evaluation mode
if self.parallel:
# In parallel mode, either rule-based or LLM evaluations
# should be correct
if rule_correct or is_correct:
final_correct += 1
else:
# In cascade mode, if rule was incorrect but LLM
# correct, increment
# (rule correct samples are already counted
# in initial_correct)
if not rule_correct and is_correct:
final_correct += 1
# Calculate final accuracy
final_accuracy = (100 * final_correct /
len(predictions) if predictions else 0)
llm_accuracy = (100 * llm_correct /
llm_evaluated if llm_evaluated else 0)
self.logger.info(
f'Final evaluation: {final_correct}/{len(predictions)}'
f'correct ({final_accuracy:.2f}%)')
if llm_evaluated > 0:
self.logger.info(
f'LLM evaluation: {llm_correct}/{llm_evaluated} '
f'correct ({llm_accuracy:.2f}%)')
result = {
'accuracy': final_accuracy,
'cascade_stats': {
'total_samples': len(predictions),
'rule_correct': initial_correct,
'rule_accuracy': initial_accuracy,
'llm_evaluated': llm_evaluated,
'llm_correct': llm_correct,
'llm_accuracy': llm_accuracy,
'final_correct': final_correct,
'final_accuracy': final_accuracy,
'parallel_mode': self.parallel,
},
'details': details,
}
return result

View File

@ -3,6 +3,7 @@ import os.path as osp
from typing import Dict, List, Optional
import mmengine
from datasets import Dataset
from mmengine.config import ConfigDict
from opencompass.openicl.icl_evaluator import BaseEvaluator
@ -82,10 +83,19 @@ class GenericLLMEvaluator(BaseEvaluator):
self,
predictions,
references: Optional[List] = None,
test_set: Optional[Dataset] = None,
) -> Dict:
"""Apply to single-model scoring."""
"""Apply to single-model scoring.
Args:
predictions: List of model predictions
references: List of reference answers
test_set: Optional Dataset containing additional
context for evaluation
"""
assert len(predictions) == len(
references), 'predictions and references must have the same length'
# -------------- Build Inferencer ----------------
self.build_inferencer()
@ -93,9 +103,7 @@ class GenericLLMEvaluator(BaseEvaluator):
predictions = self.pred_postprocess(predictions)
# For Single Round Dialogue
prediction_dict = {}
prediction_dict['prediction'] = predictions
prediction_dict['obj_gold'] = references
prediction_dict = {'prediction': predictions, 'obj_gold': references}
# ---------------- Build Dataset for LLM Judge -----------------
if self.dataset_cfg:
@ -109,19 +117,42 @@ class GenericLLMEvaluator(BaseEvaluator):
dataset.reader.dataset['test'] = dataset.test.add_column(
'reference', references)
else:
# build a default dataset just for comparison
# Handle test_set in the else branch
from opencompass.datasets.lmeval import LMEvalDataset
input_columns = list(prediction_dict.keys())
if references:
input_columns.append('reference')
if test_set is not None:
# If test_set is provided, use it as the base
# Ensure necessary columns exist
if 'prediction' not in test_set.column_names:
test_set = test_set.add_column('prediction', predictions)
if 'reference' not in test_set.column_names:
test_set = test_set.add_column('reference', references)
# Prepare input_columns and data dictionary
input_columns = test_set.column_names
data_dict = {
column: test_set[column]
for column in test_set.column_names
}
else:
# Original default dataset building logic
input_columns = list(prediction_dict.keys())
if references:
input_columns.append('reference')
data_dict = prediction_dict.copy()
if references:
data_dict['reference'] = references
# Create LMEvalDataset
dataset = LMEvalDataset(
reader_cfg=dict(input_columns=input_columns,
output_column=None,
train_split='test'),
reference=references,
**prediction_dict,
reader_cfg=dict(
input_columns=input_columns,
output_column=None,
train_split='test',
),
**data_dict,
)
dataset.reader.output_column = 'reference'
retriever = ZeroRetriever(dataset)
# ----------------- LLM Judge ----------------