mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
374 lines
13 KiB
Python
374 lines
13 KiB
Python
import json
|
|
import multiprocessing
|
|
import os
|
|
import pickle
|
|
import threading
|
|
import time
|
|
from collections import Counter, defaultdict
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from warnings import warn
|
|
|
|
import numpy as np
|
|
from termcolor import cprint
|
|
from tqdm import tqdm
|
|
|
|
from opencompass.datasets.evalplus.codegen import run_codegen
|
|
from opencompass.datasets.evalplus.config import *
|
|
from opencompass.datasets.evalplus.data import (
|
|
get_human_eval_plus,
|
|
get_human_eval_plus_hash,
|
|
get_mbpp_plus,
|
|
get_mbpp_plus_hash,
|
|
load_solutions,
|
|
)
|
|
from opencompass.datasets.evalplus.data.mbpp import mbpp_serialize_inputs
|
|
from opencompass.datasets.evalplus.data.utils import CACHE_DIR
|
|
from opencompass.datasets.evalplus.eval import (
|
|
PASS,
|
|
compatible_eval_result,
|
|
estimate_pass_at_k,
|
|
untrusted_check,
|
|
)
|
|
from opencompass.datasets.evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
|
|
from opencompass.datasets.evalplus.gen.util import trusted_exec
|
|
|
|
# 1st item: the status
|
|
# 2nd item (optional): the detailed pass/fail boolean for each input
|
|
Result = Tuple[str, List[bool]]
|
|
|
|
def get_groundtruth(problems, hashcode, tasks_only_output_not_none):
|
|
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
|
|
if os.path.exists(cache_file):
|
|
print(f"Load from ground-truth from {cache_file}")
|
|
with open(cache_file, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
print("Computing expected output...")
|
|
tbegin = time.time()
|
|
expected_output = {}
|
|
for task_id, problem in problems.items():
|
|
oracle = {}
|
|
oracle["base"], oracle["base_time"] = trusted_exec(
|
|
problem["prompt"] + problem["canonical_solution"],
|
|
problem["base_input"],
|
|
problem["entry_point"],
|
|
record_time=True,
|
|
output_not_none=problem["entry_point"] in tasks_only_output_not_none,
|
|
)
|
|
|
|
oracle["plus"], oracle["plus_time"] = trusted_exec(
|
|
problem["prompt"] + problem["canonical_solution"],
|
|
problem["plus_input"],
|
|
problem["entry_point"],
|
|
record_time=True,
|
|
output_not_none=problem["entry_point"] in tasks_only_output_not_none,
|
|
)
|
|
expected_output[task_id] = oracle
|
|
#print('oracle', oracle)
|
|
print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
|
|
|
|
with open(cache_file, "wb") as f:
|
|
pickle.dump(expected_output, f)
|
|
# print('get_groundtruth', expected_output)
|
|
return expected_output
|
|
|
|
def check_correctness(
|
|
dataset: str,
|
|
completion_id: int,
|
|
problem: Dict[str, Any],
|
|
solution: str,
|
|
expected_output: Dict[str, List],
|
|
base_only=False,
|
|
fast_check=False,
|
|
identifier=None,
|
|
min_time_limit: float = DEFAULT_MIN_TIME_LIMIT,
|
|
gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR,
|
|
) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details)
|
|
ret = {
|
|
"completion_id": completion_id,
|
|
"task_id": problem["task_id"],
|
|
"_identifier": identifier,
|
|
"solution": solution,
|
|
}
|
|
ret["base"] = untrusted_check(
|
|
dataset,
|
|
solution,
|
|
problem["base_input"],
|
|
problem["entry_point"],
|
|
expected=expected_output["base"],
|
|
atol=problem["atol"],
|
|
ref_time=expected_output["base_time"],
|
|
fast_check=fast_check,
|
|
min_time_limit=min_time_limit,
|
|
gt_time_limit_factor=gt_time_limit_factor,
|
|
)
|
|
|
|
if not base_only:
|
|
ret["plus"] = untrusted_check(
|
|
dataset,
|
|
solution,
|
|
problem["plus_input"],
|
|
problem["entry_point"],
|
|
expected=expected_output["plus"],
|
|
atol=problem["atol"],
|
|
ref_time=expected_output["plus_time"],
|
|
fast_check=fast_check,
|
|
min_time_limit=min_time_limit,
|
|
gt_time_limit_factor=gt_time_limit_factor,
|
|
)
|
|
|
|
return ret
|
|
|
|
|
|
def evaluate(
|
|
dataset: str,
|
|
samples: Optional[str] = None,
|
|
base_only: bool = False,
|
|
parallel: Optional[int] = None,
|
|
i_just_wanna_run: bool = False,
|
|
test_details: bool = False,
|
|
min_time_limit: float = DEFAULT_MIN_TIME_LIMIT,
|
|
gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR,
|
|
mini: bool = False,
|
|
noextreme: bool = False,
|
|
version: str = "default",
|
|
output_file: Optional[str] = None,
|
|
lang: str = "ar",
|
|
**model_kwargs,
|
|
):
|
|
if model_kwargs:
|
|
print('model_kwargs', model_kwargs)
|
|
# To suppress the warning of tokenizers
|
|
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
|
|
"TOKENIZERS_PARALLELISM", "false"
|
|
)
|
|
samples = run_codegen(
|
|
dataset=dataset,
|
|
**model_kwargs,
|
|
)
|
|
assert samples is not None, "No samples provided"
|
|
|
|
n_workers = parallel or max(1, multiprocessing.cpu_count() // 2)
|
|
|
|
if os.path.isdir(samples):
|
|
result_path = os.path.join(samples, "eval_results.json")
|
|
else:
|
|
assert samples.endswith(".jsonl")
|
|
result_path = samples.replace(".jsonl", "_eval_results.json")
|
|
|
|
if output_file is not None:
|
|
result_path = output_file
|
|
|
|
if os.path.isfile(result_path) and not i_just_wanna_run:
|
|
print(f"Load from previous results from {result_path}")
|
|
with open(result_path, "r") as f:
|
|
results = json.load(f)
|
|
|
|
results = compatible_eval_result(results)
|
|
else:
|
|
if dataset == "humaneval":
|
|
print('get_human_eval_plus')
|
|
problems = get_human_eval_plus(
|
|
mini=mini, noextreme=noextreme, version=version, lang=lang
|
|
)
|
|
dataset_hash = get_human_eval_plus_hash(
|
|
mini=mini, noextreme=noextreme, version=version
|
|
)
|
|
expected_output = get_groundtruth(problems, dataset_hash, [])
|
|
elif dataset == "mbpp":
|
|
problems = get_mbpp_plus(mini=mini, noextreme=noextreme, version=version)
|
|
dataset_hash = get_mbpp_plus_hash(
|
|
mini=mini, noextreme=noextreme, version=version
|
|
)
|
|
expected_output = get_groundtruth(
|
|
problems,
|
|
dataset_hash,
|
|
MBPP_OUTPUT_NOT_NONE_TASKS,
|
|
)
|
|
|
|
results = {
|
|
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
|
"hash": dataset_hash,
|
|
"eval": {},
|
|
}
|
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor:
|
|
futures = []
|
|
completion_id = Counter()
|
|
n_samples = 0
|
|
eval_results = defaultdict(list) # task_id ->
|
|
remainings = set()
|
|
|
|
print("Reading samples...")
|
|
print('samples', samples)
|
|
for sample in tqdm(load_solutions(samples)):
|
|
task_id = sample["task_id"]
|
|
if task_id not in problems:
|
|
warn(
|
|
f"Task {task_id} is found in the samples but not found in the dataset"
|
|
)
|
|
continue
|
|
solution = (
|
|
sample["solution"]
|
|
if "solution" in sample
|
|
else problems[task_id]["prompt"] + sample["completion"]
|
|
)
|
|
remainings.add(sample["_identifier"])
|
|
args = (
|
|
dataset,
|
|
completion_id[task_id],
|
|
problems[task_id],
|
|
solution,
|
|
expected_output[task_id],
|
|
base_only,
|
|
not test_details, # fast_check
|
|
sample["_identifier"],
|
|
min_time_limit,
|
|
gt_time_limit_factor,
|
|
)
|
|
futures.append(executor.submit(check_correctness, *args))
|
|
completion_id[task_id] += 1
|
|
n_samples += 1
|
|
|
|
assert n_samples == len(remainings), "Missing problems in unfinished"
|
|
assert len(completion_id) == len(problems), "Missing problems in samples"
|
|
|
|
def stucking_checker():
|
|
while remainings:
|
|
last_size = len(remainings)
|
|
time.sleep(20)
|
|
if last_size != len(remainings) or len(remainings) == 0:
|
|
continue
|
|
# Potential stucking
|
|
warn("No samples had finished testing in the last 20s")
|
|
warn(f"{len(remainings)} samples to be tested: {remainings}")
|
|
|
|
threading.Thread(target=stucking_checker).start()
|
|
|
|
for future in tqdm(as_completed(futures), total=n_samples):
|
|
result = future.result()
|
|
print('result 253', result)
|
|
remainings.remove(result["_identifier"])
|
|
eval_results[result["task_id"]].append(result)
|
|
|
|
# sort the results for each problem by completion_id
|
|
for task_id, task_results in eval_results.items():
|
|
task_results.sort(key=lambda x: x["completion_id"])
|
|
results["eval"][task_id] = []
|
|
for res in task_results:
|
|
|
|
def get_failed_tests(stat, details, inputs) -> List[Any]:
|
|
if stat == PASS or not details:
|
|
return []
|
|
|
|
if test_details:
|
|
return [
|
|
inputs[i] for i in range(len(details)) if not details[i]
|
|
]
|
|
|
|
# else => simply return the only and the last fail test
|
|
return [inputs[len(details) - 1]]
|
|
|
|
base_stat, base_details = res["base"]
|
|
base_fail_tests = get_failed_tests(
|
|
base_stat, base_details, problems[task_id]["base_input"]
|
|
)
|
|
|
|
# initialize plus tests
|
|
plus_stat = None
|
|
plus_fail_tests = []
|
|
|
|
# with plus tests
|
|
if not base_only:
|
|
plus_stat, plus_details = res["plus"]
|
|
plus_fail_tests = get_failed_tests(
|
|
plus_stat, plus_details, problems[task_id]["plus_input"]
|
|
)
|
|
|
|
if dataset == "mbpp":
|
|
base_fail_tests = mbpp_serialize_inputs(task_id, base_fail_tests)
|
|
plus_fail_tests = mbpp_serialize_inputs(task_id, plus_fail_tests)
|
|
|
|
results["eval"][task_id].append(
|
|
{
|
|
"task_id": task_id,
|
|
"solution": res["solution"],
|
|
"base_status": base_stat,
|
|
"plus_status": plus_stat,
|
|
"base_fail_tests": base_fail_tests,
|
|
"plus_fail_tests": plus_fail_tests,
|
|
}
|
|
)
|
|
|
|
# Calculate pass@k.
|
|
total = np.array([len(r) for r in results["eval"].values()])
|
|
base_correct = []
|
|
new_correct = []
|
|
|
|
for res in results["eval"].values():
|
|
bc = sum([r["base_status"] == PASS for r in res])
|
|
base_correct.append(bc)
|
|
if not base_only:
|
|
new_correct.append(
|
|
sum(
|
|
[
|
|
res[i]["base_status"] == res[i]["plus_status"] == PASS
|
|
for i in range(len(res))
|
|
]
|
|
)
|
|
)
|
|
base_correct = np.array(base_correct)
|
|
|
|
pass_at_k = {
|
|
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
|
|
for k in [1, 10, 100]
|
|
if total.min() >= k
|
|
}
|
|
cprint(f"{dataset} (base tests)", "red")
|
|
for k, v in pass_at_k.items():
|
|
cprint(f"{k}:\t{v:.3f}", "red")
|
|
results["pass_at_k"] = {"base": pass_at_k}
|
|
|
|
if new_correct:
|
|
cprint(f"{dataset}+ (base + extra tests)", "green")
|
|
pass_at_k = {
|
|
f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean()
|
|
for k in [1, 10, 100]
|
|
if (total >= k).all()
|
|
}
|
|
for k, v in pass_at_k.items():
|
|
cprint(f"{k}:\t{v:.3f}", "green")
|
|
results["pass_at_k"]["plus"] = pass_at_k
|
|
|
|
# save results
|
|
if os.path.isfile(result_path) and i_just_wanna_run:
|
|
decision = ""
|
|
while decision.lower() not in ["y", "n"]:
|
|
print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...")
|
|
decision = input()
|
|
|
|
if decision.lower() == "y":
|
|
# mv the file to a backup
|
|
new_path = result_path + ".bak"
|
|
while os.path.isfile(new_path):
|
|
new_path += ".bak"
|
|
os.rename(result_path, new_path)
|
|
print(f"Backup {result_path} to {new_path}")
|
|
|
|
if not os.path.isfile(result_path):
|
|
with open(result_path, "w") as f:
|
|
json.dump(results, f)
|
|
|
|
|
|
def main():
|
|
from fire import Fire
|
|
|
|
Fire(evaluate)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|