This commit is contained in:
zhangsongyang 2025-05-14 08:30:23 +00:00
parent bfe693cc6f
commit 6ea181d15c
2 changed files with 16 additions and 1 deletions

View File

@ -233,7 +233,8 @@ class CascadeEvaluator(BaseEvaluator):
self.llm_evaluator.dataset_cfg = None self.llm_evaluator.dataset_cfg = None
# Apply prediction postprocessing to for LLM evaluator # Apply prediction postprocessing to for LLM evaluator
failed_predictions = self.llm_evaluator.pred_postprocess(failed_predictions) failed_predictions = self.llm_evaluator.pred_postprocess(
failed_predictions)
llm_results = self.llm_evaluator.score( llm_results = self.llm_evaluator.score(
predictions=failed_predictions, predictions=failed_predictions,
@ -308,6 +309,16 @@ class CascadeEvaluator(BaseEvaluator):
f'LLM evaluation: {llm_correct}/{llm_evaluated} ' f'LLM evaluation: {llm_correct}/{llm_evaluated} '
f'correct ({llm_accuracy:.2f}%)') f'correct ({llm_accuracy:.2f}%)')
# Append cascade correctness flag to each sample
for item in details:
_rule_correct = item['rule_evaluation'].get('correct', False)
if 'llm_evaluation' in item:
_llm_correct = item['llm_evaluation'].get(
'llm_correct', False)
else:
_llm_correct = False
item['cascade_correct'] = _rule_correct or _llm_correct
result = { result = {
'accuracy': final_accuracy, 'accuracy': final_accuracy,
'cascade_stats': { 'cascade_stats': {

View File

@ -182,6 +182,10 @@ class BaseEvaluator:
elif example['detail'].get('is_correct', None) is not None: elif example['detail'].get('is_correct', None) is not None:
can_calculate = True can_calculate = True
c += int(example['detail']['is_correct']) c += int(example['detail']['is_correct'])
elif example['detail'].get('cascade_correct',
None) is not None:
can_calculate = True
c += int(example['detail']['cascade_correct'])
k_list = [k] if isinstance(k, int) else k k_list = [k] if isinstance(k, int) else k
if can_calculate and n > 1 and max(k_list) > 1: if can_calculate and n > 1 and max(k_list) > 1: