OpenCompass/opencompass/datasets/evalplus/lecacy_sanitize.py
2025-02-19 04:46:42 +01:00

201 lines
5.9 KiB
Python

"""Legacy version of post-processing LLM-generated Python code.
This sanitizer is implemented using regex and string manipulation.
You might want to use the latest tree-sitter-based sanitizer (evalplus.sanitize) instead.
"""
import os
import pathlib
import re
from typing import List, Optional
from tqdm import tqdm
from opencompass.datasets.evalplus.data import (
get_human_eval_plus,
get_mbpp_plus,
load_solutions,
write_directory,
write_jsonl,
)
from opencompass.datasets.evalplus.syncheck import syntax_check
def remove_unindented_lines(
code: str, protect_before: str, execeptions: List[str], trim_tails: List[str]
) -> str:
lines = code.splitlines()
cut_idx = []
cut_enabled = False
for i, line in enumerate(lines):
if not cut_enabled and line.startswith(protect_before):
cut_enabled = True
continue
if line.strip() == "":
continue
if any(line.startswith(e) for e in execeptions):
continue
lspace = len(line) - len(line.lstrip())
if lspace == 0:
cut_idx.append(i)
if any(line.rstrip().startswith(t) for t in trim_tails):
# cut off everything behind
cut_idx.extend(list(range(i, len(lines))))
break
return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx])
def to_four_space_indents(old_code):
new_code = ""
for line in old_code.splitlines():
lspace = len(line) - len(line.lstrip())
if lspace == 3:
new_code += " "
new_code += line + "\n"
return new_code
def sanitize(
old_code: str,
entry_point: str,
rm_prefix_lines: Optional[str] = None,
eofs: List = None,
):
new_code = old_code
if rm_prefix_lines is not None:
new_code = "\n".join(
[
line
for line in old_code.splitlines()
if not line.startswith(rm_prefix_lines)
]
)
new_code = "\n" + new_code
def_left = "def " + entry_point
# basic handling of chat output
new_code = new_code.replace("\n```python\n", "\n```\n")
for chunk in new_code.split("\n```\n"):
if def_left in chunk:
new_code = chunk
break
chunks = [chunk for chunk in re.split(f"{def_left}\\s*\\(", new_code)]
# TODO: having return does not mean this is complete
bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]]
def_left = def_left + "("
new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl
new_code = to_four_space_indents(new_code)
for eof in eofs or []:
new_code = new_code.split(eof)[0]
# remove lines starting from the first unindented line after def_left
new_code = remove_unindented_lines(
new_code,
protect_before=def_left,
execeptions=["def ", "import ", "from "],
trim_tails=['"""', "if", "print"],
)
new_code = chunks[0] + new_code
# cut all functions that are not syntactically correct && not the entry point
parts = new_code.split("\ndef ")
includes = [parts[0]]
for fn in new_code.split("\ndef ")[1:]:
if (
fn.strip().startswith(entry_point + " ")
or fn.strip().startswith(entry_point + "(")
or syntax_check("\ndef " + fn)
):
includes.append(fn)
new_code = "\ndef ".join(includes)
return new_code.strip()
def script(
samples: str,
eofs: List[str] = [],
inplace: bool = False,
rm_prefix_lines: str = None,
debug_task: str = None,
mbpp_version: str = "default",
):
# task_id -> entry_point
entry_point = {}
dataset = {**get_human_eval_plus(), **get_mbpp_plus(version=mbpp_version)}
for task_id, problem in dataset.items():
entry_point[task_id] = problem["entry_point"]
# make a new folder with "-sanitized" suffix
is_folder = os.path.isdir(samples)
target_path = pathlib.Path(samples)
if not inplace:
if is_folder:
new_name = target_path.name + "-sanitized"
else:
new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
target_path = target_path.parent / new_name
target_path = str(target_path)
nsan = 0
ntotal = 0
new_solutions = []
for solution in tqdm(load_solutions(samples)):
task_id = solution["task_id"]
dbg_identifier = solution["_identifier"]
if debug_task is not None and task_id != debug_task:
continue
ntotal += 1
if "solution" in solution:
old_code = solution["solution"]
else:
assert "completion" in solution
old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"]
old_code = old_code.strip()
new_code = sanitize(
old_code=old_code,
entry_point=entry_point[task_id],
rm_prefix_lines=rm_prefix_lines,
eofs=eofs,
).strip()
# if changed, print the message
if new_code != old_code:
msg = "Sanitized: " + dbg_identifier
if is_folder:
msg += " -> " + dbg_identifier.replace(samples, target_path)
print(msg)
nsan += 1
new_solutions.append({"task_id": task_id, "solution": new_code})
if is_folder:
write_directory(target_path, new_solutions)
else:
write_jsonl(target_path, new_solutions)
if nsan > 0:
print(f"Sanitized {nsan} out of {ntotal} files.")
else:
print(f"All files seems valid -- no files are sanitized.")
print(f"Check the sanitized files at {target_path}")
def main():
from fire import Fire
Fire(script)
if __name__ == "__main__":
main()