update llm_judge

This commit is contained in:
MaiziXiao 2025-02-25 10:49:06 +00:00
parent 5b9f4a4e7b
commit d369dfe30f
6 changed files with 695 additions and 36 deletions

View File

@ -0,0 +1,253 @@
# LLM as Judge Evaluation
## Introduction
The GenericLLMEvaluator is particularly useful for scenarios where rule-based methods (like regular expressions) cannot perfectly judge outputs, such as:
- Cases where models output answer content without option identifiers
- Factual judgment datasets that are difficult to evaluate with rules
- Open-ended responses requiring complex understanding and reasoning
- Evaluation that requires a lot of rules to be designed
OpenCompass provides the GenericLLMEvaluator component to facilitate LLM-as-judge evaluations.
## Dataset Format
The dataset for LLM judge evaluation should be in either JSON Lines (.jsonl) or CSV format. Each entry should contain at least:
- A problem or question
- A reference answer or gold standard
- (The model's prediction will be generated during evaluation)
Example JSONL format:
```json
{"problem": "What is the capital of France?", "answer": "Paris"}
```
Example CSV format:
```csv
problem,answer
"What is the capital of France?","Paris"
```
## Configuration
To set up an LLM judge evaluation, you'll need to configure three main components:
1. Dataset Reader Configuration
```python
reader_cfg = dict(
input_columns=['problem'], # Column name for the question
output_column='answer' # Column name for the reference answer
)
```
2. Inference Configuration
```python
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}', # Template for prompting the model
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. Evaluation Configuration with LLM Judge
```python
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator, # Using LLM as evaluator
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # Template for the judge
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # Configuration for the judge model
dict_postprocessor=dict(type=generic_llmjudge_postprocess), # Post-processing the judge's output
),
pred_role='BOT',
)
```
## Using CustomDataset with GenericLLMEvaluator
Here's how to set up a complete configuration for LLM judge evaluation:
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# Import your judge model configuration
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as judge_model,
)
# Define your judge template
JUDGE_TEMPLATE = """
Please evaluate whether the following response correctly answers the question.
Question: {problem}
Reference Answer: {answer}
Model Response: {prediction}
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
""".strip()
# Dataset reader configuration
reader_cfg = dict(input_columns=['problem'], output_column='answer')
# Inference configuration for the model being evaluated
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Evaluation configuration with LLM judge
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=judge_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
pred_role='BOT',
)
# Dataset configuration
datasets = [
dict(
type=CustomDataset,
abbr='my-dataset',
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
]
# Model configuration for the model being evaluated
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='model-to-evaluate',
path='path/to/your/model',
# ... other model configurations
)
]
# Output directory
work_dir = './outputs/llm_judge_eval'
```
## GenericLLMEvaluator
The GenericLLMEvaluator is designed to use an LLM as a judge for evaluating model outputs. Key features include:
1. Flexible prompt templates for instructing the judge
2. Support for various judge models (local or API-based)
3. Customizable evaluation criteria through prompt engineering
4. Post-processing of judge outputs to extract structured evaluations
**Important Note**: The current generic version of the judge template only supports outputs in the format of "A" (correct) or "B" (incorrect), and does not support other output formats (like "CORRECT" or "INCORRECT"). This is because the post-processing function `generic_llmjudge_postprocess` is specifically designed to parse this format.
The evaluator works by:
1. Taking the original problem, reference answer, and model prediction
2. Formatting them into a prompt for the judge model
3. Parsing the judge's response to determine the evaluation result (looking for "A" or "B")
4. Aggregating results across the dataset
If you would like to see the full details of evaluation results, you can add `--dump-eval-details` to the command line when you start the job.
Example evaluation output:
```python
{
'accuracy': 75.0, # Percentage of responses judged as correct
'details': [
{
'origin_prompt': """
Please evaluate whether the following response correctly answers the question.
Question: What is the capital of France?
Reference Answer: Paris
Model Response: Paris
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
""",
'gold': 'Paris',
'prediction': 'A',
},
# ... more results
]
}
```
## Complete Example
For a complete working example, refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving using an LLM judge.

View File

@ -39,7 +39,6 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
user_guides/evaluation.md user_guides/evaluation.md
user_guides/experimentation.md user_guides/experimentation.md
user_guides/metrics.md user_guides/metrics.md
user_guides/summarizer.md
.. _Prompt: .. _Prompt:
.. toctree:: .. toctree::
@ -60,13 +59,13 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
advanced_guides/new_dataset.md advanced_guides/new_dataset.md
advanced_guides/custom_dataset.md advanced_guides/custom_dataset.md
advanced_guides/new_model.md advanced_guides/new_model.md
advanced_guides/evaluation_lmdeploy.md
advanced_guides/accelerator_intro.md advanced_guides/accelerator_intro.md
advanced_guides/general_math.md advanced_guides/general_math.md
advanced_guides/llm_judge.md
advanced_guides/code_eval.md advanced_guides/code_eval.md
advanced_guides/code_eval_service.md advanced_guides/code_eval_service.md
advanced_guides/subjective_evaluation.md advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/needleinahaystack_eval.md
.. _Tools: .. _Tools:
.. toctree:: .. toctree::

View File

@ -0,0 +1,252 @@
# LLM 作为评判器
## 简介
GenericLLMEvaluator组件特别适用于那些难以通过规则式方法如正则表达式进行完美判断的场景例如
- 模型不输出选项标识而只输出选项内容的情况
- 需要事实性判断的数据集
- 需要复杂理解和推理的开放式回答
- 需要设计大量规则的判断
OpenCompass提供了GenericLLMEvaluator组件来实现LLM作为评判器的评估。
## 数据集格式
用于LLM评判的数据集应该是JSON Lines (.jsonl)或CSV格式。每个条目至少应包含
- 问题或任务
- 参考答案或标准答案
- (模型的预测将在评估过程中生成)
JSONL格式示例
```json
{"problem": "法国的首都是什么?", "answer": "巴黎"}
```
CSV格式示例
```csv
problem,answer
"法国的首都是什么?","巴黎"
```
## 配置说明
要设置LLM评判评估你需要配置三个主要组件
1. 数据集读取配置
```python
reader_cfg = dict(
input_columns=['problem'], # 问题列的名称
output_column='answer' # 参考答案列的名称
)
```
2. 推理配置
```python
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}', # 提示模型的模板
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
```
3. 使用LLM评判器的评估配置
```python
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator, # 使用LLM作为评估器
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="你是一个负责评估模型输出正确性和质量的助手。",
)
],
round=[
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # 评判器的模板
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # 评判模型的配置
dict_postprocessor=dict(type=generic_llmjudge_postprocess), # 处理评判器输出的后处理器
),
pred_role='BOT',
)
```
## 使用CustomDataset和GenericLLMEvaluator
以下是如何设置完整的LLM评判评估配置
```python
from mmengine.config import read_base
from opencompass.models import TurboMindModelwithChatTemplate
from opencompass.datasets import CustomDataset
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
# 导入评判模型配置
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as judge_model,
)
# 定义评判模板
JUDGE_TEMPLATE = """
请评估以下回答是否正确地回答了问题。
问题:{problem}
参考答案:{answer}
模型回答:{prediction}
模型回答是否正确?如果正确,请回答"A";如果不正确,请回答"B"。
""".strip()
# 数据集读取配置
reader_cfg = dict(input_columns=['problem'], output_column='answer')
# 被评估模型的推理配置
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# 使用LLM评判器的评估配置
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="你是一个负责评估模型输出正确性和质量的助手。",
)
],
round=[
dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
),
judge_cfg=judge_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
pred_role='BOT',
)
# 数据集配置
datasets = [
dict(
type=CustomDataset,
abbr='my-dataset',
path='path/to/your/dataset',
file_name='your_dataset.jsonl',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
]
# 被评估模型的配置
models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='model-to-evaluate',
path='path/to/your/model',
# ... 其他模型配置
)
]
# 输出目录
work_dir = './outputs/llm_judge_eval'
```
## GenericLLMEvaluator
GenericLLMEvaluator专为使用LLM作为评判器评估模型输出而设计。主要特点包括
1. 灵活的提示模板,用于指导评判器
2. 支持各种评判模型本地或基于API
3. 通过提示工程自定义评估标准
4. 对评判器输出进行后处理以提取结构化评估
**重要说明**:目前通用版本的评判模板只支持输出"A"(正确)或"B"(不正确)的格式,不支持其他输出格式(如"正确"或"不正确")。这是因为后处理函数`generic_llmjudge_postprocess`专门设计为解析这种格式。
评估器的工作原理:
1. 获取原始问题、参考答案和模型预测
2. 将它们格式化为评判模型的提示
3. 解析评判器的响应以确定评估结果(寻找"A"或"B"
4. 汇总整个数据集的结果
如果需要查看评估的详细结果,可以在启动任务时添加`--dump-eval-details`到命令行。
评估输出示例:
```python
{
'accuracy': 75.0, # 被判断为正确的回答百分比
'details': [
{
'origin_prompt': """
请评估以下回答是否正确地回答了问题。
问题:法国的首都是什么?
参考答案:巴黎
模型回答:法国的首都是巴黎。
模型回答是否正确?如果正确,请回答"A";如果不正确,请回答"B"。""",
'gold': '巴黎',
'prediction': 'A',
},
# ... 更多结果
]
}
```
## 完整示例
有关完整的工作示例请参考examples目录中的`eval_llm_judge.py`文件该文件演示了如何使用LLM评判器评估数学问题解决能力。

View File

@ -40,7 +40,6 @@ OpenCompass 上手路线
user_guides/evaluation.md user_guides/evaluation.md
user_guides/experimentation.md user_guides/experimentation.md
user_guides/metrics.md user_guides/metrics.md
user_guides/summarizer.md
.. _提示词: .. _提示词:
.. toctree:: .. toctree::
@ -60,13 +59,13 @@ OpenCompass 上手路线
advanced_guides/new_dataset.md advanced_guides/new_dataset.md
advanced_guides/custom_dataset.md advanced_guides/custom_dataset.md
advanced_guides/new_model.md advanced_guides/new_model.md
advanced_guides/evaluation_lmdeploy.md
advanced_guides/accelerator_intro.md advanced_guides/accelerator_intro.md
advanced_guides/general_math.md advanced_guides/general_math.md
advanced_guides/llm_judge.md
advanced_guides/code_eval.md advanced_guides/code_eval.md
advanced_guides/code_eval_service.md advanced_guides/code_eval_service.md
advanced_guides/subjective_evaluation.md advanced_guides/subjective_evaluation.md
advanced_guides/circular_eval.md
advanced_guides/needleinahaystack_eval.md
.. _工具: .. _工具:
.. toctree:: .. toctree::

116
examples/eval_llm_judge.py Normal file
View File

@ -0,0 +1,116 @@
from mmengine.config import read_base
from opencompass.models.openai_api import OpenAISDK
# Import pre-configured models from OpenCompass
with read_base():
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import (
models as lmdeploy_qwen2_5_7b_instruct_model,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as lmdeploy_qwen2_5_14b_instruct_model,
)
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.datasets import CustomDataset
# Dataset reader configuration
math_reader_cfg = dict(input_columns=['problem'], output_column='answer')
# Inference configuration
math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{problem}\nRemember to put your final answer within \\boxed{}.',
),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Template for the LLM judge
GRADER_TEMPLATE = """
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer.
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
Just return the letters "A" or "B", with no text around it.
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
<Original Question Begin>: \n{problem}\n<Original Question End>\n\n
<Gold Target Begin>: \n{answer}\n<Gold Target End>\n\n
<Predicted Answer Begin>: \n{prediction}\n<Predicted End>\n\n
Judging the correctness of candidates' answers:
""".strip()
# Evaluation configuration using LLM as judge
math_eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
)
],
round=[
dict(role='HUMAN', prompt=GRADER_TEMPLATE),
],
),
),
dataset_cfg=dict(
type=CustomDataset,
path='opencompass/math',
file_name='test_prm800k_500.jsonl',
reader_cfg=math_reader_cfg,
),
judge_cfg=lmdeploy_qwen2_5_14b_instruct_model[0],
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)
# Dataset configuration
datasets = [
dict(
type=CustomDataset,
path='opencompass/math',
file_name='test_prm800k_500.jsonl',
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg,
)
]
# Model to be evaluated
models = lmdeploy_qwen2_5_7b_instruct_model
# Limiting test to first 8 examples for quick testing
math_reader_cfg['test_range'] = '[0:8]'
# Output directory
work_dir = 'outputs/llm_judge'

