mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
"""Select the most performance-exercising inputs from pe_inputs obtained from `sampling.py`.
|
|
"""
|
|
|
|
import json
|
|
from statistics import median
|
|
|
|
from tqdm import tqdm
|
|
|
|
from evalplus.config import PERF_CURATE_TIMEOUT_SECOND
|
|
from evalplus.data import get_human_eval_plus, get_mbpp_plus
|
|
from evalplus.data.mbpp import mbpp_deserialize_inputs, mbpp_serialize_inputs
|
|
from evalplus.perf.profile import are_profiles_broken, profile
|
|
|
|
|
|
def script(solutions: str, output_profiled_solutions: str, pe_inputs: str = None):
|
|
assert solutions.endswith(".jsonl")
|
|
assert pe_inputs is None or pe_inputs.endswith(".jsonl")
|
|
assert output_profiled_solutions.endswith(".jsonl")
|
|
|
|
evalplus = get_human_eval_plus(noextreme=True)
|
|
mbppplus = get_mbpp_plus(noextreme=True)
|
|
tasks = {**evalplus, **mbppplus}
|
|
|
|
# assume each line's format is: {
|
|
# "task_id": task's id,
|
|
# "inputs": a list of inputs,
|
|
inputs_dict = None
|
|
|
|
if pe_inputs is not None:
|
|
print("Loading performance-exercising inputs...")
|
|
with open(pe_inputs, "r") as f:
|
|
inputs_dict = {
|
|
task["task_id"]: task["inputs"] for l in f for task in [json.loads(l)]
|
|
}
|
|
|
|
# Notably, the solutions are already validated and cleaned.
|
|
with open(solutions, "r") as f:
|
|
solutions = {}
|
|
for l in f:
|
|
solution = json.loads(l)
|
|
solutions[solution["task_id"]] = solution["solution"]
|
|
|
|
for task_id, task in tqdm(tasks.items()):
|
|
if inputs_dict:
|
|
inputs = (
|
|
mbpp_deserialize_inputs(task_id, inputs_dict[task_id])
|
|
if "Mbpp/" in task_id
|
|
else inputs_dict[task_id]
|
|
)
|
|
else:
|
|
inputs = task["base_input"] + list(task["plus_input"])
|
|
|
|
input_costs = []
|
|
|
|
if task_id.startswith("HumanEval"):
|
|
canonical_solution = task["prompt"] + task["canonical_solution"]
|
|
else:
|
|
canonical_solution = task["canonical_solution"]
|
|
|
|
for inp in inputs:
|
|
costs = profile(
|
|
canonical_solution,
|
|
task["entry_point"],
|
|
[inp],
|
|
timeout_second_per_test=PERF_CURATE_TIMEOUT_SECOND,
|
|
)
|
|
if are_profiles_broken(costs):
|
|
continue
|
|
input_costs.append((median(costs), inp))
|
|
input_costs.sort(reverse=True, key=lambda x: x[0])
|
|
|
|
for _, pe_input in input_costs:
|
|
solution_costs = []
|
|
|
|
for solution in solutions[task_id]:
|
|
costs = profile(
|
|
solution,
|
|
task["entry_point"],
|
|
[pe_input],
|
|
timeout_second_per_test=PERF_CURATE_TIMEOUT_SECOND,
|
|
)
|
|
if not are_profiles_broken(costs):
|
|
solution_costs.append(costs)
|
|
continue
|
|
|
|
# stop once we find the first also the most performance-exercising input
|
|
break
|
|
|
|
# This means no timeouts happen for the input, so we use it.
|
|
if len(solution_costs) == len(solutions[task_id]):
|
|
break
|
|
|
|
# If no satisfied input found, we don't save any profiled data.
|
|
if len(input_costs) == 0 or len(solution_costs) != len(solutions[task_id]):
|
|
print(f"Skipping {task_id}...")
|
|
pe_input = None
|
|
solution_costs = None
|
|
else:
|
|
pe_input = (
|
|
mbpp_serialize_inputs(task_id, [pe_input])
|
|
if task_id.startswith("Mbpp/")
|
|
else [pe_input]
|
|
)
|
|
|
|
with open(output_profiled_solutions, "a") as f:
|
|
f.write(
|
|
json.dumps(
|
|
{
|
|
"task_id": task_id,
|
|
"pe_input": pe_input,
|
|
"solutions": solutions[task_id],
|
|
"counter_profile": solution_costs,
|
|
}
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
|
|
def main():
|
|
from fire import Fire
|
|
|
|
Fire(script)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|