OpenCompass/opencompass/datasets/healthbench/healthbench_meta_eval_test.py

166 lines
6.5 KiB
Python
Raw Normal View History

2025-05-15 16:50:05 +08:00
from . import healthbench_meta_eval
def test_compute_agreement_for_rater_by_class():
self_pred_list = [True, False, True]
other_preds_list = [[True, True, False], [True, False], [False]]
cluster_list = ['a', 'a', 'b']
model_or_physician = 'model'
metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class(
self_pred_list, other_preds_list, cluster_list, model_or_physician
)
# precision overall
index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='precision', pred_str='pos'
)
index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='precision', pred_str='neg'
)
overall_pos_precision = metrics[index_str_pos_precision]
overall_neg_precision = metrics[index_str_neg_precision]
expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1)
expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0)
assert overall_pos_precision['value'] == expected_overall_pos_precision
assert overall_neg_precision['value'] == expected_overall_neg_precision
assert overall_pos_precision['n'] == 4
assert overall_neg_precision['n'] == 2
# recall overall
index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='recall', pred_str='pos'
)
index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='recall', pred_str='neg'
)
overall_pos_recall = metrics[index_str_pos_recall]
overall_neg_recall = metrics[index_str_neg_recall]
expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0)
expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1)
assert overall_pos_recall['value'] == expected_overall_pos_recall
assert overall_neg_recall['value'] == expected_overall_neg_recall
assert overall_pos_recall['n'] == 3
assert overall_neg_recall['n'] == 3
# f1 overall
index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='f1', pred_str='pos'
)
index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='f1', pred_str='neg'
)
overall_pos_f1 = metrics[index_str_pos_f1]
overall_neg_f1 = metrics[index_str_neg_f1]
expected_overall_pos_f1 = (
2
* expected_overall_pos_precision
* expected_overall_pos_recall
/ (expected_overall_pos_precision + expected_overall_pos_recall)
)
expected_overall_neg_f1 = (
2
* expected_overall_neg_precision
* expected_overall_neg_recall
/ (expected_overall_neg_precision + expected_overall_neg_recall)
)
assert overall_pos_f1['value'] == expected_overall_pos_f1
assert overall_neg_f1['value'] == expected_overall_neg_f1
# balanced f1
index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
model_or_physician=model_or_physician, metric='f1', pred_str='balanced'
)
balanced_f1 = metrics[index_str_balanced_f1]
expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2
assert balanced_f1['value'] == expected_balanced_f1
# by cluster
# precision
cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='a', index_str=index_str_pos_precision
)
cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='a', index_str=index_str_neg_precision
)
cluster_a_pos_precision = metrics[cluster_a_str_pos_precision]
cluster_a_neg_precision = metrics[cluster_a_str_neg_precision]
assert cluster_a_pos_precision['value'] == (
# example 1, 2 in order
(2 + 0) / (3 + 0)
)
assert cluster_a_neg_precision['value'] == (
# example 1, 2 in order
(0 + 1) / (0 + 2)
)
assert cluster_a_pos_precision['n'] == 3
assert cluster_a_neg_precision['n'] == 2
# recall
cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='a', index_str=index_str_pos_recall
)
cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='a', index_str=index_str_neg_recall
)
cluster_a_pos_recall = metrics[cluster_a_str_pos_recall]
cluster_a_neg_recall = metrics[cluster_a_str_neg_recall]
assert cluster_a_pos_recall['value'] == (
# example 1, 2 in order
(2 + 0) / (2 + 1)
)
assert cluster_a_neg_recall['value'] == (
# example 1, 2 in order
(0 + 1) / (1 + 1)
)
assert cluster_a_pos_recall['n'] == 3
assert cluster_a_neg_recall['n'] == 2
# cluster B
# precision
cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_pos_precision
)
cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_neg_precision
)
cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision]
assert cluster_b_str_neg_precision not in metrics
assert cluster_b_str_pos_precision['value'] == (
# example 3 only
0 / 1
)
assert cluster_b_str_pos_precision['n'] == 1
# recall
cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_pos_recall
)
cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_neg_recall
)
assert cluster_b_str_pos_recall not in metrics
cluster_b_neg_recall = metrics[cluster_b_str_neg_recall]
assert cluster_b_neg_recall['value'] == (
# example 3 only
0 / 1
)
assert cluster_b_neg_recall['n'] == 1
# f1
index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_pos_f1
)
index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_neg_f1
)
index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
cluster='b', index_str=index_str_balanced_f1
)
assert index_str_pos_f1 not in metrics
assert index_str_neg_f1 not in metrics
assert index_str_balanced_f1 not in metrics
if __name__ == '__main__':
test_compute_agreement_for_rater_by_class()