add creationbench (#753)

This commit is contained in:
bittersweet1999 2023-12-29 18:03:44 +08:00 committed by GitHub
parent 8728287a55
commit fe0b717033
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 408 additions and 167 deletions

View File

@ -0,0 +1,60 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CreationBenchDataset
subjective_reader_cfg = dict(
input_columns=['question', 'capability', 'gpt4_prefix', 'gpt4_suffix'],
output_column='judge',
)
subjective_all_sets = [
"creationbench",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=2048),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = "{gpt4_prefix}{prediction}{gpt4_suffix}"
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=CreationBenchDataset,
multi_dimension=True,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -80,6 +80,7 @@ from .storycloze import * # noqa: F401, F403
from .strategyqa import * # noqa: F401, F403
from .subject_alignmentbench import AlignmentBenchDataset # noqa: F401, F403
from .subject_corev2 import Corev2Dataset # noqa: F401, F403
from .subject_creationbench import CreationBenchDataset # noqa: F401, F403
from .subject_creationv01 import Creationv01Dataset # noqa: F401, F403
from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403
from .summedits import * # noqa: F401, F403

View File

@ -95,7 +95,7 @@ class AlignmentBenchDataset(SubjectiveCmpDataset):
else:
alignmentbench_config = None
dataset = list(super().load(path, name))
corev2_dataset = []
alignbench_dataset = []
for data in dataset:
if alignmentbench_config:
dimensions, prefix = prompt_construct(data,
@ -103,8 +103,8 @@ class AlignmentBenchDataset(SubjectiveCmpDataset):
data['critiquellm_prefix'] = prefix
data['judge']['others'] = data['others']
data['ref'] = data['others']['reference']
corev2_dataset.append(data)
dataset = Dataset.from_list(corev2_dataset)
alignbench_dataset.append(data)
dataset = Dataset.from_list(alignbench_dataset)
return dataset

View File

@ -0,0 +1,138 @@
# flake8: noqa: E501
import json
import os.path as osp
import re
from typing import Optional
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .subjective_cmp import SubjectiveCmpDataset
eng_base_prefix = """
You are an assistant skilled at evaluating the quality of creative text.
Please evaluate the quality of an AI model's response to a creative question in the capacity of an impartial judge. You'll need to assess the response on the following dimensions: Creativity, Richness, User Demand Fulfillment, and Logical Coherence. We will provide you with a creative question and the AI model's response for evaluation. As you begin your assessment, follow this process:
1. Evaluate the AI model's answers on different dimensions, pointing out its strengths or weaknesses in each dimension and assigning a score of 1 to 10 for each.
2. Finally, based on the assessments across dimensions, provide an overall score of 1 to 10 for the AI model's response.
3. Your scoring should be as stringent as possible and follow the scoring rules below: Generally, the higher the quality of the model's response, the higher the score.
Creativity Scoring Guidelines:
When the model's response fails to provide any innovative or unique content, the creativity score must be between 1 and 2;
When the model's response partially offers original creative content but of low quality, the creativity score is between 3 and 4;
When the model's response consists mostly of creative content but lacks significant novelty in the creation, with average quality, the creativity score can range from 5 to 6;
When the model's response presents novelty and high-quality creative content, the creativity score ranges from 7 to 8;
When the model's response contains highly innovative and high-quality creative content, the creativity score can only reach 9 to 10.
Richness Scoring Guidelines:
When the model's response lacks richness, lacks depth and breadth, offers extremely limited information, and displays very low diversity in information, the richness score must be between 1 and 2;
When the model's response is somewhat lacking in richness, lacks necessary depth, explanations, and examples, might be less relevant or detailed, and has limited contextual considerations, the richness score is between 3 and 4;
When the model's response is somewhat rich but with limited depth and breadth, moderately diverse information, providing users with the necessary information, the richness score can range from 5 to 6;
When the model's response is rich, and provides some depth, comprehensive contextual considerations, and displays some diversity in information, the richness score ranges from 7 to 8;
When the model's response is extremely rich, offers additional depth and breadth, includes multiple relevant detailed explanations and examples to enhance understanding, comprehensive contextual considerations, and presents highly diverse information, the richness score can only reach 9 to 10.
User Demand Fulfillment Scoring Guidelines:
When the model's response is entirely unrelated to user demands, fails to meet basic user requirements, especially in style, theme, and significant word count differences, the user demand fulfillment score must be between 1 and 2;
When the model's response has limited understanding of user demands, only provides somewhat relevant information, lacks strong connections to user demands, unable to significantly aid in problem-solving, significant style, theme, and word count differences, the user demand fulfillment score is between 3 and 4;
When the model's response partially understands user demands, provides some relevant solutions or responses, the style, theme are generally in line with requirements, and the word count differences are not significant, the user demand fulfillment score can range from 5 to 6;
When the model's response understands user demands fairly well, offers fairly relevant solutions or responses, style, theme, and word count align with problem requirements, the user demand fulfillment score ranges from 7 to 8;
When the model's response accurately understands all user demands, provides highly relevant and personalized solutions or responses, style, theme, and word count entirely align with user requirements, the user demand fulfillment score can only reach 9 to 10.
Logical Coherence Scoring Guidelines:
When the model's response lacks any coherence, lacks any logical sequence, entirely mismatched with the question or known information, the logical coherence score must be between 1 and 2;
When the model's response is somewhat coherent but still has numerous logical errors or inconsistencies, the logical coherence score is between 3 and 4;
When the model's response is mostly coherent, with few logical errors, might lose coherence in certain complex situations, the logical coherence score can range from 5 to 6;
When the model's response excels in logical coherence, handles complex logic well, very few errors, can handle intricate logical tasks, the logical coherence score ranges from 7 to 8;
When the model's response achieves perfect logical coherence, flawless in handling complex or challenging questions, without any logical errors, the logical coherence score can only reach 9 to 10.
Overall Scoring Guidelines:
When the model's response is entirely irrelevant to the question, contains substantial factual errors, or generates harmful content, the overall score must be between 1 and 2;
When the model's response lacks severe errors and is generally harmless but of low quality, fails to meet user demands, the overall score ranges from 3 to 4;
When the model's response mostly meets user requirements but performs poorly in some dimensions, with average quality, the overall score can range from 5 to 6;
When the model's response performs well across all dimensions, the overall score ranges from 7 to 8;
Only when the model's response fully addresses user problems and all demands, achieving near-perfect scores across all dimensions, can the overall score reach 9 to 10.
Please remember, you must evaluate and explain before scoring. After your explanation for each dimension, add the score for that dimension. Finally, at the end of your response, in the format of the dictionary (including brackets), return all your scoring results, ensuring your scores are integers:
{'Dimension One': Score, 'Dimension Two': Score, ..., 'Overall Score': Score}, for example: {'Creativity': 9, 'Richness': 6, ..., 'Overall Score': 7}.\n
"""
chn_base_prefix = """
你是一个擅长评价创作类文本质量的助手
请你以公正的评判者的身份评估一个AI模型对创作类问题的回答的质量你需要从下面的几个维度对回答进行评估创造性丰富度满足用户需求逻辑连贯性
我们会给您提供一个创作类问题和需要你评估的AI模型的回答当你开始你的评估时你需要遵守以下的流程
1. 从不同维度对AI模型的答案进行评价指出AI模型的答案有哪些优点或不足在每个维度的评价之后给每一个维度一个110的分数
2. 最后综合每个维度的评估对AI模型的回答给出一个110的综合得分
3. 你的打分需要尽可能严格并且要遵守下面的评分规则总的来说模型回答的质量越高则分数越高
创造性评分规则
当模型的回答没有能够提供任何创新性或独特性内容时创造性得分必须是1到2分
当模型的回答能够提供部分原创性的创作内容但创作质量较低时创造性得分为3到4分
当模型的回答基本均为创造性内容但在创作上无太多新意质量中等创造性得分可以得5到6分
当模型的回答具有新意且创作内容质量较高时创造性得分得到7到8分
当模型的回答的创作内容非常新颖且质量极高时创造性得分才能得到9到10分
丰富度评分规则
当模型的回答很不丰富缺乏深度和广度提供的信息非常有限信息呈现出很低的多样性时丰富度得分必须是1到2分
当模型的回答较不丰富缺乏必要的深度解释和实例较少且可能不够相关或不够详细上下文考虑有限信息展现出较低的多样性时丰富度得分为3到4分
当模型的回答较为丰富但深度和广度有限信息多样性一般用户能够从回答中获得基本所需的信息时丰富度得分可以得5到6分
当模型的回答丰富并提供了一定的深度上下文考虑较为全面信息展现出一定的多样性用户能够从回答中获得所需以及一些额外的有用信息时丰富度得分得到7到8分
当模型的回答非常丰富提供了额外的深度和广度包含多个相关的详细解释和实例以增强理解上下文考虑全面信息呈现出高度的多样性时丰富度得分才能得到9到10分
满足用户需求评分规则
当模型的回答与用户需求完全不相关无法满足用户的基本需求特别是文体主题完全不符字数要求相差很大时满足用户需求得分必须是1到2分
当模型的回答对用户需求的理解有限只能在很小程度上提供相关信息与用户需求关联性较低不太能够帮助用户解决问题文体主题字数与题目要求相差较大时满足用户需求得分为3到4分
当模型的回答能够部分理解用户需求并提供部分相关的解决方案或回应文体主题基本符合需求字数与要求相差不大时满足用户需求得分可以得5到6分
当模型的回答能够较好地理解用户需求并提供较为相关的解决方案或回应文体主题字数符合问题要求时满足用户需求得分得到7到8分
当模型的回答能够精准地理解用户的所有需求并提供高度相关和个性化的解决方案或回应文体主题字数完全符合用户需求时满足用户需求得分才能得到9到10分
逻辑连贯性评分规则
当模型的回答完全不连贯没有任何逻辑性与问题或已知信息完全不匹配时逻辑连贯性得分必须是1到2分
当模型的回答在一定程度上是逻辑连贯的但仍有不少逻辑错误或不一致之处时逻辑连贯性得分为3到4分
当模型的回答在大多数情况下是逻辑连贯的逻辑错误较少但在某些复杂情况下可能无法保持完全的连贯性时逻辑连贯性得分可以得5到6分
当模型的回答在逻辑连贯性方面表现出色能够很好地处理复杂逻辑错误非常少见且能够处理复杂的逻辑任务时逻辑连贯性得分得到7到8分
当模型的回答在逻辑连贯性方面达到完美无论问题多么复杂或具有挑战性都能够展现出无懈可击的逻辑能力没有任何逻辑错误时逻辑连贯性得分才能得到9到10分
综合得分评分规则
当模型回答存在与问题不相关或者有本质性的事实错误或生成了有害内容时总分必须是1到2分
当模型回答没有严重错误而且基本无害但是质量较低没有满足用户需求总分为3到4分
当模型回答基本满足用户要求但是在部分维度上表现较差质量中等总分可以得5到6分
当模型回答在所有维度上表现良好总分得7到8分
只有当模型回答充分地解决了用户问题和所有需求并且在所有维度上都接近满分的情况下才能得9到10分
请记住你必须在你打分前进行评价和解释在你对每个维度的解释之后需要加上对该维度的打分之后在你回答的末尾按照以下字典格式包括括号返回你所有的打分结果并确保你的打分结果是整数
{'维度一': 打分, '维度二': 打分, ..., '综合得分': 打分}例如{'创造性': 9, '丰富度': 6, ..., '综合得分': 7}\n
"""
def prompt_construct(sample):
lan = sample['others']['language']
question = sample['question']
if lan == 'zh':
prompt = chn_base_prefix + '创作类问题:' + str(question) + '\n[模型回答开始]\n'
suffix = '\n[模型回答结束]\n'
elif lan == 'en':
prompt = eng_base_prefix + 'Creative Question: ' + str(
question) + "\n[Model's response start]\n"
suffix = "\n[Model's response end]\n"
return prompt, suffix
@LOAD_DATASET.register_module()
class CreationBenchDataset(SubjectiveCmpDataset):
def load(self,
path: str,
name: str,
multi_dimension: Optional[bool] = False):
dataset = list(super().load(path, name))
creation_dataset = []
for data in dataset:
if multi_dimension:
prefix, suffix = prompt_construct(data)
data['gpt4_prefix'] = prefix
data['gpt4_suffix'] = suffix
data['judge']['others'] = data['others']
# data['ref'] = data['others']['reference']
creation_dataset.append(data)
dataset = Dataset.from_list(creation_dataset)
return dataset

View File

@ -1,8 +1,8 @@
# flake8: noqa: F401, E501
from .alignmentbench import (AlignmentBenchSummarizer, AutojSummarizer,
JudgeLMSummarizer)
from .alignmentbench import AlignmentBenchSummarizer
from .circular import CircularSummarizer # noqa: F401
from .corev2 import Corev2Summarizer # noqa: F401
from .creationbench import CreationBenchSummarizer
from .creationv01 import Creationv01Summarizer # noqa: F401
from .default import DefaultSummarizer # noqa: F401
from .subjective import SubjectiveSummarizer # noqa: F401

View File

@ -16,6 +16,7 @@ except ImportError:
from opencompass.utils import model_abbr_from_cfg
from .subjective_post_process import post_process_autoj, post_process_judgelm
from .utils import get_judgeanswer_and_reference, get_outdir
CATEGORIES = {
@ -23,62 +24,71 @@ CATEGORIES = {
'中文语言': ['基本任务', '中文理解', '综合问答', '文本写作', '角色扮演', '专业能力'],
}
all_dimensions = [
All_Dimensions = [
'事实正确性', '满足用户需求', '安全无害', '清晰度', '逻辑性', '完备性', '创造性', '可负责程度', '逻辑连贯性',
'公平与可负责程度', '丰富度', '综合得分'
]
def post_process(judgement: str):
def extract_rating(text):
pattern = r'{(.*?)}(?![^{]*{)' # match last brackets
match = re.search(pattern, text)
if match:
dictionary_str = match.group(1)
kv_pattern = r"'(.*?)': (\d+)"
matches = re.findall(kv_pattern, dictionary_str)
result_dict = {key: int(value) for key, value in matches}
return result_dict
else:
return None
def check_rating(rating, all_dimensions):
for k, v in rating.items():
if isinstance(v, (int, float)) and k in all_dimensions: # 确保值是数字
if v >= 0 and v <= 10:
pass
else:
return None
else:
return None
return rating
def post_process_alignbench(judgement: str,
all_dimensions=All_Dimensions,
possible_keys=['综合得分']):
"""Input a string like below:
xxx{'事实正确性': 1, '满足用户需求': 1, '清晰度': 2, '完备性': 1, '综合得分': 1}xxx,
and extract each score
"""
def extract_rating(text):
pattern = r'{(.*?)}(?![^{]*{)' # match last brackets
match = re.search(pattern, text)
if match:
dictionary_str = match.group(1)
kv_pattern = r"'(.*?)': (\d+)"
matches = re.findall(kv_pattern, dictionary_str)
result_dict = {key: int(value) for key, value in matches}
return result_dict
else:
return None
def extract_score(text):
pattern = r'\'综合得分\': (\d+(\.\d{1,2})?)'
keys_pattern = '|'.join(map(re.escape, possible_keys))
pattern = rf"({'|'.join(possible_keys)}): (\d+(\.\d{{1,2}})?)"
match = re.search(pattern, text)
if match:
return float(match.group(1))
return -1
def check_rating(rating):
for k, v in rating.items():
if isinstance(v, (int, float)) and k in all_dimensions: # 确保值是数字
if v >= 0 and v <= 10:
pass
else:
return None
else:
return None
return rating
judgement = judgement.replace('\n', '')
rating = extract_rating(judgement)
if rating is not None:
score = rating.get('综合得分', -1)
score = -1
for key in possible_keys:
score = rating.get(key, -1)
if score != -1:
break
if score == -1:
score = extract_score(judgement)
if score >= 0 and score <= 10:
pass
else:
score = -1
rating = check_rating(rating)
rating = check_rating(rating, all_dimensions)
else:
score = -1
if rating == None or score == -1:
@ -87,55 +97,21 @@ def post_process(judgement: str):
return {'rating': rating, 'score': score}
def post_process_autoj(judgement: str):
"""Input a string like below:
xxx[[5]]xxx, and extract the score
"""
pattern = r'\[(\d+)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
score = int(matched_result[0])
else:
return None
return {'score': score}
def post_process_judgelm(judgement: str):
"""Input a string like below:
5, reason:xxx and extract the score
"""
if len(judgement) >= 2:
first_two_chars = judgement[:2]
if first_two_chars.isdigit() and first_two_chars == '10':
score = 10
else:
first_char = judgement[0]
if first_char.isdigit() and 0 <= int(first_char) <= 9:
score = int(first_char)
else:
return None
elif len(judgement) == 1:
if judgement.isdigit() and 0 <= int(judgement) <= 9:
score = int(judgement)
else:
return None
else:
return None
return {'score': score}
def get_dimension_results(judged_answers, references, fout, fout_flag, model):
dimension_ratings = defaultdict(int)
dimension_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
for k, v in ans['rating'].items():
if k != '综合得分':
if k != '综合得分' or k != 'Overall Score':
dimension_ratings[k] += v
dimension_counts[k] += 1
dimension_ratings['综合得分'] += ans['score']
dimension_counts['综合得分'] += 1
else:
if k == '综合得分':
dimension_ratings['综合得分'] += ans['score']
dimension_counts['综合得分'] += 1
else:
dimension_ratings['Overall Score'] += ans['score']
dimension_counts['Overall Score'] += 1
dimension_avg_ratings = defaultdict(float)
for dimension, total_score in dimension_ratings.items():
@ -155,7 +131,12 @@ def get_dimension_results(judged_answers, references, fout, fout_flag, model):
[scores[row][column] for column in columns])
def get_capability_results(judged_answers, references, fout, fout_flag, model):
def get_capability_results(judged_answers,
references,
fout,
fout_flag,
model,
categories=CATEGORIES):
capability_ratings = defaultdict(int)
capability_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
@ -168,14 +149,19 @@ def get_capability_results(judged_answers, references, fout, fout_flag, model):
capability_avg_ratings[
capability] = total_score / capability_counts[capability]
capability_avg_ratings['中文推理总分'] = np.mean(
[np.mean(capability_avg_ratings[cat]) for cat in CATEGORIES['中文推理']])
capability_avg_ratings['中文语言总分'] = np.mean(
[np.mean(capability_avg_ratings[cat]) for cat in CATEGORIES['中文语言']])
capability_avg_ratings['总分'] = (capability_avg_ratings['中文推理总分'] +
capability_avg_ratings['中文语言总分']) / 2
temp_list = []
for category, sub_categories in categories.items():
capability_avg_ratings[category + '总分'] = np.mean([
np.mean(capability_avg_ratings[cat])
for cat in categories[category]
])
temp_list.append(category + '总分')
capability_avg_ratings['总分'] = 0
for temp in temp_list:
capability_avg_ratings['总分'] += capability_avg_ratings[temp]
capability_avg_ratings['总分'] /= len(temp_list)
scores = {model: capability_avg_ratings}
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
@ -183,13 +169,13 @@ def get_capability_results(judged_answers, references, fout, fout_flag, model):
writer.writerow(num_header)
header = ['模型', '总分']
for category, sub_categories in CATEGORIES.items():
for category, sub_categories in categories.items():
header.append(category)
header.extend([None for _ in range(len(sub_categories))])
writer.writerow(header)
sub_header = ['模型', '总分']
for category, sub_categories in CATEGORIES.items():
for category, sub_categories in categories.items():
sub_header.extend([category + '总分'])
sub_header.extend(sub_categories)
writer.writerow(sub_header)
@ -197,7 +183,7 @@ def get_capability_results(judged_answers, references, fout, fout_flag, model):
row = [model]
row.append(scores[model]['总分'])
for category, sub_categories in CATEGORIES.items():
for category, sub_categories in categories.items():
row.append(scores[model][category + '总分'])
for sub_category in sub_categories:
row.append(scores[model][sub_category])
@ -212,7 +198,7 @@ class AlignmentBenchSummarizer:
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict) -> None:
def __init__(self, config: ConfigDict, judge_type: str) -> None:
self.tasks = []
self.cfg = config
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
@ -220,6 +206,15 @@ class AlignmentBenchSummarizer:
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_type = judge_type
assert self.judge_type in ['general', 'autoj', 'judgelm']
self.judge_map = {
'general': post_process_alignbench,
'autoj': post_process_autoj,
'judgelm': post_process_judgelm
}
self.judge_function = self.judge_map[self.judge_type]
self.category = CATEGORIES
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
@ -239,93 +234,27 @@ class AlignmentBenchSummarizer:
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model, judge_model = eval_model_abbr, self.judge_abbr
fout = osp.join(output_dir,
'judged-by--' + judge_model + '-dimension.csv')
if self.judge_type == 'general':
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-dimension.csv')
fout2 = osp.join(
output_dir,
'judged-by--' + judge_model + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, post_process)
get_dimension_results(judged_answers, references, fout,
fout_flag, model)
dataset, subdir_path, self.judge_function)
if self.judge_type == 'general':
get_dimension_results(judged_answers, references, fout,
fout_flag, model)
get_capability_results(judged_answers, references, fout2,
fout_flag2, model)
fout_flag2, model, self.category)
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)
if self.judge_type == 'general':
with open(fout, 'r') as f:
x = from_csv(f)
print(x)
with open(fout2, 'r') as f:
x = from_csv(f)
print(x)
class AutojSummarizer(AlignmentBenchSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict) -> None:
super().__init__(config)
def summarize(self,
post_process=post_process_autoj,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag = 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model, judge_model = eval_model_abbr, self.judge_abbr
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, post_process)
get_capability_results(judged_answers, references, fout,
fout_flag, model)
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)
class JudgeLMSummarizer(AutojSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict) -> None:
super().__init__(config)
def summarize(self,
post_process=post_process_judgelm,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
super().summarize(post_process, time_str)

View File

@ -0,0 +1,73 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import numpy as np
from mmengine import ConfigDict
try:
from prettytable import from_csv
except ImportError:
from_csv = None
from opencompass.utils import model_abbr_from_cfg
from .alignmentbench import AlignmentBenchSummarizer, post_process_alignbench
from .subjective_post_process import post_process_autoj, post_process_judgelm
from .utils import get_judgeanswer_and_reference, get_outdir
CATEGORIES = {
'中文': ['内容扩写_ZH', '内容续写_ZH', '内容改写_ZH'],
'英文': ['内容扩写_EN', '内容续写_EN', '内容改写_EN'],
}
All_Dimensions = [
'Creativity', 'Richness', 'User Demand Fulfillment', 'Logical Coherence',
'Overall Score', '创造性', '丰富度', '满足用户需求', '逻辑连贯性', '综合得分'
]
def post_process_creationbench(judgement: str,
all_dimensions=All_Dimensions,
possible_keys=['综合得分', 'Overall Score']):
"""Input a string like below:
xxx{'事实正确性': 1, '满足用户需求': 1, '清晰度': 2, '完备性': 1, '综合得分': 1}xxx,
and extract each score
"""
return post_process_alignbench(judgement, all_dimensions, possible_keys)
class CreationBenchSummarizer(AlignmentBenchSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type: str) -> None:
super().__init__(config, judge_type)
self.judge_map = {
'general': post_process_creationbench,
'autoj': post_process_autoj,
'judgelm': post_process_judgelm
}
self.judge_function = self.judge_map[self.judge_type]
self.category = CATEGORIES
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
super().summarize(time_str)

View File

@ -0,0 +1,40 @@
import re
def post_process_autoj(judgement: str):
"""Input a string like below:
xxx[[5]]xxx, and extract the score
"""
pattern = r'\[(\d+)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
score = int(matched_result[0])
else:
return None
return {'score': score}
def post_process_judgelm(judgement: str):
"""Input a string like below:
5, reason:xxx and extract the score
"""
if len(judgement) >= 2:
first_two_chars = judgement[:2]
if first_two_chars.isdigit() and first_two_chars == '10':
score = 10
else:
first_char = judgement[0]
if first_char.isdigit() and 0 <= int(first_char) <= 9:
score = int(first_char)
else:
return None
elif len(judgement) == 1:
if judgement.isdigit() and 0 <= int(judgement) <= 9:
score = int(judgement)
else:
return None
else:
return None
return {'score': score}