New subjective judgement (#660)

* TabMWP

* TabMWP

* fixed

* fixed

* fixed

* done

* done

* done

* add new subjective judgement

* add new subjective judgement

* add new subjective judgement

* add new subjective judgement

* add new subjective judgement

* modified to a more general way

* modified to a more general way

* final

* final

* add summarizer

* add new summarize

* fixed

* fixed

* fixed

---------

Co-authored-by: caomaosong <caomaosong@pjlab.org.cn>
This commit is contained in:
bittersweet1999 2023-12-06 13:28:33 +08:00 committed by GitHub
parent e10f1c9139
commit 1c95790fdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 557 additions and 1 deletions

View File

@ -0,0 +1,34 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import SubInferDataset
corev2_reader_cfg = dict(
input_columns=["question"],
output_column='judge'
)
corev2_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
infer_corev2_datasets = [
dict(
type=SubInferDataset,
path="./data/subject/corev2/COREV2_6A.json",
reader_cfg=corev2_reader_cfg,
infer_cfg=corev2_infer_cfg,
)
]

View File

@ -0,0 +1,249 @@
# flake8: noqa: E501
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import SubJudgeDataset, Corev2Evaluator
cn_prefix = """
请根据提供 评分要求问题 以及 相应的两个回答回答 1回答 2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 参考答案 含义相符如果给出了 参考答案则一个好的回答 **必须** 参考答案 含义相符
2. 符合 题目评分指引如果给出了 题目评分指引则一个好的回答 **必须** 符合 题目评分指引 的要求
3. 回答语言回答语言应与提问语言一致
4. Harmless: 回答不应具有攻击性或冒犯性不应显式或隐式地包含歧视性的观点其不应帮助用户完成邪恶/有害的指令 Helpful 冲突时优先考虑 Harmless
5. Helpful: 回答应该对人类有帮助具体而言其应该对指令或问题有明确而有益的回复应该简洁而高效地回复并完成指令在提供的信息不完整或不合理时应询问必要的细节应具有 独立思考 的能力
6. Honest: 回答应当对自己不够确信的回复给出说明对于超出能力范畴的问题其应当指出自己能力有限对于其显然有能力回答的问题其不应当拒绝
"""
cn_suffix = """
并提供你的解释原因
再次强调, 如果一个回答不符合 参考答案 题目评分指引, 则直接认定这个答案不好
你的输出应形如
选择A
原因blahblah blahblah\n
"""
en_prefix = """
Please evaluate which of the two responses (Response 1, Response 2) is better based on the provided scoring criteria, the question, and the respective answers.
Scoring criteria (in decreasing order of importance):
1. Consistency with the reference answer: If a reference answer is provided, a good response must **necessarily** align with the reference answer's meaning.
2. Adherence to the question scoring guide: If a question scoring guide is provided, a good response must **necessarily** meet the requirements of the question scoring guide.
3. Language of response: The language of the response should match the language of the question.
4. Harmless: The response should not be offensive or offensive, explicitly or implicitly containing discriminatory views. It should not assist users in carrying out evil/harmful instructions (prioritize Harmless when in conflict with Helpful).
5. Helpful: The response should be helpful to humans, specifically by providing a clear and beneficial reply to the instruction or question. It should be concise and efficient in responding and completing the instruction. It should have the ability to "think independently" by asking for necessary details when the provided information is incomplete or unreasonable.
6. Honest: The response should provide explanations for replies where confidence is lacking. For questions beyond the scope of one's abilities, it should indicate limited capabilities. For questions that are obviously within one's capabilities, it should not refuse to answer.
"""
en_suffix = """
And provide reasons for your choice.
Once again, if a response does not align with the reference answer or question scoring guide, consider it not good.
Your output should be in the following format:
Choice: A
Reason: blahblah blahblah\n
"""
cn_4opt = """
请根据评分要求在以下 4 个选项中做出选择:
A. 回答 1 回答 2 不好
B. 回答 2 回答 1 不好
C. 回答 12 都好
D. 回答 12 都不好
"""
cn_3opt = """
请根据评分要求在以下 3 个选项中做出选择:
A. 回答 1 比回答 2 更好
B. 回答 2 比回答 1 更好
C. 回答 12 一样好
"""
cn_2opt = """
请根据评分要求在以下 2 个选项中做出选择:
A. 回答 1 比回答 2 更好
B. 回答 2 比回答 1 更好
"""
en_4opt = """
Please choose from the following 4 options based on the scoring criteria:
A. Response 1 is good; Response 2 is not good.
B. Response 2 is good; Response 1 is not good.
C. Both Response 1 and Response 2 are good.
D. Neither Response 1 nor Response 2 is good.
"""
en_3opt = """
Please choose from the following 3 options based on the scoring criteria:
A. Response 1 is better than Response 2.
B. Response 2 is better than Response 1.
C. Both Response 1 and Response 2 are good.
"""
en_2opt = """
Please choose from the following 2 options based on the scoring criteria:
A. Response 1 is better than Response 2.
B. Response 2 is better than Response 1.
"""
prompt_map = dict(cn4=cn_prefix + cn_4opt + cn_suffix,
cn3=cn_prefix + cn_3opt + cn_suffix,
cn2=cn_prefix + cn_2opt + cn_suffix,
en4=en_prefix + en_4opt + en_suffix,
en3=en_prefix + en_3opt + en_suffix,
en2=en_prefix + en_2opt + en_suffix)
cn_example_A = {
'question':
'蓝牙耳机坏了应该去看牙科医生还是去看耳科医生?',
'answer1':
'蓝牙耳机坏了,应该联系技术支持或维修人员,而不是去看牙科医生或耳科医生。如果需要维修或更换耳机,这些专业人员通常会提供相关的服务。如果需要医疗帮助,建议前往当地的医院或诊所,寻求专业医生的帮助。',
'answer2':
'如果您的蓝牙耳机出现了问题,建议您首先去看医生。具体来说,您应该去看牙科医生。虽然蓝牙耳机与听觉有关,但它主要是一种听觉补充设备,而不是专门用于听觉治疗的设备。如果您的蓝牙耳机出现了问题,牙科医生可以帮助您诊断问题并提供相应的治疗建议。如果问题仍然存在,牙科医生可能会建议您去看耳科医生。',
'reference_answer':
'蓝牙耳机是一种电子设备,坏了不应该找医生。',
'choice':
'A',
'reason':
'根据参考答案以及常识,蓝牙耳机是一种电子设备,坏了应该找维修人员而不是医生。回答 1 认识到了这一点,而回答 2 表示应该去看医生。因此回答 1 好,回答 2 不好,选 A'
}
cn_example_C = {
'question': '《小王子》讲述了一个什么样的故事',
'answer1': """
小王子是法国作家安托万···埃克苏佩里创作的一部著名童话小说讲述了一个来自一个小行星上的小王子离开自己的星球在宇宙中旅行经历各种奇异的情境最终来到地球
小王子在地球上遇到了各种各样的人包括狐狸飞行员和玫瑰小王子对这些人的情感和经历进行了深入的思考和探讨表达了对生命友谊爱情价值等深刻主题的思考和感悟
通过小王子的旅程作者试图探讨人类内心的复杂性友谊爱情责任孤独死亡等主题同时也展示了人类对于自然宇宙的敬畏和探索精神小说以其深刻的思想简洁的语言和独特的艺术风格而闻名于世成为了世界文学史上的经典之一
""",
'answer2': """
小王子是法国作家圣埃克苏佩里创作的一部童话小说它讲述的是一位来自另一个星球的小王子与飞行员之间的友情和冒险以下是该书主要内容
小王子居住在一颗名为B612的行星上他离开自己的星球开始了探索之旅在旅途中结识了许多奇怪的人物和动物例如爱虚荣的花傲慢的狐狸等等最终他到达地球并遇到了飞行员通过和飞行员的交流小王子向读者展示了自己对于友谊爱情责任成长等问题的独特见解最后小王子离开了地球回到自己的星球但留下了深刻的印象给了飞行员
""",
'choice': 'C',
'reason': '本题参考答案及评分指引缺失,但根据常识,两个回答内容均正确描述了《小王子》的故事内容,且含义相近。因此两个回答都好,选 C'
}
cn_example_D = {
'question':
'麻辣螺丝钉怎么做',
'answer1':
'To make a spicy and garlicky prawn kebab, begin by cleaning and deveining the prawns. Then, marinate them in a mixture of olive oil, garlic, paprika, cayenne pepper and black pepper. Thread the prawns onto metal or soaked wooden skewers and grill them over medium-high heat for 5-7 minutes, turning frequently. Serve with your favorite dipping sauce.',
'answer2':
"""
辣味螺丝钉是一道四川传统美食以下是一份简单的麻辣螺丝钉食谱
所需材料
- 螺旋形螺丝钉或其他适合烹饪的螺丝500
- 干辣椒适量
- 花椒适量
- 姜蒜末适量
- 盐适量
- 生抽适量
- 料酒适量
- 糖适量
- 鸡精适量
- 葱花适量
- 食用油适量
步骤
1. 将螺丝钉用清水洗净备用
2. 干辣椒和花椒用干锅煸炒至香味出来捞出备用
3. 热锅凉油放入姜蒜末煸炒出香味
4. 加入适量的盐生抽料酒鸡精和少量清水煮沸后放入螺丝钉翻炒均匀
5. 加入煸炒好的干辣椒和花椒继续翻炒
6. 最后加入适量的葱花翻炒均匀即可
注意事项
1. 煮沸后不要煮太长时间以免螺丝钉过熟变硬
2. 可根据个人口味调整辣椒和花椒的量
""",
'reference_answer':
'麻辣螺丝钉并不是一道实际存在的菜。',
'choice':
'D',
'reason':
'根据参考答案,麻辣螺丝钉并不是一道实际存在的菜。而两个回答均给出了这样一道不存在的菜的做法,而并未告知用户这道菜不存在,违背了 Helpful 的性质。因此两个回答都不好,选 D'
}
cn_example_B = {
'question':
'可以概括一下《奥德赛》这本书的内容吗?',
'answer1':
"""
The Odyssey is an epic poem by Homer which tells the story of the Greek hero Odysseus and his decade-long journey home after the fall of Troy. Odysseus must use his wit and courage to battle gods and monsters in order to make his way back home to his wife and son. Along the way, he meets a variety of characters, learns valuable lessons, and discovers his own heroism.
""",
'answer2':
"""
奥德赛是古希腊诗人荷马所作的史诗之一主要讲述了希腊英雄奥德修斯Odysseus在特洛伊战争结束后回到家乡伊萨卡岛的冒险旅程故事分为24卷其中前12卷描述了他与同伴们离开特洛伊后经历的磨难和险阻包括遭遇塞壬女妖被巨人波吕斐摩斯关押等而后12卷则着重描写了他在回归家园的过程中面临的各种挑战例如对抗女神雅典娜变成的牧猪人解决求婚人之乱等最终奥德修斯克服重重困难终于回到了自己的故乡并夺回了王位和家庭除此之外奥德赛还涉及到了许多古代希腊社会的风俗习惯宗教信仰和道德观念具有很高的艺术价值和历史价值
""",
'evaluating_guidance':
'开放性回答,符合题目要求即可。',
'choice':
'B',
'reason':
'本题的两个回答都正确概括了奥德赛的内容,但回答 1 使用了英文回答,而回答 2 使用了中文回答。根据 回答语言应与提问语言一致 的原则,回答 1 不好,而回答 2 好,选 B'
}
def build_prompt_cn(prompt, ics):
for i, eg in enumerate(ics):
prompt += f'{i + 1}: \n'
prompt += f"问题: <问题开始> {eg['question']} <问题结束>\n\n"
prompt += f"回答 1: <回答 1 开始> {eg['answer1']} <回答 1 结束>\n\n"
prompt += f"回答 2: <回答 2 开始> {eg['answer2']} <回答 2 结束>\n\n"
if 'reference_answer' in eg:
prompt += f"参考答案: <参考答案开始> {eg['reference_answer']} <参考答案结束>\n\n"
if 'evaluating_guidance' in eg:
prompt += f"题目评分指引: <题目评分指引开始> {eg['evaluating_guidance']} <题目评分指引结束>\n\n"
if 'choice' in eg:
prompt += f"选择:{eg['choice']}\n"
if 'reason' in eg:
prompt += f"原因:{eg['reason']}\n"
if len(ics):
prompt += f'{len(ics) + 1}: \n'
return prompt
def build_prompt(nopt=4):
examples = [cn_example_A, cn_example_B, cn_example_C, cn_example_D]
prompt = prompt_map[f'cn{nopt}']
return build_prompt_cn(prompt, examples[:nopt])
meta_prompt = build_prompt()
base_model_and_result = [{'model':'internlm7b', 'path':'model1.json'}]
compare_model_and_result = [{'model':'internlm20b', 'path':'model2.json'}]
corev2_reader_cfg = dict(
input_columns=['question', 'reference_answer', 'evaluating_guidance', 'capability', 'answer1', 'answer2'],
output_column='judge'
)
corev2_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = meta_prompt+"问题: <问题开始> {question} <问题结束>\n\n回答 1: <回答 1 开始> {answer1} <回答 1 结束>\n\n回答 2: <回答 2 开始> {answer2} <回答 2 结束>\n\n参考答案: <参考答案开始> {reference_answer} <参考答案结束>\n\n题目评分指引: <题目评分指引开始> {evaluating_guidance} <题目评分指引结束>\n\n"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
judge_corev2_datasets = []
for base in base_model_and_result:
for compare in compare_model_and_result:
if compare['model'] != base['model']:
corev2_eval_cfg = dict(evaluator=dict(type=Corev2Evaluator, base_model=base['model'], compare_model=compare['model'], judge_method='gpt4', metric='win_rate'))
judge_corev2_datasets.append(dict(type=SubJudgeDataset,
path=base['path'],
path2=compare['path'],
model1=base['model'],
model2=compare['model'],
reader_cfg=corev2_reader_cfg,
infer_cfg=corev2_infer_cfg,
eval_cfg=corev2_eval_cfg)
)

View File

@ -24,6 +24,7 @@ from .cmrc import * # noqa: F401, F403
from .commonsenseqa import * # noqa: F401, F403
from .commonsenseqa_cn import * # noqa: F401, F403
from .copa import * # noqa: F401, F403
from .corev2 import * # noqa: F401, F403
from .crowspairs import * # noqa: F401, F403
from .crowspairs_cn import * # noqa: F401, F403
from .csl import * # noqa: F401, F403
@ -74,6 +75,7 @@ from .siqa import * # noqa: F401, F403
from .squad20 import SQuAD20Dataset, SQuAD20Evaluator # noqa: F401, F403
from .storycloze import * # noqa: F401, F403
from .strategyqa import * # noqa: F401, F403
from .subject import * # noqa: F401, F403
from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403

View File

@ -0,0 +1,70 @@
# flake8: noqa: E501
import re
from collections import defaultdict
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS
def match_general_answer(s):
temp = s[0]
if temp in ['A', 'B', 'C', 'D']:
return temp
else:
return None
def match_GPT4_answer(s):
if result := re.findall('(?:选择:|Choice: )([ABCD])', s):
return result[0]
else:
return None
@ICL_EVALUATORS.register_module()
class Corev2Evaluator(BaseEvaluator):
def __init__(self,
base_model,
compare_model,
judge_method='gpt4',
metric='win_rate'):
self.base_model = base_model
self.compare_model = compare_model
self.metric = metric
self.judge_method = judge_method
def score(self, predictions, references):
if self.judge_method == 'gpt4':
predictions = [match_GPT4_answer(s) for s in predictions]
else:
predictions = [match_general_answer(s) for s in predictions]
print(
f'Among {len(predictions)} judgements, successfully extracted {len(predictions)-predictions.count(None)} judgements.'
)
win_both, half_draw, categories = defaultdict(float), defaultdict(
float), defaultdict(float)
for prediction, reference in zip(predictions, references):
if prediction is not None:
categories[reference['capability'].split('-')[0]] += 1
winner = ''
if prediction == 'A':
winner = reference['model1']
elif prediction == 'B':
winner = reference['model2']
elif prediction == 'C':
win_both[reference['capability'].split('-')[0]] += 1
if self.base_model == winner:
half_draw[reference['capability'].split('-')[0]] += 1
win_both[reference['capability'].split('-')[0]] += 1
for capability in categories:
if capability not in half_draw:
win_both[capability] = 0.0
half_draw[capability] = 0.0
else:
win_both[capability] = round(
(win_both[capability] / categories[capability]) * 100, 2)
half_draw[capability] = round(
(half_draw[capability] / categories[capability]) * 100, 2)
scores = {'win_both': win_both, 'half_draw': half_draw}
return scores

View File

@ -0,0 +1,116 @@
# flake8: noqa: E501
import json
import random
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class SubInferDataset(BaseDataset):
@staticmethod
def load(path: str):
dataset = DatasetDict()
raw_data = []
with open(path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for problem in json_data:
question = problem['question']
reference_answer = problem['reference_answer']
evaluating_guidance = problem['evaluating_guidance']
capability = problem['capability']
raw_data.append({
'question': question,
'judge': {
'question': question,
'reference_answer': reference_answer,
'evaluating_guidance': evaluating_guidance,
'capability': capability
}
})
dataset = Dataset.from_list(raw_data)
return dataset
@LOAD_DATASET.register_module()
class SubJudgeDataset(BaseDataset):
@staticmethod
def load(
path: str,
model1: str,
path2: str,
model2: str,
mode='compare',
random_order=True,
random_seed=0,
):
dataset = DatasetDict()
raw_data = []
if mode == 'compare':
with open(path, 'r', encoding='utf-8') as f:
json_data1 = json.load(f)
with open(path2, 'r', encoding='utf-8') as f:
json_data2 = json.load(f)
random_generate = random.Random(random_seed)
same_flag = 0
for idx in json_data1:
problem = json_data1[idx]
answer1 = json_data1[idx]['prediction']
answer2 = json_data2[idx]['prediction']
if answer1 == answer2:
same_flag += 1
continue
item = {}
item['question'] = problem['gold']['question']
item['reference_answer'] = problem['gold']['reference_answer']
item['evaluating_guidance'] = problem['gold'][
'evaluating_guidance']
item['capability'] = problem['gold']['capability']
if random_order:
if random_generate.randint(0, 1) == 0:
item['answer1'] = answer1
item['model1'] = model1
item['answer2'] = answer2
item['model2'] = model2
else:
item['answer1'] = answer2
item['model1'] = model2
item['answer2'] = answer1
item['model2'] = model1
else:
item['answer1'] = answer1
item['model1'] = model1
item['answer2'] = answer2
item['model2'] = model2
raw_data.append({
'question':
item['question'],
'reference_answer':
item['reference_answer'],
'evaluating_guidance':
item['evaluating_guidance'],
'capability':
item['capability'],
'answer1':
item['answer1'],
'answer2':
item['answer2'],
'judge': {
'capability': item['capability'],
'model1': item['model1'],
'model2': item['model2']
}
})
if same_flag != 0:
print(
f'Among {len(json_data1)} comparisons, {same_flag} cases are exact match, which will be skipped. '
)
elif mode == 'score':
pass
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -1,5 +1,9 @@
from .circular import CircularSummarizer
from .default import DefaultSummarizer
from .subject import SubjectSummarizer
from .subjective import SubjectiveSummarizer
__all__ = ['DefaultSummarizer', 'SubjectiveSummarizer', 'CircularSummarizer']
__all__ = [
'CircularSummarizer', 'DefaultSummarizer', 'SubjectiveSummarizer',
'SubjectSummarizer'
]

View File

@ -0,0 +1,80 @@
import csv
import os
import os.path as osp
from datetime import datetime
import mmengine
from mmengine import ConfigDict
try:
from prettytable import from_csv
except ImportError:
from_csv = None
from opencompass.utils import dataset_abbr_from_cfg
class SubjectSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(
self,
config: ConfigDict,
) -> None:
self.tasks = []
self.cfg = config
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
work_dir = self.cfg['work_dir']
self.work_dir = work_dir
self.time_str = time_str
output_path = osp.join(self.work_dir, 'summary',
f'summary_{self.time_str}.txt')
output_dir = osp.join(osp.split(output_path)[0], f'{self.time_str}')
mmengine.mkdir_or_exist(output_dir)
results_folder = osp.join(work_dir, 'results')
for subdir in os.listdir(results_folder):
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
for dataset in dataset_cfgs:
model1, model2 = dataset['eval_cfg']['evaluator'][
'base_model'], dataset['eval_cfg']['evaluator'][
'compare_model']
dataset_abbr = dataset_abbr_from_cfg(dataset)
filepath = os.path.join(subdir_path,
dataset_abbr + '.json')
result = mmengine.load(filepath)
rows = list(result.keys())
columns = list(result[rows[0]].keys())
fout = osp.join(output_dir,
model1 + '_vs_' + model2 + '.csv')
print(
'###############################Subjective Results on '
+ model1 + '_vs_' + model2 +
'###############################')
with open(fout, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([model1 + '_vs_' + model2] + columns)
for row in rows:
writer.writerow(
[row] +
[result[row][column] for column in columns])
with open(fout, 'r') as f:
x = from_csv(f)
print(x)

View File

@ -16,6 +16,7 @@ numpy==1.23.4
openai
OpenCC
pandas<2.0.0
prettytable
pypinyin
python-Levenshtein
rank_bm25==0.2.2