"""Legacy version of post-processing LLM-generated Python code. This sanitizer is implemented using regex and string manipulation. You might want to use the latest tree-sitter-based sanitizer (evalplus.sanitize) instead. """ import os import pathlib import re from typing import List, Optional from tqdm import tqdm from opencompass.datasets.evalplus.data import ( get_human_eval_plus, get_mbpp_plus, load_solutions, write_directory, write_jsonl, ) from opencompass.datasets.evalplus.syncheck import syntax_check def remove_unindented_lines( code: str, protect_before: str, execeptions: List[str], trim_tails: List[str] ) -> str: lines = code.splitlines() cut_idx = [] cut_enabled = False for i, line in enumerate(lines): if not cut_enabled and line.startswith(protect_before): cut_enabled = True continue if line.strip() == "": continue if any(line.startswith(e) for e in execeptions): continue lspace = len(line) - len(line.lstrip()) if lspace == 0: cut_idx.append(i) if any(line.rstrip().startswith(t) for t in trim_tails): # cut off everything behind cut_idx.extend(list(range(i, len(lines)))) break return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) def to_four_space_indents(old_code): new_code = "" for line in old_code.splitlines(): lspace = len(line) - len(line.lstrip()) if lspace == 3: new_code += " " new_code += line + "\n" return new_code def sanitize( old_code: str, entry_point: str, rm_prefix_lines: Optional[str] = None, eofs: List = None, ): new_code = old_code if rm_prefix_lines is not None: new_code = "\n".join( [ line for line in old_code.splitlines() if not line.startswith(rm_prefix_lines) ] ) new_code = "\n" + new_code def_left = "def " + entry_point # basic handling of chat output new_code = new_code.replace("\n```python\n", "\n```\n") for chunk in new_code.split("\n```\n"): if def_left in chunk: new_code = chunk break chunks = [chunk for chunk in re.split(f"{def_left}\\s*\\(", new_code)] # TODO: having return does not mean this is complete bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] def_left = def_left + "(" new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl new_code = to_four_space_indents(new_code) for eof in eofs or []: new_code = new_code.split(eof)[0] # remove lines starting from the first unindented line after def_left new_code = remove_unindented_lines( new_code, protect_before=def_left, execeptions=["def ", "import ", "from "], trim_tails=['"""', "if", "print"], ) new_code = chunks[0] + new_code # cut all functions that are not syntactically correct && not the entry point parts = new_code.split("\ndef ") includes = [parts[0]] for fn in new_code.split("\ndef ")[1:]: if ( fn.strip().startswith(entry_point + " ") or fn.strip().startswith(entry_point + "(") or syntax_check("\ndef " + fn) ): includes.append(fn) new_code = "\ndef ".join(includes) return new_code.strip() def script( samples: str, eofs: List[str] = [], inplace: bool = False, rm_prefix_lines: str = None, debug_task: str = None, mbpp_version: str = "default", ): # task_id -> entry_point entry_point = {} dataset = {**get_human_eval_plus(), **get_mbpp_plus(version=mbpp_version)} for task_id, problem in dataset.items(): entry_point[task_id] = problem["entry_point"] # make a new folder with "-sanitized" suffix is_folder = os.path.isdir(samples) target_path = pathlib.Path(samples) if not inplace: if is_folder: new_name = target_path.name + "-sanitized" else: new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") target_path = target_path.parent / new_name target_path = str(target_path) nsan = 0 ntotal = 0 new_solutions = [] for solution in tqdm(load_solutions(samples)): task_id = solution["task_id"] dbg_identifier = solution["_identifier"] if debug_task is not None and task_id != debug_task: continue ntotal += 1 if "solution" in solution: old_code = solution["solution"] else: assert "completion" in solution old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"] old_code = old_code.strip() new_code = sanitize( old_code=old_code, entry_point=entry_point[task_id], rm_prefix_lines=rm_prefix_lines, eofs=eofs, ).strip() # if changed, print the message if new_code != old_code: msg = "Sanitized: " + dbg_identifier if is_folder: msg += " -> " + dbg_identifier.replace(samples, target_path) print(msg) nsan += 1 new_solutions.append({"task_id": task_id, "solution": new_code}) if is_folder: write_directory(target_path, new_solutions) else: write_jsonl(target_path, new_solutions) if nsan > 0: print(f"Sanitized {nsan} out of {ntotal} files.") else: print(f"All files seems valid -- no files are sanitized.") print(f"Check the sanitized files at {target_path}") def main(): from fire import Fire Fire(script) if __name__ == "__main__": main()