2025-01-22 19:27:00 +08:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
|
|
from fuzzywuzzy import fuzz
|
2025-02-02 14:46:39 +08:00
|
|
|
|
2025-01-22 19:27:00 +08:00
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
2025-02-02 14:46:39 +08:00
|
|
|
|
2025-01-22 19:27:00 +08:00
|
|
|
from ..base import BaseDataset
|
|
|
|
|
|
|
|
|
|
|
|
class HuStandardFIBDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2025-02-02 14:46:39 +08:00
|
|
|
def load(filepath):
|
|
|
|
assert os.path.isfile(filepath)
|
|
|
|
assert filepath.endswith('.jsonl')
|
2025-01-22 19:27:00 +08:00
|
|
|
dataset = DatasetDict()
|
2025-02-02 14:46:39 +08:00
|
|
|
f = open(filepath, 'r', encoding='utf-8')
|
2025-01-22 19:27:00 +08:00
|
|
|
lines = f.readlines()
|
|
|
|
objs = []
|
|
|
|
|
|
|
|
for line in lines:
|
|
|
|
obj = json.loads(line)
|
|
|
|
objs.append(obj)
|
|
|
|
|
|
|
|
out_dict_list = []
|
|
|
|
|
|
|
|
for obj in objs:
|
2025-02-02 14:46:39 +08:00
|
|
|
instruction = obj['question'] # TODO: question -> instruction
|
|
|
|
questions = obj[
|
|
|
|
'question_sub'] # TODO: update question_sub -> questions
|
|
|
|
hu_specific_dim = obj['hu_specific_dim']
|
2025-01-22 19:27:00 +08:00
|
|
|
tmp = obj
|
2025-02-02 14:46:39 +08:00
|
|
|
new_obj = dict(instruction=instruction,
|
|
|
|
questions=questions,
|
|
|
|
hu_specific_dim=hu_specific_dim,
|
|
|
|
reference=tmp)
|
2025-01-22 19:27:00 +08:00
|
|
|
out_dict_list.append(new_obj)
|
|
|
|
dataset = Dataset.from_list(out_dict_list)
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
class HuStandardFIBEvaluator(BaseEvaluator):
|
|
|
|
"""
|
|
|
|
ref: opencompass.openicl.icl_evaluator.AccwithDetailsEvaluator
|
|
|
|
"""
|
|
|
|
|
|
|
|
def score(self, predictions, references, origin_prompt) -> dict:
|
|
|
|
if len(predictions) != len(references):
|
|
|
|
return {'error': 'preds and refers have different length.'}
|
|
|
|
details = {}
|
|
|
|
blank_correct, blank_total = 0, 0
|
|
|
|
question_correct, question_total = 0, 0
|
|
|
|
|
|
|
|
for idx, (pred, refer, prompt) in enumerate(
|
|
|
|
zip(predictions, references, origin_prompt)):
|
|
|
|
std_ans = [
|
|
|
|
re.sub(r'#\d+#', '', ans).split(';')
|
2025-02-02 14:46:39 +08:00
|
|
|
for ans in refer['answer'] # TODO: answer -> answers
|
|
|
|
] # Remove "#0#" and "#1#", then split refer['formatted_std_ans']
|
2025-01-22 19:27:00 +08:00
|
|
|
model_ans = []
|
|
|
|
pred = pred.strip()
|
|
|
|
match = re.search(r'\{.*?\}', pred, re.DOTALL)
|
|
|
|
if match:
|
|
|
|
json_str = match.group(0)
|
|
|
|
else:
|
|
|
|
blank_total += len(std_ans)
|
|
|
|
question_total += 1
|
|
|
|
details[idx] = {
|
|
|
|
'detail': refer,
|
|
|
|
'model_ans': model_ans,
|
|
|
|
'gt': std_ans,
|
|
|
|
'prompt': prompt,
|
|
|
|
'raw_pred': pred,
|
|
|
|
}
|
|
|
|
continue
|
|
|
|
json_str = json_str.strip()
|
|
|
|
json_str = json_str.replace('\\xa0', '')
|
|
|
|
formatted_json_str = json_str
|
|
|
|
|
|
|
|
to_end_flag = False
|
|
|
|
if isinstance(formatted_json_str, str):
|
|
|
|
try:
|
|
|
|
data = json.loads(formatted_json_str)
|
|
|
|
to_end_flag = True
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
print(f'Invalid JSON format. {idx}')
|
|
|
|
blank_total += len(std_ans)
|
|
|
|
question_total += 1
|
|
|
|
|
|
|
|
elif isinstance(formatted_json_str, dict):
|
|
|
|
data = formatted_json_str
|
|
|
|
to_end_flag = True
|
|
|
|
else:
|
|
|
|
blank_total += len(std_ans)
|
|
|
|
question_total += 1
|
|
|
|
|
|
|
|
model_ans = []
|
|
|
|
if to_end_flag:
|
|
|
|
model_ans = [
|
|
|
|
re.sub(r'#\d+#', '', ans).split(';')
|
2025-02-02 14:46:39 +08:00
|
|
|
for ans in data.get('answers', [])
|
2025-01-22 19:27:00 +08:00
|
|
|
] # Preprocess model_ans in the same way as std_ans
|
|
|
|
is_question_correct = True
|
|
|
|
for idx, ans_list in enumerate(std_ans):
|
|
|
|
if idx >= len(model_ans):
|
|
|
|
is_question_correct = False
|
|
|
|
break
|
|
|
|
|
|
|
|
model_list = model_ans[idx]
|
|
|
|
for ans in ans_list:
|
|
|
|
best_match = max(
|
|
|
|
model_list,
|
|
|
|
key=lambda model: fuzz.ratio(ans, model))
|
|
|
|
if fuzz.ratio(ans, best_match) > 70: # check threshold
|
|
|
|
blank_correct += 1
|
|
|
|
else:
|
|
|
|
is_question_correct = False
|
|
|
|
|
|
|
|
blank_total += len(std_ans)
|
|
|
|
question_total += 1
|
|
|
|
question_correct += 1 if is_question_correct else 0
|
|
|
|
|
|
|
|
details[idx] = {
|
|
|
|
'detail': refer,
|
|
|
|
'model_ans': model_ans,
|
|
|
|
'gt': std_ans,
|
|
|
|
'prompt': prompt,
|
|
|
|
'raw_pred': pred,
|
|
|
|
}
|
|
|
|
results = {
|
|
|
|
'blank_level_correctness':
|
|
|
|
round(blank_correct / blank_total * 100, 2),
|
|
|
|
'question_level_correctness':
|
|
|
|
round(question_correct / question_total * 100, 2),
|
|
|
|
'details':
|
|
|
|
details
|
|
|
|
}
|
|
|
|
|
|
|
|
return results
|