"""Post-processing LLM-generated Python code implemented using tree-sitter.""" import os import pathlib from typing import Dict, Generator, List, Optional, Set, Tuple import tree_sitter_python from tqdm import tqdm from tree_sitter import Language, Node, Parser from opencompass.datasets.evalplus.data import ( get_human_eval_plus, get_mbpp_plus, load_solutions, write_directory, write_jsonl, ) from opencompass.datasets.evalplus.syncheck import syntax_check CLASS_TYPE = "class_definition" FUNCTION_TYPE = "function_definition" IMPORT_TYPE = ["import_statement", "import_from_statement"] IDENTIFIER_TYPE = "identifier" ATTRIBUTE_TYPE = "attribute" RETURN_TYPE = "return_statement" EXPRESSION_TYPE = "expression_statement" ASSIGNMENT_TYPE = "assignment" def code_extract(text: str) -> str: lines = text.split("\n") longest_line_pair = (0, 0) longest_so_far = 0 for i in range(len(lines)): for j in range(i + 1, len(lines)): current_lines = "\n".join(lines[i : j + 1]) if syntax_check(current_lines): current_length = sum(1 for line in lines[i : j + 1] if line.strip()) if current_length > longest_so_far: longest_so_far = current_length longest_line_pair = (i, j) return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: def dfs_get_deps(node: Node, deps: Set[str]) -> None: for child in node.children: if child.type == IDENTIFIER_TYPE: deps.add(child.text.decode("utf8")) else: dfs_get_deps(child, deps) name2deps = {} for name, node in nodes: deps = set() dfs_get_deps(node, deps) name2deps[name] = deps return name2deps def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: queue = [entrypoint] visited = {entrypoint} while queue: current = queue.pop(0) if current not in call_graph: continue for neighbour in call_graph[current]: if not (neighbour in visited): visited.add(neighbour) queue.append(neighbour) return visited def get_definition_name(node: Node) -> str: for child in node.children: if child.type == IDENTIFIER_TYPE: return child.text.decode("utf8") def traverse_tree(node: Node) -> Generator[Node, None, None]: cursor = node.walk() depth = 0 visited_children = False while True: if not visited_children: yield cursor.node if not cursor.goto_first_child(): depth += 1 visited_children = True elif cursor.goto_next_sibling(): visited_children = False elif not cursor.goto_parent() or depth == 0: break else: depth -= 1 def has_return_statement(node: Node) -> bool: traverse_nodes = traverse_tree(node) for node in traverse_nodes: if node.type == RETURN_TYPE: return True return False def extract_target_code_or_empty(code: str, entrypoint: Optional[str] = None) -> str: code = code_extract(code) code_bytes = bytes(code, "utf8") parser = Parser(Language(tree_sitter_python.language())) tree = parser.parse(code_bytes) class_names = set() function_names = set() variable_names = set() root_node = tree.root_node import_nodes = [] definition_nodes = [] for child in root_node.children: if child.type in IMPORT_TYPE: import_nodes.append(child) elif child.type == CLASS_TYPE: name = get_definition_name(child) if not ( name in class_names or name in variable_names or name in function_names ): definition_nodes.append((name, child)) class_names.add(name) elif child.type == FUNCTION_TYPE: name = get_definition_name(child) if not ( name in function_names or name in variable_names or name in class_names ) and has_return_statement(child): definition_nodes.append((name, child)) function_names.add(get_definition_name(child)) elif ( child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE ): subchild = child.children[0] name = get_definition_name(subchild) if not ( name in variable_names or name in function_names or name in class_names ): definition_nodes.append((name, subchild)) variable_names.add(name) if entrypoint: name2deps = get_deps(definition_nodes) reacheable = get_function_dependency(entrypoint, name2deps) sanitized_output = b"" for node in import_nodes: sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" for pair in definition_nodes: name, node = pair if entrypoint and not (name in reacheable): continue sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" return sanitized_output[:-1].decode("utf8") def sanitize(code: str, entrypoint: Optional[str] = None) -> str: sanitized_code = extract_target_code_or_empty(code, entrypoint).strip() if not sanitized_code: return code_extract(code) return sanitized_code def script( samples: str, inplace: bool = False, debug_task: str = None, mbpp_version="default" ): # task_id -> entry_point entry_point = {} # merge two datasets dataset = {**get_human_eval_plus(), **get_mbpp_plus(version=mbpp_version)} for task_id, problem in dataset.items(): entry_point[task_id] = problem["entry_point"] # make a new folder with "-sanitized" suffix is_folder = os.path.isdir(samples) target_path = pathlib.Path(samples) if not inplace: if is_folder: new_name = target_path.name + "-sanitized" else: new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") target_path = target_path.parent / new_name target_path = str(target_path) nsan = 0 ntotal = 0 new_solutions = [] for solution in tqdm(load_solutions(samples)): task_id = solution["task_id"] if task_id not in dataset: print( f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset." ) continue function_name = entry_point[task_id] if task_id in entry_point else None dbg_identifier = solution["_identifier"] if debug_task is not None and task_id != debug_task: continue ntotal += 1 if "solution" in solution: old_code = solution["solution"] else: assert "completion" in solution old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"] new_code = sanitize(code=old_code, entrypoint=function_name) # if changed, print the message if new_code != old_code: msg = "Sanitized: " + dbg_identifier if is_folder: msg += " -> " + dbg_identifier.replace(samples, target_path) print(msg) nsan += 1 new_solutions.append({"task_id": task_id, "solution": new_code}) if is_folder: write_directory(target_path, new_solutions) else: write_jsonl(target_path, new_solutions) if nsan > 0: print(f"Sanitized {nsan} out of {ntotal} files.") else: print(f"All files seems valid -- no files are sanitized.") print(f"Check the sanitized files at {target_path}") def main(): from fire import Fire Fire(script) if __name__ == "__main__": main()