From 6ea181d15c164a7d2727181b488537404164975a Mon Sep 17 00:00:00 2001 From: zhangsongyang Date: Wed, 14 May 2025 08:30:23 +0000 Subject: [PATCH] Update --- opencompass/evaluator/cascade_evaluator.py | 13 ++++++++++++- .../openicl/icl_evaluator/icl_base_evaluator.py | 4 ++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/opencompass/evaluator/cascade_evaluator.py b/opencompass/evaluator/cascade_evaluator.py index 44a765b8..6a898546 100644 --- a/opencompass/evaluator/cascade_evaluator.py +++ b/opencompass/evaluator/cascade_evaluator.py @@ -233,7 +233,8 @@ class CascadeEvaluator(BaseEvaluator): self.llm_evaluator.dataset_cfg = None # Apply prediction postprocessing to for LLM evaluator - failed_predictions = self.llm_evaluator.pred_postprocess(failed_predictions) + failed_predictions = self.llm_evaluator.pred_postprocess( + failed_predictions) llm_results = self.llm_evaluator.score( predictions=failed_predictions, @@ -308,6 +309,16 @@ class CascadeEvaluator(BaseEvaluator): f'LLM evaluation: {llm_correct}/{llm_evaluated} ' f'correct ({llm_accuracy:.2f}%)') + # Append cascade correctness flag to each sample + for item in details: + _rule_correct = item['rule_evaluation'].get('correct', False) + if 'llm_evaluation' in item: + _llm_correct = item['llm_evaluation'].get( + 'llm_correct', False) + else: + _llm_correct = False + item['cascade_correct'] = _rule_correct or _llm_correct + result = { 'accuracy': final_accuracy, 'cascade_stats': { diff --git a/opencompass/openicl/icl_evaluator/icl_base_evaluator.py b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py index 94c69f3a..dded48f9 100644 --- a/opencompass/openicl/icl_evaluator/icl_base_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py @@ -182,6 +182,10 @@ class BaseEvaluator: elif example['detail'].get('is_correct', None) is not None: can_calculate = True c += int(example['detail']['is_correct']) + elif example['detail'].get('cascade_correct', + None) is not None: + can_calculate = True + c += int(example['detail']['cascade_correct']) k_list = [k] if isinstance(k, int) else k if can_calculate and n > 1 and max(k_list) > 1: