mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
119 lines
3.7 KiB
Python
119 lines
3.7 KiB
Python
"""This file checks two things:
|
|
1. Is the LLMs codegen completed for each benchmark?
|
|
2. Warn the code that are not compilable (it could be some impl issues).
|
|
"""
|
|
|
|
import ast
|
|
import traceback
|
|
|
|
from termcolor import colored
|
|
|
|
from opencompass.datasets.evalplus.data import load_solutions
|
|
|
|
|
|
def syntax_check(code, verbose=False):
|
|
try:
|
|
ast.parse(code)
|
|
return True
|
|
except (SyntaxError, MemoryError):
|
|
if verbose:
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def script(
|
|
samples: str, dataset: str, nsample_check: int = None, verbose: bool = False
|
|
):
|
|
# List[Dict{"task_id", "solution"}]
|
|
solutions = load_solutions(samples)
|
|
|
|
if dataset == "humaneval":
|
|
from evalplus.data import get_human_eval_plus
|
|
|
|
dataset = get_human_eval_plus()
|
|
dataset_name = "HumanEval"
|
|
elif dataset == "mbpp":
|
|
from evalplus.data import get_mbpp_plus
|
|
|
|
dataset = get_mbpp_plus()
|
|
dataset_name = "Mbpp"
|
|
|
|
print(colored(f"Dataset: {dataset_name}", "blue"))
|
|
|
|
id2solutions = {}
|
|
for solution in solutions:
|
|
task_id = solution["task_id"]
|
|
if task_id not in id2solutions:
|
|
id2solutions[task_id] = []
|
|
if "solution" not in solution:
|
|
assert "completion" in solution, "solution or completion must exist!"
|
|
solution["solution"] = dataset[task_id]["prompt"] + solution["completion"]
|
|
id2solutions[task_id].append(solution)
|
|
|
|
print(colored("==============================", "blue"))
|
|
print(colored(" ::: Checking completeness... ", "blue"))
|
|
print(colored(" ::::: All tasks complete? ", "blue"))
|
|
ndone = 0
|
|
|
|
task_ids = dataset.keys()
|
|
ntask = len(task_ids)
|
|
for task_id in task_ids:
|
|
if task_id not in id2solutions:
|
|
print(colored(f" ⚠️ {task_id} is missing!", "red"))
|
|
continue
|
|
nfiles = len(id2solutions[task_id])
|
|
|
|
if nsample_check is None or nfiles <= nsample_check:
|
|
ndone += 1
|
|
continue
|
|
|
|
print(
|
|
colored(
|
|
f" ⚠️ {task_id} only has {nfiles} samples! But {nsample_check} are expected.",
|
|
"red",
|
|
)
|
|
)
|
|
|
|
# check if there is enough number of samples here.
|
|
if nsample_check is not None:
|
|
if ntask != ndone:
|
|
ntbd = ntask - ndone
|
|
print(colored(f" ::::: ⚠️ {ntbd}/{ntask} tasks incomplete!", "red"))
|
|
else:
|
|
print(colored(f" ::::: All {ntask} tasks complete!", "green"))
|
|
|
|
print(colored("==============================", "blue"))
|
|
print(colored(" ::: Checking compilation... ", "blue"))
|
|
print(colored(" ::::: All code compilable? ", "blue"))
|
|
ncode = 0
|
|
nwrong = 0
|
|
for task_id in task_ids:
|
|
# task_id must exist
|
|
if task_id not in id2solutions:
|
|
continue
|
|
|
|
for solution in id2solutions[task_id]:
|
|
ncode += 1
|
|
code = solution["solution"]
|
|
dbg_identifier = solution["_identifier"]
|
|
if code.strip() == "":
|
|
print(colored(f" ⚠️ {dbg_identifier} is empty!", "red"))
|
|
nwrong += 1
|
|
elif not syntax_check(code, verbose):
|
|
print(colored(f" ⚠️ {dbg_identifier} is not compilable!", "red"))
|
|
nwrong += 1
|
|
if 0 != nwrong:
|
|
print(colored(f" ::::: ⚠️ {nwrong}/{ncode} code are not compilable!", "red"))
|
|
else:
|
|
print(colored(f" ::::: All {ncode} code are compilable!", "green"))
|
|
|
|
|
|
def main():
|
|
from fire import Fire
|
|
|
|
Fire(script)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|