diff --git a/opencompass/configs/datasets/SciKnowEval/SciKnowEval_llmjudge_gen_ebe47d.py b/opencompass/configs/datasets/SciKnowEval/SciKnowEval_llmjudge_gen_ebe47d.py index eeb1cd9d..0758a39b 100644 --- a/opencompass/configs/datasets/SciKnowEval/SciKnowEval_llmjudge_gen_ebe47d.py +++ b/opencompass/configs/datasets/SciKnowEval/SciKnowEval_llmjudge_gen_ebe47d.py @@ -1,6 +1,5 @@ -from opencompass.datasets import SciKnowEvalDataset, SciKnowEvalEvaluator +from opencompass.datasets import SciKnowEvalDataset, SciKnowEvalEvaluator, SciKnowEval_llmjudge_postprocess from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import generic_llmjudge_postprocess from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.evaluator import GenericLLMEvaluator @@ -91,7 +90,7 @@ eval_cfg_biology = dict( reader_cfg=reader_cfg, ), judge_cfg=dict(), - dict_postprocessor=dict(type=generic_llmjudge_postprocess), + dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess), ), ) @@ -121,7 +120,7 @@ eval_cfg_chemistry = dict( subset='chemistry', ), judge_cfg=dict(), - dict_postprocessor=dict(type=generic_llmjudge_postprocess), + dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess), ), ) @@ -151,7 +150,7 @@ eval_cfg_material = dict( subset='material', ), judge_cfg=dict(), - dict_postprocessor=dict(type=generic_llmjudge_postprocess), + dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess), ), ) @@ -181,7 +180,7 @@ eval_cfg_physics = dict( subset='physics', ), judge_cfg=dict(), - dict_postprocessor=dict(type=generic_llmjudge_postprocess), + dict_postprocessor=dict(type=SciKnowEval_llmjudge_postprocess), ), ) @@ -228,4 +227,4 @@ sciknoweval_dataset_physics = dict( infer_cfg=infer_cfg, eval_cfg=eval_cfg_physics, ) -sciknoweval_datasets = [sciknoweval_dataset_biology, sciknoweval_dataset_chemistry, sciknoweval_dataset_physics, sciknoweval_dataset_material] +sciknoweval_datasets = [sciknoweval_dataset_biology, sciknoweval_dataset_chemistry, sciknoweval_dataset_physics, sciknoweval_dataset_material] \ No newline at end of file diff --git a/opencompass/datasets/SciKnowEval.py b/opencompass/datasets/SciKnowEval.py index 867f2c7a..464b4178 100644 --- a/opencompass/datasets/SciKnowEval.py +++ b/opencompass/datasets/SciKnowEval.py @@ -106,3 +106,75 @@ def answer_cleansing( prediction[0] = prediction[0][:-1] return prediction[0] + + +def _generic_llmjudge_postprocess(judgement: str): + match = re.search(r'(A|B)', judgement) + grade_letter = (match.group(0) if match else 'B' + ) # Default to "INCORRECT" if no match + return grade_letter + + +def SciKnowEval_llmjudge_postprocess( + output: dict, + output_path: str, + dataset: Dataset, +) -> dict: + # Get the original dataset + original_dataset = dataset.reader.dataset['test'] + + judged_answers = [] + original_responses = [] + references = [] + details = [] + + total_correct = 0 + total_count = 0 + + for k, v in output.items(): + idx = int(k) # Convert key to integer for indexing + original_responses.append(v['prediction']) + processed_judge = _generic_llmjudge_postprocess(v['prediction']) + + sample = original_dataset[idx] + # Record the judgment + if processed_judge is not None: + judged_answers.append(processed_judge) + try: + gold = v['gold'] + references.append(gold) + except KeyError: + get_logger().warning( + f'No gold answer for {k}, use empty string as reference!') + gold = '' + references.append('') + + # Check if the answer is correct (A means correct) + is_correct = processed_judge == 'A' + total_count += 1 + + if is_correct: + total_correct += 1 + + # Add to details + details.append({ + 'id': k, + 'question': sample['question'], + 'origin_prompt': v['origin_prompt'], + 'llm_judge': processed_judge, + 'gold': gold, + 'is_correct': is_correct, + }) + + # Calculate overall accuracy with two decimal places + overall_accuracy = (round( + (total_correct / total_count * 100), 2) if total_count > 0 else 0.00) + + # Initialize results dictionary + results = { + 'accuracy': overall_accuracy, + 'total_correct': total_correct, + 'total_count': total_count, + 'details': details, + } + return results