"""This script evaluates a grader model on grading HealthBench rubrics. It effectively evaluates the evaluator against physician opinion, so we call it a meta-evaluation. To run, use the following command (working directory should contain simple- evals folder): `python -m simple-evals.simple_evals --eval=healthbench_meta --model=gpt-4.1` """ import json import random from collections import defaultdict from typing import Literal import blobfile as bf from . import common from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict from .types import Eval, EvalResult, SamplerBase, SingleEvalResult INPUT_PATH = 'https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl' INDEX_STR_TEMPLATE = 'pairwise_{model_or_physician}_{metric}_{pred_str}' CLUSTER_STR_TEMPLATE = '{cluster}: {index_str}' HEALTHBENCH_META_HTML_JINJA = (common.HTML_JINJA.replace( '

Correct Answer: {{ correct_answer }}

\n', '', ) + "

Explanation for grader's label: {{ explanation }}

") class HealthBenchMetaEval(Eval): def __init__( self, grader_model: SamplerBase, num_examples: int | None = None, n_threads: int = 120, n_repeats: int = 1, ): with bf.BlobFile(INPUT_PATH, 'rb') as f: examples = [json.loads(line) for line in f] print(f'Loaded {len(examples)} examples from {INPUT_PATH}') rng = random.Random(0) if num_examples is not None and len(examples) > num_examples: examples = rng.sample(examples, num_examples) self.examples = examples * n_repeats self.grader_model = grader_model self.n_threads = n_threads def grade_sample( self, grading_response_dict: dict, physician_labels: list[bool], category: str, ) -> tuple[dict, bool | None, str]: metrics = { 'num_physician_labels': len(physician_labels), 'percent_physician_pos': sum(physician_labels) / len(physician_labels), } grader_label = grading_response_dict['criteria_met'] assert grader_label is True or grader_label is False metrics['model_predicted_positive'] = grader_label explanation = grading_response_dict.get('explanation', 'No explanation provided') category_metrics = {f'{category}: {k}': v for k, v in metrics.items()} metrics = {**metrics, **category_metrics} return metrics, grader_label, explanation def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict) -> tuple[SingleEvalResult, bool | None]: convo_with_response = row['prompt'] + [ dict(content=row['completion'], role='assistant') ] prompt_str = '\n\n'.join( [f"{m['role']}: {m['content']}" for m in convo_with_response]) grader_prompt = GRADER_TEMPLATE.replace('<>', prompt_str) grader_prompt = grader_prompt.replace('<>', row['rubric']) grader_convo = [dict(content=grader_prompt, role='user')] while True: sampler_response = sampler(grader_convo) response_text = sampler_response.response_text actual_queried_grader_convo = ( sampler_response.actual_queried_message_list) grading_response_dict = parse_json_to_dict(response_text) if 'criteria_met' in grading_response_dict: label = grading_response_dict['criteria_met'] if label is True or label is False: break print('Grading failed due to bad JSON output, retrying...') metrics, grader_label, explanation = self.grade_sample( grading_response_dict=grading_response_dict, physician_labels=row['binary_labels'], category=row['category'], ) score = metrics['model_predicted_positive'] # Create HTML for each sample result html = common.jinja_env.from_string( HEALTHBENCH_META_HTML_JINJA).render( prompt_messages=actual_queried_grader_convo, next_message=dict(content=response_text, role='assistant'), score=metrics['model_predicted_positive'], extracted_answer=response_text, explanation=explanation, ) convo = actual_queried_grader_convo + [ dict(content=response_text, role='assistant') ] return ( SingleEvalResult(html=html, score=score, convo=convo, metrics=metrics), grader_label, ) # Run evaluation and collect results all_outputs = common.map_with_progress(fn, self.examples, self.n_threads) results: list[SingleEvalResult] grader_labels: list[bool] results, grader_labels = zip(*all_outputs) # model pairwise agreement metrics model_agreement_metrics = compute_metrics_for_rater_by_class( self_pred_list=grader_labels, other_preds_list=[x['binary_labels'] for x in self.examples], cluster_list=[x['category'] for x in self.examples], model_or_physician='model', ) # physicians: physician_rating_lists = defaultdict(lambda: ([], [], [])) for example in self.examples: for i in range(len(example['binary_labels'])): physician_id = example['anonymized_physician_ids'][i] self_pred = example['binary_labels'][i] other_preds = (example['binary_labels'][:i] + example['binary_labels'][i + 1:]) cluster = example['category'] physician_rating_lists[physician_id][0].append(self_pred) physician_rating_lists[physician_id][1].append(other_preds) physician_rating_lists[physician_id][2].append(cluster) physician_agreement_metric_lists = defaultdict(dict) for physician_id, ( physician_rating_list, other_preds_list, cluster_list, ) in physician_rating_lists.items(): physician_agreement_metrics = compute_metrics_for_rater_by_class( self_pred_list=physician_rating_list, other_preds_list=other_preds_list, cluster_list=cluster_list, model_or_physician='physician', ) for k, v in physician_agreement_metrics.items(): physician_agreement_metric_lists[k][physician_id] = v # consolidate final metrics and add agreement metrics final_metrics = common.aggregate_results( results, default_stats=('mean', 'n_samples', 'bootstrap_std')) model_agreement_metrics_condensed: dict[str, float] = { k: v['value'] for k, v in model_agreement_metrics.items() if v['value'] is not None } assert final_metrics.metrics is not None final_metrics.metrics.update(model_agreement_metrics_condensed) final_metrics.score = final_metrics.metrics[ 'pairwise_model_f1_balanced'] final_metrics.metadata = { 'model_agreement_metrics': model_agreement_metrics, 'physician_agreement_metric_lists': physician_agreement_metric_lists, } return final_metrics def compute_metrics_for_rater_by_class( self_pred_list: list[bool], other_preds_list: list[list[bool]], cluster_list: list[str], model_or_physician: Literal['model', 'physician'], ) -> dict[str, dict[str, float | None]]: # get all the metrics for each cluster metric_lists = defaultdict(list) for self_pred, other_preds, cluster in zip(self_pred_list, other_preds_list, cluster_list, strict=True): self_pred_str = 'pos' if self_pred else 'neg' for other_pred in other_preds: # precision. based on the grader's labels - # i.e., calculated as TP / (TP + FP) # so a prediction should be recorded whenever self_pred is True precision_index_str = INDEX_STR_TEMPLATE.format( model_or_physician=model_or_physician, metric='precision', pred_str=self_pred_str, ) metric_lists[precision_index_str].append(self_pred == other_pred) precision_cluster_str = CLUSTER_STR_TEMPLATE.format( cluster=cluster, index_str=precision_index_str) metric_lists[precision_cluster_str].append(self_pred == other_pred) # recall. based on the ground truth labels - # i.e., calculated as TP / (TP + FN) # so a prediction should be recorded whenever other_pred is True other_pred_str = 'pos' if other_pred else 'neg' recall_index_str = INDEX_STR_TEMPLATE.format( model_or_physician=model_or_physician, metric='recall', pred_str=other_pred_str, ) metric_lists[recall_index_str].append(self_pred == other_pred) recall_cluster_str = CLUSTER_STR_TEMPLATE.format( cluster=cluster, index_str=recall_index_str) metric_lists[recall_cluster_str].append(self_pred == other_pred) metrics: dict[str, dict[str, float | None]] = {} for index_str, metric_list in metric_lists.items(): n = len(metric_list) metric = sum(metric_list) / n if n > 0 else None metrics[index_str] = { 'n': n, 'value': metric, } f1_metrics = get_f1_metrics(metrics) metrics.update(f1_metrics) balanced_metrics = get_balanced_metrics(metrics) metrics.update(balanced_metrics) return metrics def get_f1_metrics( metrics: dict[str, dict[str, float | None]], ) -> dict[str, dict[str, float | None]]: f1_metrics: dict[str, dict[str, float | None]] = {} for precision_key_name in metrics: if 'precision' in precision_key_name: recall_key_name = precision_key_name.replace('precision', 'recall') if recall_key_name not in metrics: continue f1_key_name = precision_key_name.replace('precision', 'f1') assert f1_key_name not in metrics f1_metrics[f1_key_name] = compute_f1_metric( precision=metrics[precision_key_name], recall=metrics[recall_key_name], ) return f1_metrics def compute_f1_metric( precision: dict[str, float | None], recall: dict[str, float | None], ) -> dict[str, float | None]: precision_n = precision['n'] recall_n = recall['n'] assert precision_n is not None and recall_n is not None, 'n_pos or n_neg is None' precision_metric = precision['value'] recall_metric = recall['value'] if precision_metric is None or recall_metric is None: f1_metric = None n_f1 = ( precision_n + recall_n ) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting elif precision_metric == 0 and recall_metric == 0: f1_metric = 0.0 tp = precision_metric * precision_n # because precision = TP / (TP+FP) n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP else: f1_metric = (2 * (precision_metric * recall_metric) / (precision_metric + recall_metric)) tp = precision_metric * precision_n # because precision = TP / (TP+FP) n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP return { 'n': n_f1, 'value': f1_metric, } def get_balanced_metrics( metrics: dict[str, dict[str, float | None]], ) -> dict[str, dict[str, float | None]]: balanced_metrics: dict[str, dict[str, float | None]] = {} for pos_key_name in metrics: if 'pos' in pos_key_name: neg_key_name = pos_key_name.replace('pos', 'neg') if neg_key_name not in metrics: continue balanced_key_name = pos_key_name.replace('pos', 'balanced') assert balanced_key_name not in metrics balanced_metrics[balanced_key_name] = compute_balanced_metric( metric_pos=metrics[pos_key_name], metric_neg=metrics[neg_key_name], ) return balanced_metrics def compute_balanced_metric( metric_pos: dict[str, float | None], metric_neg: dict[str, float | None], ) -> dict[str, float | None]: n_pos = metric_pos['n'] n_neg = metric_neg['n'] assert n_pos is not None and n_neg is not None, 'n_pos or n_neg is None' pos_metric = metric_pos['value'] neg_metric = metric_neg['value'] if pos_metric is None or neg_metric is None: metric = None else: metric = (pos_metric + neg_metric) / 2 return { 'n': n_pos + n_neg, # note: this overcounts samples going towards the balanced F1 'value': metric, }