This commit is contained in:
wujiang 2025-02-07 18:50:03 +08:00 committed by jxd
parent 9ae714a577
commit 61ceb02c23

View File

@ -53,7 +53,7 @@ class HuStandardFIBEvaluator(BaseEvaluator):
blank_correct, blank_total = 0, 0 blank_correct, blank_total = 0, 0
question_correct, question_total = 0, 0 question_correct, question_total = 0, 0
for i, (pred, refer, prompt) in enumerate( for idx, (pred, refer, prompt) in enumerate(
zip(predictions, references, origin_prompt)): zip(predictions, references, origin_prompt)):
std_ans = [ std_ans = [
re.sub(r'#\d+#', '', ans).split(';') re.sub(r'#\d+#', '', ans).split(';')
@ -67,7 +67,7 @@ class HuStandardFIBEvaluator(BaseEvaluator):
else: else:
blank_total += len(std_ans) blank_total += len(std_ans)
question_total += 1 question_total += 1
details[i] = { details[idx] = {
'reference': refer, 'reference': refer,
'model_ans': model_ans, 'model_ans': model_ans,
'gt': std_ans, 'gt': std_ans,
@ -87,7 +87,7 @@ class HuStandardFIBEvaluator(BaseEvaluator):
data = json.loads(formatted_json_str) data = json.loads(formatted_json_str)
to_end_flag = True to_end_flag = True
except json.JSONDecodeError: except json.JSONDecodeError:
print(f'Invalid JSON format. {i}') print(f'Invalid JSON format. {idx}')
blank_total += len(std_ans) blank_total += len(std_ans)
question_total += 1 question_total += 1
@ -106,13 +106,13 @@ class HuStandardFIBEvaluator(BaseEvaluator):
re.sub(r'#\d+#', '', ans).split(';') re.sub(r'#\d+#', '', ans).split(';')
for ans in data.get('answers', []) for ans in data.get('answers', [])
] # Preprocess model_ans in the same way as std_ans ] # Preprocess model_ans in the same way as std_ans
for idx, ans_list in enumerate(std_ans): for ans_idx, ans_list in enumerate(std_ans):
if idx >= len(model_ans): if ans_idx >= len(model_ans):
is_question_correct = False is_question_correct = False
blank_wise_correctness.append(False) blank_wise_correctness.append(False)
continue continue
model_list = model_ans[idx] model_list = model_ans[ans_idx]
is_blank_correct = True is_blank_correct = True
for ans in ans_list: for ans in ans_list:
best_match = max( best_match = max(
@ -129,7 +129,7 @@ class HuStandardFIBEvaluator(BaseEvaluator):
question_total += 1 question_total += 1
question_correct += 1 if is_question_correct else 0 question_correct += 1 if is_question_correct else 0
details[i] = { details[idx] = {
'reference': refer, 'reference': refer,
'std_ans': std_ans, 'std_ans': std_ans,
'model_ans': model_ans, 'model_ans': model_ans,