"""This script evaluates a grader model on grading HealthBench rubrics. It
effectively evaluates the evaluator against physician opinion, so we call it a
meta-evaluation.
To run, use the following command (working directory should contain simple-
evals folder): `python -m simple-evals.simple_evals --eval=healthbench_meta
--model=gpt-4.1`
"""
import json
import random
from collections import defaultdict
from typing import Literal
import blobfile as bf
from . import common
from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
INPUT_PATH = 'https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl'
INDEX_STR_TEMPLATE = 'pairwise_{model_or_physician}_{metric}_{pred_str}'
CLUSTER_STR_TEMPLATE = '{cluster}: {index_str}'
HEALTHBENCH_META_HTML_JINJA = (common.HTML_JINJA.replace(
'
Correct Answer: {{ correct_answer }}
\n',
'',
) + "Explanation for grader's label: {{ explanation }}
")
class HealthBenchMetaEval(Eval):
def __init__(
self,
grader_model: SamplerBase,
num_examples: int | None = None,
n_threads: int = 120,
n_repeats: int = 1,
):
with bf.BlobFile(INPUT_PATH, 'rb') as f:
examples = [json.loads(line) for line in f]
print(f'Loaded {len(examples)} examples from {INPUT_PATH}')
rng = random.Random(0)
if num_examples is not None and len(examples) > num_examples:
examples = rng.sample(examples, num_examples)
self.examples = examples * n_repeats
self.grader_model = grader_model
self.n_threads = n_threads
def grade_sample(
self,
grading_response_dict: dict,
physician_labels: list[bool],
category: str,
) -> tuple[dict, bool | None, str]:
metrics = {
'num_physician_labels': len(physician_labels),
'percent_physician_pos':
sum(physician_labels) / len(physician_labels),
}
grader_label = grading_response_dict['criteria_met']
assert grader_label is True or grader_label is False
metrics['model_predicted_positive'] = grader_label
explanation = grading_response_dict.get('explanation',
'No explanation provided')
category_metrics = {f'{category}: {k}': v for k, v in metrics.items()}
metrics = {**metrics, **category_metrics}
return metrics, grader_label, explanation
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict) -> tuple[SingleEvalResult, bool | None]:
convo_with_response = row['prompt'] + [
dict(content=row['completion'], role='assistant')
]
prompt_str = '\n\n'.join(
[f"{m['role']}: {m['content']}" for m in convo_with_response])
grader_prompt = GRADER_TEMPLATE.replace('<>',
prompt_str)
grader_prompt = grader_prompt.replace('<>',
row['rubric'])
grader_convo = [dict(content=grader_prompt, role='user')]
while True:
sampler_response = sampler(grader_convo)
response_text = sampler_response.response_text
actual_queried_grader_convo = (
sampler_response.actual_queried_message_list)
grading_response_dict = parse_json_to_dict(response_text)
if 'criteria_met' in grading_response_dict:
label = grading_response_dict['criteria_met']
if label is True or label is False:
break
print('Grading failed due to bad JSON output, retrying...')
metrics, grader_label, explanation = self.grade_sample(
grading_response_dict=grading_response_dict,
physician_labels=row['binary_labels'],
category=row['category'],
)
score = metrics['model_predicted_positive']
# Create HTML for each sample result
html = common.jinja_env.from_string(
HEALTHBENCH_META_HTML_JINJA).render(
prompt_messages=actual_queried_grader_convo,
next_message=dict(content=response_text, role='assistant'),
score=metrics['model_predicted_positive'],
extracted_answer=response_text,
explanation=explanation,
)
convo = actual_queried_grader_convo + [
dict(content=response_text, role='assistant')
]
return (
SingleEvalResult(html=html,
score=score,
convo=convo,
metrics=metrics),
grader_label,
)
# Run evaluation and collect results
all_outputs = common.map_with_progress(fn, self.examples,
self.n_threads)
results: list[SingleEvalResult]
grader_labels: list[bool]
results, grader_labels = zip(*all_outputs)
# model pairwise agreement metrics
model_agreement_metrics = compute_metrics_for_rater_by_class(
self_pred_list=grader_labels,
other_preds_list=[x['binary_labels'] for x in self.examples],
cluster_list=[x['category'] for x in self.examples],
model_or_physician='model',
)
# physicians:
physician_rating_lists = defaultdict(lambda: ([], [], []))
for example in self.examples:
for i in range(len(example['binary_labels'])):
physician_id = example['anonymized_physician_ids'][i]
self_pred = example['binary_labels'][i]
other_preds = (example['binary_labels'][:i] +
example['binary_labels'][i + 1:])
cluster = example['category']
physician_rating_lists[physician_id][0].append(self_pred)
physician_rating_lists[physician_id][1].append(other_preds)
physician_rating_lists[physician_id][2].append(cluster)
physician_agreement_metric_lists = defaultdict(dict)
for physician_id, (
physician_rating_list,
other_preds_list,
cluster_list,
) in physician_rating_lists.items():
physician_agreement_metrics = compute_metrics_for_rater_by_class(
self_pred_list=physician_rating_list,
other_preds_list=other_preds_list,
cluster_list=cluster_list,
model_or_physician='physician',
)
for k, v in physician_agreement_metrics.items():
physician_agreement_metric_lists[k][physician_id] = v
# consolidate final metrics and add agreement metrics
final_metrics = common.aggregate_results(
results, default_stats=('mean', 'n_samples', 'bootstrap_std'))
model_agreement_metrics_condensed: dict[str, float] = {
k: v['value']
for k, v in model_agreement_metrics.items()
if v['value'] is not None
}
assert final_metrics.metrics is not None
final_metrics.metrics.update(model_agreement_metrics_condensed)
final_metrics.score = final_metrics.metrics[
'pairwise_model_f1_balanced']
final_metrics.metadata = {
'model_agreement_metrics': model_agreement_metrics,
'physician_agreement_metric_lists':
physician_agreement_metric_lists,
}
return final_metrics
def compute_metrics_for_rater_by_class(
self_pred_list: list[bool],
other_preds_list: list[list[bool]],
cluster_list: list[str],
model_or_physician: Literal['model', 'physician'],
) -> dict[str, dict[str, float | None]]:
# get all the metrics for each cluster
metric_lists = defaultdict(list)
for self_pred, other_preds, cluster in zip(self_pred_list,
other_preds_list,
cluster_list,
strict=True):
self_pred_str = 'pos' if self_pred else 'neg'
for other_pred in other_preds:
# precision. based on the grader's labels -
# i.e., calculated as TP / (TP + FP)
# so a prediction should be recorded whenever self_pred is True
precision_index_str = INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician,
metric='precision',
pred_str=self_pred_str,
)
metric_lists[precision_index_str].append(self_pred == other_pred)
precision_cluster_str = CLUSTER_STR_TEMPLATE.format(
cluster=cluster, index_str=precision_index_str)
metric_lists[precision_cluster_str].append(self_pred == other_pred)
# recall. based on the ground truth labels -
# i.e., calculated as TP / (TP + FN)
# so a prediction should be recorded whenever other_pred is True
other_pred_str = 'pos' if other_pred else 'neg'
recall_index_str = INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician,
metric='recall',
pred_str=other_pred_str,
)
metric_lists[recall_index_str].append(self_pred == other_pred)
recall_cluster_str = CLUSTER_STR_TEMPLATE.format(
cluster=cluster, index_str=recall_index_str)
metric_lists[recall_cluster_str].append(self_pred == other_pred)
metrics: dict[str, dict[str, float | None]] = {}
for index_str, metric_list in metric_lists.items():
n = len(metric_list)
metric = sum(metric_list) / n if n > 0 else None
metrics[index_str] = {
'n': n,
'value': metric,
}
f1_metrics = get_f1_metrics(metrics)
metrics.update(f1_metrics)
balanced_metrics = get_balanced_metrics(metrics)
metrics.update(balanced_metrics)
return metrics
def get_f1_metrics(
metrics: dict[str, dict[str, float | None]],
) -> dict[str, dict[str, float | None]]:
f1_metrics: dict[str, dict[str, float | None]] = {}
for precision_key_name in metrics:
if 'precision' in precision_key_name:
recall_key_name = precision_key_name.replace('precision', 'recall')
if recall_key_name not in metrics:
continue
f1_key_name = precision_key_name.replace('precision', 'f1')
assert f1_key_name not in metrics
f1_metrics[f1_key_name] = compute_f1_metric(
precision=metrics[precision_key_name],
recall=metrics[recall_key_name],
)
return f1_metrics
def compute_f1_metric(
precision: dict[str, float | None],
recall: dict[str, float | None],
) -> dict[str, float | None]:
precision_n = precision['n']
recall_n = recall['n']
assert precision_n is not None and recall_n is not None, 'n_pos or n_neg is None'
precision_metric = precision['value']
recall_metric = recall['value']
if precision_metric is None or recall_metric is None:
f1_metric = None
n_f1 = (
precision_n + recall_n
) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting
elif precision_metric == 0 and recall_metric == 0:
f1_metric = 0.0
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
else:
f1_metric = (2 * (precision_metric * recall_metric) /
(precision_metric + recall_metric))
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
return {
'n': n_f1,
'value': f1_metric,
}
def get_balanced_metrics(
metrics: dict[str, dict[str, float | None]],
) -> dict[str, dict[str, float | None]]:
balanced_metrics: dict[str, dict[str, float | None]] = {}
for pos_key_name in metrics:
if 'pos' in pos_key_name:
neg_key_name = pos_key_name.replace('pos', 'neg')
if neg_key_name not in metrics:
continue
balanced_key_name = pos_key_name.replace('pos', 'balanced')
assert balanced_key_name not in metrics
balanced_metrics[balanced_key_name] = compute_balanced_metric(
metric_pos=metrics[pos_key_name],
metric_neg=metrics[neg_key_name],
)
return balanced_metrics
def compute_balanced_metric(
metric_pos: dict[str, float | None],
metric_neg: dict[str, float | None],
) -> dict[str, float | None]:
n_pos = metric_pos['n']
n_neg = metric_neg['n']
assert n_pos is not None and n_neg is not None, 'n_pos or n_neg is None'
pos_metric = metric_pos['value']
neg_metric = metric_neg['value']
if pos_metric is None or neg_metric is None:
metric = None
else:
metric = (pos_metric + neg_metric) / 2
return {
'n': n_pos + n_neg,
# note: this overcounts samples going towards the balanced F1
'value': metric,
}