fix dataset-index & use official llm_judge_postprocess

This commit is contained in:
huihui 2025-05-08 04:31:11 +00:00
parent bc9ba0126f
commit 5e8bfee3f4
3 changed files with 19 additions and 79 deletions

View File

@ -1023,3 +1023,9 @@
paper: https://arxiv.org/pdf/2402.09391
configpath: opencompass/configs/datasets/SmolInstruct/smolinstruct_gen.py
configpath_llmjudge: ''
- SciKnowEval:
name: SciKnowEval
category: Science
paper: https://arxiv.org/abs/2406.09098
configpath: opencompass/configs/datasets/SciKnowEval/SciKnowEval_gen_ebe47d.py
configpath_llmjudge: opencompass/configs/datasets/SciKnowEval/SciKnowEval_llmjudge_gen_ebe47d.py

View File

@ -1,5 +1,6 @@
from opencompass.datasets import SciKnowEvalDataset, SciKnowEvalEvaluator, SciKnowEval_llmjudge_postprocess
from opencompass.datasets import SciKnowEvalDataset, SciKnowEvalEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.evaluator import GenericLLMEvaluator
@ -90,7 +91,7 @@ eval_cfg_biology = dict(
reader_cfg=reader_cfg,
),
judge_cfg=dict(),
dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess),
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)
@ -120,7 +121,7 @@ eval_cfg_chemistry = dict(
subset='chemistry',
),
judge_cfg=dict(),
dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess),
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)
@ -150,7 +151,7 @@ eval_cfg_material = dict(
subset='material',
),
judge_cfg=dict(),
dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess),
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)
@ -180,7 +181,7 @@ eval_cfg_physics = dict(
subset='physics',
),
judge_cfg=dict(),
dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess),
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
)

View File

@ -27,13 +27,18 @@ class SciKnowEvalDataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
def capitalize_first_letter(s):
if not s: # 检查字符串是否为空
return s
return s[0].upper() + s[1:]
subset = kwargs['subset']
data_files = {'test': f'data/{capitalize_first_letter(subset)}/sciknoweval_{subset}_test.jsonl'}
dataset = load_dataset(path, data_files=data_files, split='test')
data_files = {
'test':
f'data/{capitalize_first_letter(subset)}/sciknoweval_{subset}_test.jsonl'
}
dataset = load_dataset(path, data_files=data_files, split='test')
# dataset = dataset.select(range(20))
if prompt_mode == 'zero-shot':
dataset = dataset.map(
@ -101,75 +106,3 @@ def answer_cleansing(
prediction[0] = prediction[0][:-1]
return prediction[0]
def _generic_llmjudge_postprocess(judgement: str):
match = re.search(r'(A|B)', judgement)
grade_letter = (match.group(0) if match else 'B'
) # Default to "INCORRECT" if no match
return grade_letter
def SciKnowEval_llmjudge_postprocess(
output: dict,
output_path: str,
dataset: Dataset,
) -> dict:
# Get the original dataset
original_dataset = dataset.reader.dataset['test']
judged_answers = []
original_responses = []
references = []
details = []
total_correct = 0
total_count = 0
for k, v in output.items():
idx = int(k) # Convert key to integer for indexing
original_responses.append(v['prediction'])
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
sample = original_dataset[idx]
# Record the judgment
if processed_judge is not None:
judged_answers.append(processed_judge)
try:
gold = v['gold']
references.append(gold)
except KeyError:
get_logger().warning(
f'No gold answer for {k}, use empty string as reference!')
gold = ''
references.append('')
# Check if the answer is correct (A means correct)
is_correct = processed_judge == 'A'
total_count += 1
if is_correct:
total_correct += 1
# Add to details
details.append({
'id': k,
'question': sample['question'],
'origin_prompt': v['origin_prompt'],
'llm_judge': processed_judge,
'gold': gold,
'is_correct': is_correct,
})
# Calculate overall accuracy with two decimal places
overall_accuracy = (round(
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
# Initialize results dictionary
results = {
'accuracy': overall_accuracy,
'total_correct': total_correct,
'total_count': total_count,
'details': details,
}
return results