import json import multiprocessing import os import pickle import threading import time from collections import Counter, defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime from typing import Any, Dict, List, Optional, Tuple from warnings import warn import numpy as np from termcolor import cprint from tqdm import tqdm from opencompass.datasets.evalplus.codegen import run_codegen from opencompass.datasets.evalplus.config import * from opencompass.datasets.evalplus.data import ( get_human_eval_plus, get_human_eval_plus_hash, get_mbpp_plus, get_mbpp_plus_hash, load_solutions, ) from opencompass.datasets.evalplus.data.mbpp import mbpp_serialize_inputs from opencompass.datasets.evalplus.data.utils import CACHE_DIR from opencompass.datasets.evalplus.eval import ( PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check, ) from opencompass.datasets.evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS from opencompass.datasets.evalplus.gen.util import trusted_exec # 1st item: the status # 2nd item (optional): the detailed pass/fail boolean for each input Result = Tuple[str, List[bool]] def get_groundtruth(problems, hashcode, tasks_only_output_not_none): cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") if os.path.exists(cache_file): print(f"Load from ground-truth from {cache_file}") with open(cache_file, "rb") as f: return pickle.load(f) os.makedirs(CACHE_DIR, exist_ok=True) print("Computing expected output...") tbegin = time.time() expected_output = {} for task_id, problem in problems.items(): oracle = {} oracle["base"], oracle["base_time"] = trusted_exec( problem["prompt"] + problem["canonical_solution"], problem["base_input"], problem["entry_point"], record_time=True, output_not_none=problem["entry_point"] in tasks_only_output_not_none, ) oracle["plus"], oracle["plus_time"] = trusted_exec( problem["prompt"] + problem["canonical_solution"], problem["plus_input"], problem["entry_point"], record_time=True, output_not_none=problem["entry_point"] in tasks_only_output_not_none, ) expected_output[task_id] = oracle #print('oracle', oracle) print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") with open(cache_file, "wb") as f: pickle.dump(expected_output, f) # print('get_groundtruth', expected_output) return expected_output def check_correctness( dataset: str, completion_id: int, problem: Dict[str, Any], solution: str, expected_output: Dict[str, List], base_only=False, fast_check=False, identifier=None, min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, ) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details) ret = { "completion_id": completion_id, "task_id": problem["task_id"], "_identifier": identifier, "solution": solution, } ret["base"] = untrusted_check( dataset, solution, problem["base_input"], problem["entry_point"], expected=expected_output["base"], atol=problem["atol"], ref_time=expected_output["base_time"], fast_check=fast_check, min_time_limit=min_time_limit, gt_time_limit_factor=gt_time_limit_factor, ) if not base_only: ret["plus"] = untrusted_check( dataset, solution, problem["plus_input"], problem["entry_point"], expected=expected_output["plus"], atol=problem["atol"], ref_time=expected_output["plus_time"], fast_check=fast_check, min_time_limit=min_time_limit, gt_time_limit_factor=gt_time_limit_factor, ) return ret def evaluate( dataset: str, samples: Optional[str] = None, base_only: bool = False, parallel: Optional[int] = None, i_just_wanna_run: bool = False, test_details: bool = False, min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, mini: bool = False, noextreme: bool = False, version: str = "default", output_file: Optional[str] = None, lang: str = "ar", **model_kwargs, ): if model_kwargs: print('model_kwargs', model_kwargs) # To suppress the warning of tokenizers os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( "TOKENIZERS_PARALLELISM", "false" ) samples = run_codegen( dataset=dataset, **model_kwargs, ) assert samples is not None, "No samples provided" n_workers = parallel or max(1, multiprocessing.cpu_count() // 2) if os.path.isdir(samples): result_path = os.path.join(samples, "eval_results.json") else: assert samples.endswith(".jsonl") result_path = samples.replace(".jsonl", "_eval_results.json") if output_file is not None: result_path = output_file if os.path.isfile(result_path) and not i_just_wanna_run: print(f"Load from previous results from {result_path}") with open(result_path, "r") as f: results = json.load(f) results = compatible_eval_result(results) else: if dataset == "humaneval": print('get_human_eval_plus') problems = get_human_eval_plus( mini=mini, noextreme=noextreme, version=version, lang=lang ) dataset_hash = get_human_eval_plus_hash( mini=mini, noextreme=noextreme, version=version ) expected_output = get_groundtruth(problems, dataset_hash, []) elif dataset == "mbpp": problems = get_mbpp_plus(mini=mini, noextreme=noextreme, version=version) dataset_hash = get_mbpp_plus_hash( mini=mini, noextreme=noextreme, version=version ) expected_output = get_groundtruth( problems, dataset_hash, MBPP_OUTPUT_NOT_NONE_TASKS, ) results = { "date": datetime.now().strftime("%Y-%m-%d %H:%M"), "hash": dataset_hash, "eval": {}, } with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [] completion_id = Counter() n_samples = 0 eval_results = defaultdict(list) # task_id -> remainings = set() print("Reading samples...") print('samples', samples) for sample in tqdm(load_solutions(samples)): task_id = sample["task_id"] if task_id not in problems: warn( f"Task {task_id} is found in the samples but not found in the dataset" ) continue solution = ( sample["solution"] if "solution" in sample else problems[task_id]["prompt"] + sample["completion"] ) remainings.add(sample["_identifier"]) args = ( dataset, completion_id[task_id], problems[task_id], solution, expected_output[task_id], base_only, not test_details, # fast_check sample["_identifier"], min_time_limit, gt_time_limit_factor, ) futures.append(executor.submit(check_correctness, *args)) completion_id[task_id] += 1 n_samples += 1 assert n_samples == len(remainings), "Missing problems in unfinished" assert len(completion_id) == len(problems), "Missing problems in samples" def stucking_checker(): while remainings: last_size = len(remainings) time.sleep(20) if last_size != len(remainings) or len(remainings) == 0: continue # Potential stucking warn("No samples had finished testing in the last 20s") warn(f"{len(remainings)} samples to be tested: {remainings}") threading.Thread(target=stucking_checker).start() for future in tqdm(as_completed(futures), total=n_samples): result = future.result() print('result 253', result) remainings.remove(result["_identifier"]) eval_results[result["task_id"]].append(result) # sort the results for each problem by completion_id for task_id, task_results in eval_results.items(): task_results.sort(key=lambda x: x["completion_id"]) results["eval"][task_id] = [] for res in task_results: def get_failed_tests(stat, details, inputs) -> List[Any]: if stat == PASS or not details: return [] if test_details: return [ inputs[i] for i in range(len(details)) if not details[i] ] # else => simply return the only and the last fail test return [inputs[len(details) - 1]] base_stat, base_details = res["base"] base_fail_tests = get_failed_tests( base_stat, base_details, problems[task_id]["base_input"] ) # initialize plus tests plus_stat = None plus_fail_tests = [] # with plus tests if not base_only: plus_stat, plus_details = res["plus"] plus_fail_tests = get_failed_tests( plus_stat, plus_details, problems[task_id]["plus_input"] ) if dataset == "mbpp": base_fail_tests = mbpp_serialize_inputs(task_id, base_fail_tests) plus_fail_tests = mbpp_serialize_inputs(task_id, plus_fail_tests) results["eval"][task_id].append( { "task_id": task_id, "solution": res["solution"], "base_status": base_stat, "plus_status": plus_stat, "base_fail_tests": base_fail_tests, "plus_fail_tests": plus_fail_tests, } ) # Calculate pass@k. total = np.array([len(r) for r in results["eval"].values()]) base_correct = [] new_correct = [] for res in results["eval"].values(): bc = sum([r["base_status"] == PASS for r in res]) base_correct.append(bc) if not base_only: new_correct.append( sum( [ res[i]["base_status"] == res[i]["plus_status"] == PASS for i in range(len(res)) ] ) ) base_correct = np.array(base_correct) pass_at_k = { f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() for k in [1, 10, 100] if total.min() >= k } cprint(f"{dataset} (base tests)", "red") for k, v in pass_at_k.items(): cprint(f"{k}:\t{v:.3f}", "red") results["pass_at_k"] = {"base": pass_at_k} if new_correct: cprint(f"{dataset}+ (base + extra tests)", "green") pass_at_k = { f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() for k in [1, 10, 100] if (total >= k).all() } for k, v in pass_at_k.items(): cprint(f"{k}:\t{v:.3f}", "green") results["pass_at_k"]["plus"] = pass_at_k # save results if os.path.isfile(result_path) and i_just_wanna_run: decision = "" while decision.lower() not in ["y", "n"]: print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") decision = input() if decision.lower() == "y": # mv the file to a backup new_path = result_path + ".bak" while os.path.isfile(new_path): new_path += ".bak" os.rename(result_path, new_path) print(f"Backup {result_path} to {new_path}") if not os.path.isfile(result_path): with open(result_path, "w") as f: json.dump(results, f) def main(): from fire import Fire Fire(evaluate) if __name__ == "__main__": main()