mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
201 lines
5.9 KiB
Python
201 lines
5.9 KiB
Python
![]() |
"""Legacy version of post-processing LLM-generated Python code.
|
||
|
This sanitizer is implemented using regex and string manipulation.
|
||
|
You might want to use the latest tree-sitter-based sanitizer (evalplus.sanitize) instead.
|
||
|
"""
|
||
|
|
||
|
import os
|
||
|
import pathlib
|
||
|
import re
|
||
|
from typing import List, Optional
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from opencompass.datasets.evalplus.data import (
|
||
|
get_human_eval_plus,
|
||
|
get_mbpp_plus,
|
||
|
load_solutions,
|
||
|
write_directory,
|
||
|
write_jsonl,
|
||
|
)
|
||
|
from opencompass.datasets.evalplus.syncheck import syntax_check
|
||
|
|
||
|
def remove_unindented_lines(
|
||
|
code: str, protect_before: str, execeptions: List[str], trim_tails: List[str]
|
||
|
) -> str:
|
||
|
lines = code.splitlines()
|
||
|
cut_idx = []
|
||
|
cut_enabled = False
|
||
|
for i, line in enumerate(lines):
|
||
|
if not cut_enabled and line.startswith(protect_before):
|
||
|
cut_enabled = True
|
||
|
continue
|
||
|
if line.strip() == "":
|
||
|
continue
|
||
|
if any(line.startswith(e) for e in execeptions):
|
||
|
continue
|
||
|
|
||
|
lspace = len(line) - len(line.lstrip())
|
||
|
if lspace == 0:
|
||
|
cut_idx.append(i)
|
||
|
|
||
|
if any(line.rstrip().startswith(t) for t in trim_tails):
|
||
|
# cut off everything behind
|
||
|
cut_idx.extend(list(range(i, len(lines))))
|
||
|
break
|
||
|
|
||
|
return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx])
|
||
|
|
||
|
|
||
|
def to_four_space_indents(old_code):
|
||
|
new_code = ""
|
||
|
for line in old_code.splitlines():
|
||
|
lspace = len(line) - len(line.lstrip())
|
||
|
if lspace == 3:
|
||
|
new_code += " "
|
||
|
new_code += line + "\n"
|
||
|
return new_code
|
||
|
|
||
|
|
||
|
def sanitize(
|
||
|
old_code: str,
|
||
|
entry_point: str,
|
||
|
rm_prefix_lines: Optional[str] = None,
|
||
|
eofs: List = None,
|
||
|
):
|
||
|
new_code = old_code
|
||
|
if rm_prefix_lines is not None:
|
||
|
new_code = "\n".join(
|
||
|
[
|
||
|
line
|
||
|
for line in old_code.splitlines()
|
||
|
if not line.startswith(rm_prefix_lines)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
new_code = "\n" + new_code
|
||
|
def_left = "def " + entry_point
|
||
|
|
||
|
# basic handling of chat output
|
||
|
new_code = new_code.replace("\n```python\n", "\n```\n")
|
||
|
for chunk in new_code.split("\n```\n"):
|
||
|
if def_left in chunk:
|
||
|
new_code = chunk
|
||
|
break
|
||
|
|
||
|
chunks = [chunk for chunk in re.split(f"{def_left}\\s*\\(", new_code)]
|
||
|
# TODO: having return does not mean this is complete
|
||
|
bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]]
|
||
|
def_left = def_left + "("
|
||
|
new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl
|
||
|
new_code = to_four_space_indents(new_code)
|
||
|
|
||
|
for eof in eofs or []:
|
||
|
new_code = new_code.split(eof)[0]
|
||
|
|
||
|
# remove lines starting from the first unindented line after def_left
|
||
|
new_code = remove_unindented_lines(
|
||
|
new_code,
|
||
|
protect_before=def_left,
|
||
|
execeptions=["def ", "import ", "from "],
|
||
|
trim_tails=['"""', "if", "print"],
|
||
|
)
|
||
|
new_code = chunks[0] + new_code
|
||
|
|
||
|
# cut all functions that are not syntactically correct && not the entry point
|
||
|
parts = new_code.split("\ndef ")
|
||
|
includes = [parts[0]]
|
||
|
for fn in new_code.split("\ndef ")[1:]:
|
||
|
if (
|
||
|
fn.strip().startswith(entry_point + " ")
|
||
|
or fn.strip().startswith(entry_point + "(")
|
||
|
or syntax_check("\ndef " + fn)
|
||
|
):
|
||
|
includes.append(fn)
|
||
|
new_code = "\ndef ".join(includes)
|
||
|
return new_code.strip()
|
||
|
|
||
|
|
||
|
def script(
|
||
|
samples: str,
|
||
|
eofs: List[str] = [],
|
||
|
inplace: bool = False,
|
||
|
rm_prefix_lines: str = None,
|
||
|
debug_task: str = None,
|
||
|
mbpp_version: str = "default",
|
||
|
):
|
||
|
# task_id -> entry_point
|
||
|
entry_point = {}
|
||
|
dataset = {**get_human_eval_plus(), **get_mbpp_plus(version=mbpp_version)}
|
||
|
|
||
|
for task_id, problem in dataset.items():
|
||
|
entry_point[task_id] = problem["entry_point"]
|
||
|
|
||
|
# make a new folder with "-sanitized" suffix
|
||
|
is_folder = os.path.isdir(samples)
|
||
|
target_path = pathlib.Path(samples)
|
||
|
if not inplace:
|
||
|
if is_folder:
|
||
|
new_name = target_path.name + "-sanitized"
|
||
|
else:
|
||
|
new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
|
||
|
target_path = target_path.parent / new_name
|
||
|
target_path = str(target_path)
|
||
|
|
||
|
nsan = 0
|
||
|
ntotal = 0
|
||
|
|
||
|
new_solutions = []
|
||
|
|
||
|
for solution in tqdm(load_solutions(samples)):
|
||
|
task_id = solution["task_id"]
|
||
|
dbg_identifier = solution["_identifier"]
|
||
|
if debug_task is not None and task_id != debug_task:
|
||
|
continue
|
||
|
|
||
|
ntotal += 1
|
||
|
if "solution" in solution:
|
||
|
old_code = solution["solution"]
|
||
|
else:
|
||
|
assert "completion" in solution
|
||
|
old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"]
|
||
|
|
||
|
old_code = old_code.strip()
|
||
|
|
||
|
new_code = sanitize(
|
||
|
old_code=old_code,
|
||
|
entry_point=entry_point[task_id],
|
||
|
rm_prefix_lines=rm_prefix_lines,
|
||
|
eofs=eofs,
|
||
|
).strip()
|
||
|
|
||
|
# if changed, print the message
|
||
|
if new_code != old_code:
|
||
|
msg = "Sanitized: " + dbg_identifier
|
||
|
if is_folder:
|
||
|
msg += " -> " + dbg_identifier.replace(samples, target_path)
|
||
|
print(msg)
|
||
|
nsan += 1
|
||
|
|
||
|
new_solutions.append({"task_id": task_id, "solution": new_code})
|
||
|
|
||
|
if is_folder:
|
||
|
write_directory(target_path, new_solutions)
|
||
|
else:
|
||
|
write_jsonl(target_path, new_solutions)
|
||
|
|
||
|
if nsan > 0:
|
||
|
print(f"Sanitized {nsan} out of {ntotal} files.")
|
||
|
else:
|
||
|
print(f"All files seems valid -- no files are sanitized.")
|
||
|
print(f"Check the sanitized files at {target_path}")
|
||
|
|
||
|
|
||
|
def main():
|
||
|
from fire import Fire
|
||
|
|
||
|
Fire(script)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|