mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
166 lines
6.5 KiB
Python
166 lines
6.5 KiB
Python
![]() |
from . import healthbench_meta_eval
|
||
|
|
||
|
|
||
|
def test_compute_agreement_for_rater_by_class():
|
||
|
self_pred_list = [True, False, True]
|
||
|
other_preds_list = [[True, True, False], [True, False], [False]]
|
||
|
cluster_list = ['a', 'a', 'b']
|
||
|
model_or_physician = 'model'
|
||
|
metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class(
|
||
|
self_pred_list, other_preds_list, cluster_list, model_or_physician
|
||
|
)
|
||
|
|
||
|
# precision overall
|
||
|
index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='precision', pred_str='pos'
|
||
|
)
|
||
|
index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='precision', pred_str='neg'
|
||
|
)
|
||
|
overall_pos_precision = metrics[index_str_pos_precision]
|
||
|
overall_neg_precision = metrics[index_str_neg_precision]
|
||
|
expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1)
|
||
|
expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0)
|
||
|
assert overall_pos_precision['value'] == expected_overall_pos_precision
|
||
|
assert overall_neg_precision['value'] == expected_overall_neg_precision
|
||
|
assert overall_pos_precision['n'] == 4
|
||
|
assert overall_neg_precision['n'] == 2
|
||
|
|
||
|
# recall overall
|
||
|
index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='recall', pred_str='pos'
|
||
|
)
|
||
|
index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='recall', pred_str='neg'
|
||
|
)
|
||
|
overall_pos_recall = metrics[index_str_pos_recall]
|
||
|
overall_neg_recall = metrics[index_str_neg_recall]
|
||
|
expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0)
|
||
|
expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1)
|
||
|
assert overall_pos_recall['value'] == expected_overall_pos_recall
|
||
|
assert overall_neg_recall['value'] == expected_overall_neg_recall
|
||
|
assert overall_pos_recall['n'] == 3
|
||
|
assert overall_neg_recall['n'] == 3
|
||
|
|
||
|
# f1 overall
|
||
|
index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='f1', pred_str='pos'
|
||
|
)
|
||
|
index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='f1', pred_str='neg'
|
||
|
)
|
||
|
overall_pos_f1 = metrics[index_str_pos_f1]
|
||
|
overall_neg_f1 = metrics[index_str_neg_f1]
|
||
|
expected_overall_pos_f1 = (
|
||
|
2
|
||
|
* expected_overall_pos_precision
|
||
|
* expected_overall_pos_recall
|
||
|
/ (expected_overall_pos_precision + expected_overall_pos_recall)
|
||
|
)
|
||
|
expected_overall_neg_f1 = (
|
||
|
2
|
||
|
* expected_overall_neg_precision
|
||
|
* expected_overall_neg_recall
|
||
|
/ (expected_overall_neg_precision + expected_overall_neg_recall)
|
||
|
)
|
||
|
assert overall_pos_f1['value'] == expected_overall_pos_f1
|
||
|
assert overall_neg_f1['value'] == expected_overall_neg_f1
|
||
|
|
||
|
# balanced f1
|
||
|
index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||
|
model_or_physician=model_or_physician, metric='f1', pred_str='balanced'
|
||
|
)
|
||
|
balanced_f1 = metrics[index_str_balanced_f1]
|
||
|
expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2
|
||
|
assert balanced_f1['value'] == expected_balanced_f1
|
||
|
|
||
|
# by cluster
|
||
|
# precision
|
||
|
cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='a', index_str=index_str_pos_precision
|
||
|
)
|
||
|
cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='a', index_str=index_str_neg_precision
|
||
|
)
|
||
|
cluster_a_pos_precision = metrics[cluster_a_str_pos_precision]
|
||
|
cluster_a_neg_precision = metrics[cluster_a_str_neg_precision]
|
||
|
assert cluster_a_pos_precision['value'] == (
|
||
|
# example 1, 2 in order
|
||
|
(2 + 0) / (3 + 0)
|
||
|
)
|
||
|
assert cluster_a_neg_precision['value'] == (
|
||
|
# example 1, 2 in order
|
||
|
(0 + 1) / (0 + 2)
|
||
|
)
|
||
|
assert cluster_a_pos_precision['n'] == 3
|
||
|
assert cluster_a_neg_precision['n'] == 2
|
||
|
|
||
|
# recall
|
||
|
cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='a', index_str=index_str_pos_recall
|
||
|
)
|
||
|
cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='a', index_str=index_str_neg_recall
|
||
|
)
|
||
|
cluster_a_pos_recall = metrics[cluster_a_str_pos_recall]
|
||
|
cluster_a_neg_recall = metrics[cluster_a_str_neg_recall]
|
||
|
assert cluster_a_pos_recall['value'] == (
|
||
|
# example 1, 2 in order
|
||
|
(2 + 0) / (2 + 1)
|
||
|
)
|
||
|
assert cluster_a_neg_recall['value'] == (
|
||
|
# example 1, 2 in order
|
||
|
(0 + 1) / (1 + 1)
|
||
|
)
|
||
|
assert cluster_a_pos_recall['n'] == 3
|
||
|
assert cluster_a_neg_recall['n'] == 2
|
||
|
|
||
|
# cluster B
|
||
|
# precision
|
||
|
cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_pos_precision
|
||
|
)
|
||
|
cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_neg_precision
|
||
|
)
|
||
|
cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision]
|
||
|
assert cluster_b_str_neg_precision not in metrics
|
||
|
assert cluster_b_str_pos_precision['value'] == (
|
||
|
# example 3 only
|
||
|
0 / 1
|
||
|
)
|
||
|
assert cluster_b_str_pos_precision['n'] == 1
|
||
|
|
||
|
# recall
|
||
|
cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_pos_recall
|
||
|
)
|
||
|
cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_neg_recall
|
||
|
)
|
||
|
assert cluster_b_str_pos_recall not in metrics
|
||
|
cluster_b_neg_recall = metrics[cluster_b_str_neg_recall]
|
||
|
assert cluster_b_neg_recall['value'] == (
|
||
|
# example 3 only
|
||
|
0 / 1
|
||
|
)
|
||
|
assert cluster_b_neg_recall['n'] == 1
|
||
|
|
||
|
# f1
|
||
|
index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_pos_f1
|
||
|
)
|
||
|
index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_neg_f1
|
||
|
)
|
||
|
index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||
|
cluster='b', index_str=index_str_balanced_f1
|
||
|
)
|
||
|
assert index_str_pos_f1 not in metrics
|
||
|
assert index_str_neg_f1 not in metrics
|
||
|
assert index_str_balanced_f1 not in metrics
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_compute_agreement_for_rater_by_class()
|