import os from typing import Any, Callable, Dict, List, Optional import mmengine from datasets import Dataset from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.registry import ICL_EVALUATORS from opencompass.utils.logging import get_logger @ICL_EVALUATORS.register_module() class CascadeEvaluator(BaseEvaluator): """Cascade Evaluator. First uses a rule-based method to judge predictions. If a sample is marked as incorrect by the rule-based method, then it uses an LLM judge to re-evaluate it. Arguments: llm_evaluator (dict): Configuration for the LLM evaluator. rule_evaluator (Optional[dict]): Configuration for the rule-based evaluator. sample_score_fn (Optional[Callable]): A function to score individual samples. If provided without rule_evaluator, this function will be used directly. parallel (bool): Whether to run in parallel mode. """ def __init__( self, llm_evaluator: Dict, rule_evaluator: Optional[Dict] = None, sample_score_fn: Optional[Callable] = None, parallel: bool = True, ) -> None: super().__init__() self.logger = get_logger(__name__) # Initialize the LLM evaluator llm_evaluator_type = llm_evaluator.pop('type') if isinstance(llm_evaluator_type, str): llm_evaluator_type = ICL_EVALUATORS.get(llm_evaluator_type) self.llm_evaluator = llm_evaluator_type(**llm_evaluator) # Initialize the rule evaluator if provided self.rule_evaluator = None if rule_evaluator: rule_evaluator_type = rule_evaluator.pop('type') if isinstance(rule_evaluator_type, str): rule_evaluator_type = ICL_EVALUATORS.get(rule_evaluator_type) self.rule_evaluator = rule_evaluator_type(**rule_evaluator) self.sample_score_fn = sample_score_fn self.parallel = parallel # At least one of rule_evaluator or sample_score_fn must be provided if not self.rule_evaluator and not self.sample_score_fn: raise ValueError( 'Either rule_evaluator or sample_score_fn must be provided') def sample_score(self, prediction: str, reference: str, test_set=None) -> Dict[str, Any]: """Score a single sample using sample_score_fn or rule_evaluator. Args: prediction: The model's prediction. reference: The ground truth. Returns: Dict: A dictionary containing the score and other details. """ if self.sample_score_fn: # Use user-provided function to evaluate a single sample result = self.sample_score_fn(prediction, reference, test_set) if not isinstance(result, dict): # Ensure result is a dictionary with at least 'correct' field result = { 'correct': bool(result), 'pred': prediction, 'answer': reference, } return result else: # Use rule_evaluator to evaluate a single sample by calling # the score method with single-element lists result = self.rule_evaluator.score([prediction], [reference], [test_set]) if 'details' in result and len(result['details']) > 0: return result['details'][0] else: # Fallback if rule_evaluator doesn't provide detailed results return { 'correct': result.get('accuracy', 0) > 0, 'pred': prediction, 'answer': reference, } def _get_llm_correctness(self, llm_detail): """Determine if the LLM judge considers the answer correct. Args: llm_detail: The evaluation details from the LLM judge. Returns: bool: Whether the answer is correct according to the LLM judge. """ if 'prediction' in llm_detail: response = llm_detail['prediction'].strip().upper() return response == 'A' or response.startswith('CORRECT') elif 'correct' in llm_detail: return llm_detail['correct'] elif 'score' in llm_detail: return llm_detail['score'] > 0.5 return False def score( self, predictions: List[str], references: List[str], test_set: Optional[Dataset] = None, ) -> Dict[str, Any]: """Score predictions using cascade or parallel evaluation. Args: predictions: List of model predictions. references: List of ground truths. test_set: Huggingface Dataset containing original test samples. Returns: Dict: A dictionary containing the scores and details. """ self.logger.info( f"Running {'parallel' if self.parallel else 'cascade'} evaluation") # Step 1: Evaluate each sample individually using rule-based evaluation details = [] failed_predictions = [] failed_references = [] failed_indices = [] for i, (pred, ref) in enumerate(zip(predictions, references)): if test_set is not None: test_item = test_set[i] else: test_item = None # Apply prediction postprocessing for each sample [pred] = self.rule_evaluator.pred_postprocess([pred]) result = self.sample_score(pred, ref, test_item) result['evaluation_method'] = 'rule' details.append({'rule_evaluation': result}) # If the sample failed rule-based evaluation or in parallel # mode, mark it for LLM evaluation if not result.get('correct', False) or self.parallel: failed_predictions.append(pred) failed_references.append(ref) failed_indices.append(i) # Calculate initial accuracy based on rule evaluation initial_correct = sum( 1 for detail in details if detail['rule_evaluation'].get('correct', False)) initial_accuracy = (100 * initial_correct / len(predictions) if predictions else 0) self.logger.info( f'Rule-based evaluation: {initial_correct}/{len(predictions)} ' f'correct ({initial_accuracy:.2f}%)') eval_mode = ('parallel (all samples)' if self.parallel else 'cascade (only failed samples)') self.logger.info(f'Samples requiring LLM evaluation ({eval_mode}): ' f'{len(failed_indices)}') # Step 2: If there are samples for LLM evaluation if failed_predictions and test_set is not None: self.logger.info(f'Running LLM evaluation in {eval_mode} mode...') # Create a subset of the test_set for LLM evaluation failed_subset = test_set.select(failed_indices) # Add prediction and reference columns to the dataset failed_subset = failed_subset.add_column('prediction', failed_predictions) failed_subset = failed_subset.add_column('reference', failed_references) # Set a custom output path for LLM evaluation original_out_dir = getattr(self.llm_evaluator, '_out_dir', None) self.llm_evaluator._out_dir = f'{self._out_dir}_llm_judge' # Generate random hash suffix llm_results_path = f'{self.llm_evaluator._out_dir}_replica{self.dataset_replica_idx}.json' # noqa self.logger.info(f'LLM evaluation results will be saved at ' f'{llm_results_path}') # Check if results already exist to avoid re-evaluation if os.path.exists(llm_results_path): self.logger.info( f'Loading existing LLM evaluation results from ' f'{llm_results_path}') llm_results = mmengine.load(llm_results_path) # Extract details from loaded results if llm_results.get('details', []): loaded_details = llm_results['details'] else: loaded_details = llm_results # Strictly verify that the loaded results match # the current evaluation needs if len(loaded_details) != len(failed_indices): error_msg = ( f'Error: Loaded LLM results contain ' f'{len(loaded_details)} samples, but current ' f'evaluation requires {len(failed_indices)} samples. ' f"The cached results at {llm_results_path} don't match" f'the current evaluation needs. ' f'Please remove the cache file or fix the mismatch.') self.logger.error(error_msg) raise ValueError(error_msg) else: # Use GenericLLMEvaluator to evaluate samples # unset dataset_cfg for GenericLLMEvaluator to # directly use test_set # self.llm_evaluator.output_path = llm_results_path self.llm_evaluator._dataset_replica_idx = \ self._dataset_replica_idx self.llm_evaluator.dataset_cfg = None # Apply prediction postprocessing to for LLM evaluator failed_predictions = self.llm_evaluator.pred_postprocess( failed_predictions) llm_results = self.llm_evaluator.score( predictions=failed_predictions, references=failed_references, test_set=failed_subset, ) # Restore original output directory if original_out_dir: self.llm_evaluator._out_dir = original_out_dir if llm_results.get('details', []): llm_details = llm_results['details'] else: llm_details = llm_results # Initialize counters for accuracy calculation final_correct = initial_correct if not self.parallel else 0 llm_correct = 0 llm_evaluated = 0 # Update the details for samples that were evaluated by LLM for i, llm_detail in enumerate(llm_details.values()): # Add dataset replica index to LLM evaluation result llm_detail['dataset_replica_idx'] = self.dataset_replica_idx original_index = failed_indices[i] # Store original rule-based evaluation result rule_result = details[original_index].copy() rule_correct = rule_result['rule_evaluation'].get( 'correct', False) # Add LLM evaluation details details[original_index]['llm_evaluation'] = llm_detail # Determine LLM correctness judgment and store it is_correct = self._get_llm_correctness(llm_detail) details[original_index]['llm_evaluation'][ 'llm_correct'] = is_correct # Count LLM evaluation statistics llm_evaluated += 1 if is_correct: llm_correct += 1 # Update final_correct counter based on evaluation mode if self.parallel: # In parallel mode, either rule-based or LLM evaluations # should be correct if rule_correct or is_correct: final_correct += 1 else: # In cascade mode, if rule was incorrect but LLM # correct, increment # (rule correct samples are already counted # in initial_correct) if not rule_correct and is_correct: final_correct += 1 # Calculate final accuracy final_accuracy = (100 * final_correct / len(predictions) if predictions else 0) llm_accuracy = (100 * llm_correct / llm_evaluated if llm_evaluated else 0) self.logger.info( f'Final evaluation: {final_correct}/{len(predictions)}' f'correct ({final_accuracy:.2f}%)') if llm_evaluated > 0: self.logger.info( f'LLM evaluation: {llm_correct}/{llm_evaluated} ' f'correct ({llm_accuracy:.2f}%)') # Append cascade correctness flag to each sample for item in details: _rule_correct = item['rule_evaluation'].get('correct', False) if 'llm_evaluation' in item: _llm_correct = item['llm_evaluation'].get( 'llm_correct', False) else: _llm_correct = False item['cascade_correct'] = _rule_correct or _llm_correct result = { 'accuracy': final_accuracy, 'cascade_stats': { 'total_samples': len(predictions), 'rule_correct': initial_correct, 'rule_accuracy': initial_accuracy, 'llm_evaluated': llm_evaluated, 'llm_correct': llm_correct, 'llm_accuracy': llm_accuracy, 'final_correct': final_correct, 'final_accuracy': final_accuracy, 'parallel_mode': self.parallel, }, 'details': details, } return result