View File

@ -3,6 +3,7 @@ import copy
import math import math
import os import os
import os.path as osp import os.path as osp
import random
import statistics import statistics
import sys import sys
import time import time
@ -37,18 +38,31 @@ class OpenICLEvalTask(BaseTask):
super().__init__(cfg) super().__init__(cfg)
self.logger = get_logger() self.logger = get_logger()
self.num_gpus = max( self.num_gpus = max(
c.get('eval_cfg', {}).get('num_gpus', 0) max(
c.get('eval_cfg', {}).get('num_gpus', 0),
c.get('eval_cfg', {}).get('evaluator', {}).get(
'judge_cfg', {}).get('run_cfg', {}).get('num_gpus', 0),
) for c in sum(self.dataset_cfgs, []))
self.num_procs = max(
c.get('eval_cfg', {}).get('evaluator', {}).get(
'judge_cfg', {}).get('run_cfg', {}).get('num_procs', 1)
for c in sum(self.dataset_cfgs, [])) for c in sum(self.dataset_cfgs, []))
self.dump_details = cfg.get('eval', {}).get('runner', {}).get( self.dump_details = (cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('dump_details', False) 'task', {}).get('dump_details', False))
self.cal_extract_rate = cfg.get('eval', {}).get('runner', {}).get( self.cal_extract_rate = (cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('cal_extract_rate', False) 'task', {}).get('cal_extract_rate', False))
def get_command(self, cfg_path, template): def get_command(self, cfg_path, template):
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
script_path = __file__ script_path = __file__
python = sys.executable if self.num_gpus > 1:
command = f'{python} {script_path} {cfg_path}' port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
python = sys.executable
command = f'{python} {script_path} {cfg_path}'
return template.format(task_cmd=command) return template.format(task_cmd=command)
def run(self): def run(self):
@ -63,8 +77,10 @@ class OpenICLEvalTask(BaseTask):
dataset_cfg['reader_cfg']['output_column']) dataset_cfg['reader_cfg']['output_column'])
out_path = get_infer_output_path( out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg, self.model_cfg,
osp.join(self.work_dir, 'results')) self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
if osp.exists(out_path): if osp.exists(out_path):
continue continue
self._score() self._score()
@ -86,8 +102,10 @@ class OpenICLEvalTask(BaseTask):
# Load predictions # Load predictions
filename = get_infer_output_path( filename = get_infer_output_path(
self.model_cfg, self.dataset_cfg, self.model_cfg,
osp.join(self.work_dir, 'predictions')) self.dataset_cfg,
osp.join(self.work_dir, 'predictions'),
)
# in case the prediction is partial # in case the prediction is partial
root, ext = osp.splitext(filename) root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext partial_filename = root + '_0' + ext
@ -123,6 +141,7 @@ class OpenICLEvalTask(BaseTask):
and not MODELS.get(self.model_cfg['type']).is_api): and not MODELS.get(self.model_cfg['type']).is_api):
# Create a prompt template for role config parsing # Create a prompt template for role config parsing
from opencompass.models.base import LMTemplateParser from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(self.model_cfg['meta_template']) parser = LMTemplateParser(self.model_cfg['meta_template'])
role = parser.roles[self.eval_cfg['pred_role']] role = parser.roles[self.eval_cfg['pred_role']]
if sc_size is not None: if sc_size is not None:
@ -131,15 +150,19 @@ class OpenICLEvalTask(BaseTask):
'must be list.') 'must be list.')
if pred_list_flag: if pred_list_flag:
pred_strs = [[ pred_strs = [[
extract_role_pred(_pred, role.get('begin', None), extract_role_pred(
role.get('end', None)) _pred,
for _pred in pred role.get('begin', None),
role.get('end', None),
) for _pred in pred
] for pred in pred_strs] ] for pred in pred_strs]
else: else:
pred_strs = [ pred_strs = [
extract_role_pred(pred, role.get('begin', None), extract_role_pred(
role.get('end', None)) pred,
for pred in pred_strs role.get('begin', None),
role.get('end', None),
) for pred in pred_strs
] ]
# Postprocess predictions if necessary # Postprocess predictions if necessary
@ -195,8 +218,10 @@ class OpenICLEvalTask(BaseTask):
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
# need results dir to save other files # need results dir to save other files
out_path = get_infer_output_path( out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg, self.model_cfg,
osp.join(self.work_dir, 'results')) self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
icl_evaluator._out_dir = osp.splitext(out_path)[ icl_evaluator._out_dir = osp.splitext(out_path)[
0] # strip extension 0] # strip extension
@ -235,9 +260,13 @@ class OpenICLEvalTask(BaseTask):
details = result.get('details', None) details = result.get('details', None)
try: try:
result['details'] = self.format_details( result['details'] = self.format_details(
pred_strs, model_pred_strs, pred_strs,
test_set[self.output_column], details, model_details, model_pred_strs,
pred_dicts) test_set[self.output_column],
details,
model_details,
pred_dicts,
)
self.logger.warning( self.logger.warning(
f"result['details'] : {result['details']}"), f"result['details'] : {result['details']}"),
result['type'] = result['details'].pop('type', None) result['type'] = result['details'].pop('type', None)
@ -247,8 +276,8 @@ class OpenICLEvalTask(BaseTask):
if 'PPL' in str( if 'PPL' in str(
self.dataset_cfg.infer_cfg.inferencer.type): self.dataset_cfg.infer_cfg.inferencer.type):
result['correct_bpb'], result['incorrect_bpb'] = \ result['correct_bpb'], result['incorrect_bpb'] = (
self.calculate_bpb(pred_dicts) self.calculate_bpb(pred_dicts))
except Exception as e: except Exception as e:
self.logger.warning(f'Skip dumping details due to: {e}.') self.logger.warning(f'Skip dumping details due to: {e}.')
else: else:
@ -281,8 +310,11 @@ class OpenICLEvalTask(BaseTask):
f'{task_abbr_from_cfg(self.cfg)}:{model_result_wo_details}') f'{task_abbr_from_cfg(self.cfg)}:{model_result_wo_details}')
# Save result # Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, out_path = get_infer_output_path(
osp.join(self.work_dir, 'results')) self.model_cfg,
self.dataset_cfg,
osp.join(self.work_dir, 'results'),
)
mkdir_or_exist(osp.split(out_path)[0]) mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4) mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
@ -305,8 +337,15 @@ class OpenICLEvalTask(BaseTask):
success_rate = 100 - len(invalid_extractions) / len(details) * 100 success_rate = 100 - len(invalid_extractions) / len(details) * 100
return success_rate return success_rate
def format_details(self, predictions, model_pred_strs, references, details, def format_details(
model_details, pred_dicts): self,
predictions,
model_pred_strs,
references,
details,
model_details,
pred_dicts,
):
"""This function is responsible for formatting prediction details. """This function is responsible for formatting prediction details.
Args: Args:
@ -344,8 +383,9 @@ class OpenICLEvalTask(BaseTask):
result['references'] = str(references[i]) result['references'] = str(references[i])
result['correct'] = str(predictions[i]) == str(references[i]) result['correct'] = str(predictions[i]) == str(references[i])
elif details is not None and model_details is not None: elif details is not None and model_details is not None:
assert model_pred_strs != [], \ assert (
'Model details is not None, but model_pred_strs is empty' model_pred_strs != []
), 'Model details is not None, but model_pred_strs is empty'
self.logger.info( self.logger.info(
f"model_details[i]['pred']: {model_details[i]['pred']}") f"model_details[i]['pred']: {model_details[i]['pred']}")
results['type'] = 'GEN' results['type'] = 'GEN'