fix gpass compare error of list k

This commit is contained in:
jnanliu 2025-04-08 10:16:24 +00:00
parent a05f9da134
commit b63f998459

View File

@ -158,9 +158,10 @@ class BaseEvaluator:
can_calculate = True can_calculate = True
c += int(example['detail']['is_correct']) c += int(example['detail']['is_correct'])
if can_calculate and n > 1 and k > 1: k_list = [k] if isinstance(k, int) else k
if can_calculate and n > 1 and max(k_list) > 1:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0] thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in [k] if isinstance(k, int) else k: for _k in k_list:
for threshold in thresholds: for threshold in thresholds:
g_pass = compute_g_pass_at_k(n=n, g_pass = compute_g_pass_at_k(n=n,
c=c, c=c,