mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
296 lines
10 KiB
Python
296 lines
10 KiB
Python
import json
|
|
import os
|
|
import re
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from copy import deepcopy
|
|
from traceback import format_exc
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from pympler.asizeof import asizeof
|
|
from rich.syntax import Syntax
|
|
from termcolor import colored
|
|
|
|
from evalplus.config import PERF_CURATE_TIMEOUT_SECOND, PERF_RAM_GB_PER_PROC
|
|
from evalplus.data import get_human_eval_plus, get_mbpp_plus
|
|
from evalplus.data.mbpp import mbpp_serialize_inputs
|
|
from evalplus.eval.utils import TimeoutException, reliability_guard, time_limit
|
|
from evalplus.sanitize import syntax_check
|
|
from evalplus.utils import progress
|
|
|
|
|
|
# this is more of a hack... rather than a "verified" implementation
|
|
def insert_contract(entry_point: str, code: str, contract: str):
|
|
# why is this so complicated? because the contract might be mis-indented...
|
|
def get_first_indent_size(source, body_char_start_idx):
|
|
assert source.strip()
|
|
indent_size = 0
|
|
while source[body_char_start_idx - indent_size - 1] == " ":
|
|
indent_size += 1
|
|
return indent_size
|
|
|
|
code = code.replace("\t", " " * 4)
|
|
contract = contract.replace("\t", " " * 4)
|
|
|
|
lines = [line for line in code.split("\n") if line.strip()]
|
|
fn_def_line = [line for line in lines if line.startswith(f"def {entry_point}")][0]
|
|
def_line_idx = lines.index(fn_def_line)
|
|
body_start_idx = code.index(code.split(fn_def_line)[1].lstrip())
|
|
|
|
source_indent: int = get_first_indent_size(code, body_start_idx)
|
|
contract_indent: int = get_first_indent_size(
|
|
contract, len(contract) - len(contract.lstrip())
|
|
)
|
|
return "\n".join(
|
|
lines[: def_line_idx + 1]
|
|
+ [
|
|
" " * max(0, source_indent - contract_indent) + cline
|
|
for cline in contract.split("\n")
|
|
if cline
|
|
]
|
|
+ [
|
|
" " * max(0, contract_indent - source_indent) + sline
|
|
for sline in lines[def_line_idx + 1 :]
|
|
if sline
|
|
]
|
|
)
|
|
|
|
|
|
def post_process(text: str) -> Optional[str]:
|
|
"""Post-process the LLM generated text to make it valid."""
|
|
if "\n```" not in text:
|
|
return None
|
|
|
|
# split ```python3 or ```python
|
|
text = re.split(r"\n```python3?\n", text)[1]
|
|
text = text.split("\n```")[0].strip()
|
|
|
|
# perform syntax check
|
|
if not syntax_check(text):
|
|
print(colored("⚠️ Syntax check failed for the code below:", "red"))
|
|
print(text[:256], "..." if len(text) > 256 else "")
|
|
return None
|
|
|
|
return text
|
|
|
|
|
|
# returns:
|
|
# 1. generated and validated (by the contract) inputs
|
|
# 2. whether the generator stops in a well-defined manner
|
|
# -- if False, we might want to try another generator
|
|
def sample_one_input(
|
|
ref_code_with_contract: str,
|
|
entry_point: str,
|
|
generator_code: str,
|
|
timeout_second: float = PERF_CURATE_TIMEOUT_SECOND + 1,
|
|
) -> Tuple[List[Any], bool]:
|
|
# These system calls are needed when cleaning up tempdir.
|
|
import os
|
|
import shutil
|
|
|
|
rmtree = shutil.rmtree
|
|
rmdir = os.rmdir
|
|
chdir = os.chdir
|
|
# Disable functionalities that can make destructive changes to the test.
|
|
# :imit memory usages.
|
|
maximum_memory_bytes = PERF_RAM_GB_PER_PROC * 1024 * 1024 * 1024
|
|
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
|
exec_globals = {}
|
|
|
|
# eval the func def with contract
|
|
exec(ref_code_with_contract, exec_globals)
|
|
fn = exec_globals[entry_point]
|
|
|
|
# eval the generator
|
|
generator_code = "from typing import *\n" + generator_code
|
|
try:
|
|
exec(generator_code, exec_globals)
|
|
generator = exec_globals["perf_input_gen"]
|
|
except Exception:
|
|
print(colored(f"⚠️ [GEN EVAL] Exception ~ {entry_point}:", "red"))
|
|
print(colored(format_exc(), "red"))
|
|
return [], False
|
|
|
|
well_defined_exit = True
|
|
return_inputs = []
|
|
|
|
for fac in range(1, 27):
|
|
scale = 2**fac
|
|
print(f"[INPUT GEN] scale=2**{fac}")
|
|
try:
|
|
with time_limit(timeout_second):
|
|
test_input = generator(scale)
|
|
if not isinstance(test_input, tuple):
|
|
test_input = (test_input,)
|
|
# integers should stay in the range of 64-bit
|
|
if any(
|
|
isinstance(arg, int) and not (-(2**63) <= arg < 2**63)
|
|
for arg in test_input
|
|
):
|
|
print(colored(f"[INPUT GEN] Int overflow against 64bit", "yellow"))
|
|
break
|
|
# hack list integer
|
|
if isinstance(test_input[0], list) and any(
|
|
not (-(2**63) <= v < 2**63)
|
|
for v in test_input[0]
|
|
if isinstance(v, int)
|
|
):
|
|
print(colored(f"[INPUT GEN] Int overflow against 64bit", "yellow"))
|
|
break
|
|
# stop here if the input is of 64M.
|
|
INPUT_LIMIT_MB = 64
|
|
if asizeof(test_input) > 1024 * 1024 * INPUT_LIMIT_MB:
|
|
print(colored(f"[INPUT GEN] Size > {INPUT_LIMIT_MB}MB", "yellow"))
|
|
break
|
|
except TimeoutException:
|
|
print(colored(f"[INPUT GEN] TimeoutException at scale=2**{fac}", "yellow"))
|
|
break
|
|
except MemoryError:
|
|
print(colored(f"[INPUT GEN] MemoryError at scale=2**{fac}", "yellow"))
|
|
break
|
|
except Exception:
|
|
print(colored(f"⚠️ [INPUT GEN] Exception at scale=2**{fac}", "red"))
|
|
print(colored(format_exc(), "red"))
|
|
well_defined_exit = False
|
|
break
|
|
|
|
try:
|
|
with time_limit(timeout_second):
|
|
# deepcopy in case fn modifies the input
|
|
fn(*deepcopy(test_input))
|
|
return_inputs = [test_input] # only keep on input
|
|
except TimeoutException:
|
|
print(colored(f"[Testing] Timeout at scale=2**{fac}", "yellow"))
|
|
break
|
|
except MemoryError:
|
|
print(colored(f"[Testing] MemoryError at scale=2**{fac}", "yellow"))
|
|
break
|
|
except Exception:
|
|
print(colored(f"⚠️ [Testing] Exception ~ {entry_point}", "red"))
|
|
print(colored(format_exc(), "red"))
|
|
well_defined_exit = False
|
|
break
|
|
|
|
# Needed for cleaning up.
|
|
shutil.rmtree = rmtree
|
|
os.rmdir = rmdir
|
|
os.chdir = chdir
|
|
|
|
return return_inputs, well_defined_exit
|
|
|
|
|
|
def main(input: str, output: str):
|
|
"""In the synthesizer file, each line includes a set of input generators for a task.
|
|
The goal of this script is to use these generators to sample inputs for each task.
|
|
The generated inputs are expected to be valid.
|
|
"""
|
|
assert output.endswith(".jsonl"), "output must be a .jsonl file"
|
|
|
|
id2task = {}
|
|
for task_id, item in get_human_eval_plus().items():
|
|
id2task[task_id] = item
|
|
|
|
for task_id, item in get_mbpp_plus().items():
|
|
id2task[task_id] = item
|
|
|
|
# loading the synthesizers
|
|
with open(input, "r") as f:
|
|
synthesizers = [json.loads(l) for l in f]
|
|
|
|
n_total = 0
|
|
n_parsed = 0
|
|
n_dedup = 0
|
|
|
|
for item in synthesizers:
|
|
item["synthesizers"] = [post_process(s) for s in item["synthesizers"]]
|
|
n_total += len(item["synthesizers"])
|
|
item["synthesizers"] = [s for s in item["synthesizers"] if s is not None]
|
|
n_parsed += len(item["synthesizers"])
|
|
|
|
dedup_set = set()
|
|
for s in item["synthesizers"]:
|
|
dedup_set.add(
|
|
"\n".join(
|
|
[l for l in s.splitlines() if l.strip() and not l.startswith("#")]
|
|
)
|
|
)
|
|
item["synthesizers"] = list(dedup_set)
|
|
n_dedup += len(item["synthesizers"])
|
|
|
|
print(
|
|
colored(
|
|
f"#Total {n_total} with {n_parsed} parsed => {100 * (1 - n_parsed / n_total) :.1f}% syntax err",
|
|
"green",
|
|
)
|
|
)
|
|
|
|
print(
|
|
colored(
|
|
f"#Parsed {n_parsed} with {n_dedup} dedup => {100 * (1 - n_dedup / n_parsed) :.1f}% duplicate",
|
|
"green",
|
|
)
|
|
)
|
|
|
|
# resume mode check finished tasks
|
|
finished_tasks = set()
|
|
if os.path.isfile(output):
|
|
with open(output, "r") as f:
|
|
for l in f:
|
|
item = json.loads(l)
|
|
finished_tasks.add(item["task_id"])
|
|
|
|
print("Resumed finished tasks:", finished_tasks)
|
|
with open(output, "ab+") as f:
|
|
with progress() as p:
|
|
for item in p.track(synthesizers):
|
|
task_id = item["task_id"]
|
|
entry_point = id2task[task_id]["entry_point"]
|
|
if task_id in finished_tasks:
|
|
p.console.print(f"{task_id}: {entry_point} ~ Resumed")
|
|
continue
|
|
|
|
ref_code_with_contract = insert_contract(
|
|
entry_point, item["ref_code"], id2task[task_id]["contract"]
|
|
)
|
|
p.console.print(f"{task_id}: PE input generation...")
|
|
p.console.print(Syntax(ref_code_with_contract.strip(), "python"))
|
|
|
|
results = []
|
|
for i, generator_code in enumerate(item["synthesizers"]):
|
|
p.console.print(
|
|
f"Using generator {i+1}/{len(item['synthesizers'])}:"
|
|
)
|
|
p.console.print(Syntax(generator_code, "python"))
|
|
args = (
|
|
ref_code_with_contract,
|
|
entry_point,
|
|
generator_code,
|
|
)
|
|
with ProcessPoolExecutor(max_workers=1) as executor:
|
|
tmp_results, status = executor.submit(
|
|
sample_one_input, *args
|
|
).result()
|
|
|
|
results.extend(tmp_results)
|
|
|
|
# if the func returns in a well-defined manner, we can stop here.
|
|
if status:
|
|
break
|
|
|
|
p.console.print("Serializing and storing results...")
|
|
|
|
if "Mbpp/" in task_id:
|
|
results = mbpp_serialize_inputs(task_id, results)
|
|
|
|
to_write = {"task_id": item["task_id"], "inputs": results}
|
|
to_write = (json.dumps(to_write) + "\n").encode("utf-8")
|
|
|
|
# task_id => list of inputs
|
|
f.write(to_write)
|
|
f.flush()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import fire
|
|
|
|
fire.Fire(main)
|