OpenCompass/opencompass/evaluator/cascade_evaluator.py

339 lines
14 KiB
Python
Raw Normal View History

import os
from typing import Any, Callable, Dict, List, Optional
import mmengine
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS
from opencompass.utils.logging import get_logger
@ICL_EVALUATORS.register_module()
class CascadeEvaluator(BaseEvaluator):
"""Cascade Evaluator.
First uses a rule-based method to judge predictions.
If a sample is marked as incorrect by the rule-based method,
then it uses an LLM judge to re-evaluate it.
Arguments:
llm_evaluator (dict): Configuration for the LLM evaluator.
rule_evaluator (Optional[dict]): Configuration for the
rule-based evaluator.
sample_score_fn (Optional[Callable]): A function to
score individual samples. If provided without rule_evaluator,
this function will be used directly.
parallel (bool): Whether to run in parallel mode.
"""
def __init__(
self,
llm_evaluator: Dict,
rule_evaluator: Optional[Dict] = None,
sample_score_fn: Optional[Callable] = None,
parallel: bool = True,
) -> None:
super().__init__()
self.logger = get_logger(__name__)
# Initialize the LLM evaluator
llm_evaluator_type = llm_evaluator.pop('type')
if isinstance(llm_evaluator_type, str):
llm_evaluator_type = ICL_EVALUATORS.get(llm_evaluator_type)
self.llm_evaluator = llm_evaluator_type(**llm_evaluator)
# Initialize the rule evaluator if provided
self.rule_evaluator = None
if rule_evaluator:
rule_evaluator_type = rule_evaluator.pop('type')
if isinstance(rule_evaluator_type, str):
rule_evaluator_type = ICL_EVALUATORS.get(rule_evaluator_type)
self.rule_evaluator = rule_evaluator_type(**rule_evaluator)
self.sample_score_fn = sample_score_fn
self.parallel = parallel
# At least one of rule_evaluator or sample_score_fn must be provided
if not self.rule_evaluator and not self.sample_score_fn:
raise ValueError(
'Either rule_evaluator or sample_score_fn must be provided')
def sample_score(self,
prediction: str,
reference: str,
test_set=None) -> Dict[str, Any]:
"""Score a single sample using sample_score_fn or rule_evaluator.
Args:
prediction: The model's prediction.
reference: The ground truth.
Returns:
Dict: A dictionary containing the score and other details.
"""
if self.sample_score_fn:
# Use user-provided function to evaluate a single sample
result = self.sample_score_fn(prediction, reference, test_set)
if not isinstance(result, dict):
# Ensure result is a dictionary with at least 'correct' field
result = {
'correct': bool(result),
'pred': prediction,
'answer': reference,
}
return result
else:
# Use rule_evaluator to evaluate a single sample by calling
# the score method with single-element lists
result = self.rule_evaluator.score([prediction], [reference],
[test_set])
if 'details' in result and len(result['details']) > 0:
return result['details'][0]
else:
# Fallback if rule_evaluator doesn't provide detailed results
return {
'correct': result.get('accuracy', 0) > 0,
'pred': prediction,
'answer': reference,
}
def _get_llm_correctness(self, llm_detail):
"""Determine if the LLM judge considers the answer correct.
Args:
llm_detail: The evaluation details from the LLM judge.
Returns:
bool: Whether the answer is correct according to the LLM judge.
"""
if 'prediction' in llm_detail:
response = llm_detail['prediction'].strip().upper()
return response == 'A' or response.startswith('CORRECT')
elif 'correct' in llm_detail:
return llm_detail['correct']
elif 'score' in llm_detail:
return llm_detail['score'] > 0.5
return False
def score(
self,
predictions: List[str],
references: List[str],
test_set: Optional[Dataset] = None,
) -> Dict[str, Any]:
"""Score predictions using cascade or parallel evaluation.
Args:
predictions: List of model predictions.
references: List of ground truths.
test_set: Huggingface Dataset containing original test samples.
Returns:
Dict: A dictionary containing the scores and details.
"""
self.logger.info(
f"Running {'parallel' if self.parallel else 'cascade'} evaluation")
# Step 1: Evaluate each sample individually using rule-based evaluation
details = []
failed_predictions = []
failed_references = []
failed_indices = []
for i, (pred, ref) in enumerate(zip(predictions, references)):
if test_set is not None:
test_item = test_set[i]
else:
test_item = None
# Apply prediction postprocessing for each sample
[pred] = self.rule_evaluator.pred_postprocess([pred])
result = self.sample_score(pred, ref, test_item)
result['evaluation_method'] = 'rule'
details.append({'rule_evaluation': result})
# If the sample failed rule-based evaluation or in parallel
# mode, mark it for LLM evaluation
if not result.get('correct', False) or self.parallel:
failed_predictions.append(pred)
failed_references.append(ref)
failed_indices.append(i)
# Calculate initial accuracy based on rule evaluation
initial_correct = sum(
1 for detail in details
if detail['rule_evaluation'].get('correct', False))
initial_accuracy = (100 * initial_correct /
len(predictions) if predictions else 0)
self.logger.info(
f'Rule-based evaluation: {initial_correct}/{len(predictions)} '
f'correct ({initial_accuracy:.2f}%)')
eval_mode = ('parallel (all samples)'
if self.parallel else 'cascade (only failed samples)')
self.logger.info(f'Samples requiring LLM evaluation ({eval_mode}): '
f'{len(failed_indices)}')
# Step 2: If there are samples for LLM evaluation
if failed_predictions and test_set is not None:
self.logger.info(f'Running LLM evaluation in {eval_mode} mode...')
# Create a subset of the test_set for LLM evaluation
failed_subset = test_set.select(failed_indices)
# Add prediction and reference columns to the dataset
failed_subset = failed_subset.add_column('prediction',
failed_predictions)
failed_subset = failed_subset.add_column('reference',
failed_references)
# Set a custom output path for LLM evaluation
original_out_dir = getattr(self.llm_evaluator, '_out_dir', None)
self.llm_evaluator._out_dir = f'{self._out_dir}_llm_judge'
# Generate random hash suffix
llm_results_path = f'{self.llm_evaluator._out_dir}_replica{self.dataset_replica_idx}.json' # noqa
self.logger.info(f'LLM evaluation results will be saved at '
f'{llm_results_path}')
# Check if results already exist to avoid re-evaluation
if os.path.exists(llm_results_path):
self.logger.info(
f'Loading existing LLM evaluation results from '
f'{llm_results_path}')
llm_results = mmengine.load(llm_results_path)
# Extract details from loaded results
if llm_results.get('details', []):
loaded_details = llm_results['details']
else:
loaded_details = llm_results
# Strictly verify that the loaded results match
# the current evaluation needs
if len(loaded_details) != len(failed_indices):
error_msg = (
f'Error: Loaded LLM results contain '
f'{len(loaded_details)} samples, but current '
f'evaluation requires {len(failed_indices)} samples. '
f"The cached results at {llm_results_path} don't match"
f'the current evaluation needs. '
f'Please remove the cache file or fix the mismatch.')
self.logger.error(error_msg)
raise ValueError(error_msg)
else:
# Use GenericLLMEvaluator to evaluate samples
# unset dataset_cfg for GenericLLMEvaluator to
# directly use test_set
# self.llm_evaluator.output_path = llm_results_path
self.llm_evaluator._dataset_replica_idx = \
self._dataset_replica_idx
self.llm_evaluator.dataset_cfg = None
# Apply prediction postprocessing to for LLM evaluator
failed_predictions = self.llm_evaluator.pred_postprocess(
failed_predictions)
llm_results = self.llm_evaluator.score(
predictions=failed_predictions,
references=failed_references,
test_set=failed_subset,
)
# Restore original output directory
if original_out_dir:
self.llm_evaluator._out_dir = original_out_dir
if llm_results.get('details', []):
llm_details = llm_results['details']
else:
llm_details = llm_results
# Initialize counters for accuracy calculation
final_correct = initial_correct if not self.parallel else 0
llm_correct = 0
llm_evaluated = 0
# Update the details for samples that were evaluated by LLM
for i, llm_detail in enumerate(llm_details.values()):
# Add dataset replica index to LLM evaluation result
llm_detail['dataset_replica_idx'] = self.dataset_replica_idx
original_index = failed_indices[i]
# Store original rule-based evaluation result
rule_result = details[original_index].copy()
rule_correct = rule_result['rule_evaluation'].get(
'correct', False)
# Add LLM evaluation details
details[original_index]['llm_evaluation'] = llm_detail
# Determine LLM correctness judgment and store it
is_correct = self._get_llm_correctness(llm_detail)
details[original_index]['llm_evaluation'][
'llm_correct'] = is_correct
# Count LLM evaluation statistics
llm_evaluated += 1
if is_correct:
llm_correct += 1
# Update final_correct counter based on evaluation mode
if self.parallel:
# In parallel mode, either rule-based or LLM evaluations
# should be correct
if rule_correct or is_correct:
final_correct += 1
else:
# In cascade mode, if rule was incorrect but LLM
# correct, increment
# (rule correct samples are already counted
# in initial_correct)
if not rule_correct and is_correct:
final_correct += 1
# Calculate final accuracy
final_accuracy = (100 * final_correct /
len(predictions) if predictions else 0)
llm_accuracy = (100 * llm_correct /
llm_evaluated if llm_evaluated else 0)
self.logger.info(
f'Final evaluation: {final_correct}/{len(predictions)}'
f'correct ({final_accuracy:.2f}%)')
if llm_evaluated > 0:
self.logger.info(
f'LLM evaluation: {llm_correct}/{llm_evaluated} '
f'correct ({llm_accuracy:.2f}%)')
# Append cascade correctness flag to each sample
for item in details:
_rule_correct = item['rule_evaluation'].get('correct', False)
if 'llm_evaluation' in item:
_llm_correct = item['llm_evaluation'].get(
'llm_correct', False)
else:
_llm_correct = False
item['cascade_correct'] = _rule_correct or _llm_correct
result = {
'accuracy': final_accuracy,
'cascade_stats': {
'total_samples': len(predictions),
'rule_correct': initial_correct,
'rule_accuracy': initial_accuracy,
'llm_evaluated': llm_evaluated,
'llm_correct': llm_correct,
'llm_accuracy': llm_accuracy,
'final_correct': final_correct,
'final_accuracy': final_accuracy,
'parallel_mode': self.parallel,
},
'details': details,
}
return result