OpenCompass/opencompass/datasets/charm.py
jxd 12b84aeb3b
[Feature] Update CHARM Memeorziation (#1230)
* update gemini api and add gemini models

* add openai models

* update CHARM evaluation

* add CHARM memorization tasks

* add CharmMemSummarizer (output eval details for memorization-independent reasoning analysis

* update CHARM readme

---------

Co-authored-by: wujiang <wujiang@pjlab.org.cn>
2024-07-26 18:42:30 +08:00

154 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os.path as osp
import re
from typing import List, Union
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator, LMEvaluator
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import build_dataset_from_cfg
from .base import BaseDataset
@TEXT_POSTPROCESSORS.register_module('charm-reason')
def charm_reason_postprocess(text: str) -> str:
ans = text
ans_line = ans.split('answer is ')
if len(ans_line) != 1:
ans = ans_line[1].strip()
match = re.search(r'\(([A-Z])\)*', ans)
if match:
return match.group(1)
match = re.search(r'([A-Z])', ans)
if match:
return match.group(1)
return ans
@ICL_EVALUATORS.register_module()
class CharmReasonEvaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
if pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'score': score, 'details': details}
UNCERTAIN_LIST = ['不确定', '无法确定', '无法回答', '不知道', '不认识']
def charm_memory_eval(pred: str, ref: Union[str, List[str]]) -> str:
for uncertain in UNCERTAIN_LIST:
if uncertain in pred:
return '[错误]'
is_negative = False
if isinstance(ref, str):
if ref.startswith('[not]'):
# 部分CHARM记忆题目的ref是"[not]xxx"
# 即xxx是一个负例pred中不应该出现xxx
# 例如https://github.com/opendatalab/CHARM/blob/v1.0/data/CHARM/memorization/Chinese_Movie_and_Music_Recommendation.json#L45
is_negative = True
ref = ref[5:] # 去掉[not]保留xxx
references = [ref]
else:
references = ref # 部分CHARM记忆题目的ref是List[str]
assert isinstance(references, list)
for r in references:
if r in pred: # pred中包含ref
if is_negative:
return '[错误]'
else:
return '[正确]'
if is_negative: # 已验证pred中不包含ref且ref是负例所以pred是正确的
return '[正确]'
else:
return '[错误]'
class CharmMemoryEvaluator(LMEvaluator):
"""本Evaluator是基于规则评判CHARM记忆题目的回答是否正确,
只用于Chinese_Movie_and_Music_Recommendation这一个任务的评判。
由于CHARM其他的记忆任务需要使用LLM作为judge使用LMEvaluator因而整个eval使用的是SubjectiveEvalTask。
因此本Evaluator的输入输出与LMEvaluator一致。"""
def __init__(self, prompt_template=None, *nargs, **kwargs):
if prompt_template is None:
prompt_template = dict(
type='PromptTemplate',
template=dict(
round=[dict(role='HUMAN', prompt='')])) # useless
super().__init__(prompt_template, *nargs, **kwargs)
def score(self, predictions, references, **kwargs):
assert isinstance(predictions, dict) # single-model scoring
references = [{} for _ in range(len(predictions[0]['model_preds']))
] if references is None else references
predictions = predictions['model_preds']
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
eval_results = [
charm_memory_eval(pred, ref)
for pred, ref in zip(predictions, references)
]
dataset = None
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)
output = dict()
for i in range(len(predictions)):
if dataset is not None:
question = ''
for col in dataset.reader.input_columns:
question += dataset.reader['test'][col][i] + '\n'
output[str(i)] = {
'origin_prompt': [{
'role':
'HUMAN',
'prompt':
f"[Question]: {question}[Assistant's Answer]: {predictions[i]}" # noqa
}],
'prediction':
eval_results[i],
'gold':
references[i],
}
return output
@LOAD_DATASET.register_module()
class CharmDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
with open(osp.join(path, f'{name}.json'), 'r', encoding='utf-8') as f:
data = json.load(f)['examples']
dataset = Dataset.from_list(data)
return dataset