OpenCompass/opencompass/datasets/evalplus/perf/select_pe_inputs.py
2025-02-19 04:46:42 +01:00

127 lines
4.1 KiB
Python

"""Select the most performance-exercising inputs from pe_inputs obtained from `sampling.py`.
"""
import json
from statistics import median
from tqdm import tqdm
from evalplus.config import PERF_CURATE_TIMEOUT_SECOND
from evalplus.data import get_human_eval_plus, get_mbpp_plus
from evalplus.data.mbpp import mbpp_deserialize_inputs, mbpp_serialize_inputs
from evalplus.perf.profile import are_profiles_broken, profile
def script(solutions: str, output_profiled_solutions: str, pe_inputs: str = None):
assert solutions.endswith(".jsonl")
assert pe_inputs is None or pe_inputs.endswith(".jsonl")
assert output_profiled_solutions.endswith(".jsonl")
evalplus = get_human_eval_plus(noextreme=True)
mbppplus = get_mbpp_plus(noextreme=True)
tasks = {**evalplus, **mbppplus}
# assume each line's format is: {
# "task_id": task's id,
# "inputs": a list of inputs,
inputs_dict = None
if pe_inputs is not None:
print("Loading performance-exercising inputs...")
with open(pe_inputs, "r") as f:
inputs_dict = {
task["task_id"]: task["inputs"] for l in f for task in [json.loads(l)]
}
# Notably, the solutions are already validated and cleaned.
with open(solutions, "r") as f:
solutions = {}
for l in f:
solution = json.loads(l)
solutions[solution["task_id"]] = solution["solution"]
for task_id, task in tqdm(tasks.items()):
if inputs_dict:
inputs = (
mbpp_deserialize_inputs(task_id, inputs_dict[task_id])
if "Mbpp/" in task_id
else inputs_dict[task_id]
)
else:
inputs = task["base_input"] + list(task["plus_input"])
input_costs = []
if task_id.startswith("HumanEval"):
canonical_solution = task["prompt"] + task["canonical_solution"]
else:
canonical_solution = task["canonical_solution"]
for inp in inputs:
costs = profile(
canonical_solution,
task["entry_point"],
[inp],
timeout_second_per_test=PERF_CURATE_TIMEOUT_SECOND,
)
if are_profiles_broken(costs):
continue
input_costs.append((median(costs), inp))
input_costs.sort(reverse=True, key=lambda x: x[0])
for _, pe_input in input_costs:
solution_costs = []
for solution in solutions[task_id]:
costs = profile(
solution,
task["entry_point"],
[pe_input],
timeout_second_per_test=PERF_CURATE_TIMEOUT_SECOND,
)
if not are_profiles_broken(costs):
solution_costs.append(costs)
continue
# stop once we find the first also the most performance-exercising input
break
# This means no timeouts happen for the input, so we use it.
if len(solution_costs) == len(solutions[task_id]):
break
# If no satisfied input found, we don't save any profiled data.
if len(input_costs) == 0 or len(solution_costs) != len(solutions[task_id]):
print(f"Skipping {task_id}...")
pe_input = None
solution_costs = None
else:
pe_input = (
mbpp_serialize_inputs(task_id, [pe_input])
if task_id.startswith("Mbpp/")
else [pe_input]
)
with open(output_profiled_solutions, "a") as f:
f.write(
json.dumps(
{
"task_id": task_id,
"pe_input": pe_input,
"solutions": solutions[task_id],
"counter_profile": solution_costs,
}
)
+ "\n"
)
def main():
from fire import Fire
Fire(script)
if __name__ == "__main__":
main()