import difflib import json import os.path as osp from datasets import Dataset from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator from opencompass.registry import LOAD_DATASET from opencompass.utils import get_data_path from .base import BaseDataset # currently supporting languages _HUMANEVAL_LANGUAGE_ = [ 'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl', 'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts' ] _MBPP_LANGUAGE_ = [ 'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts' ] @LOAD_DATASET.register_module() class MultiplEDataset(BaseDataset): @staticmethod def load(path: str, language: str, tag: str = 'humaneval', local_mode: bool = False): """Load dataset for pass k mode. Args: path(str): The path to the dataset. language(str): The language of the dataset. num_repeats(int): Number of repetition for this dataset to get. tag(str): The tag of the dataset. local_mode(bool): Whether to load the dataset in local mode. Returns: Dataset: A PyTorch dataset. """ path = get_data_path(path, local_mode=local_mode) assert tag in ['humaneval', 'mbpp'], 'tag must be in ["humaneval", "mbpp"]' if tag == 'humaneval': assert language in _HUMANEVAL_LANGUAGE_, ( f'language must be in {_HUMANEVAL_LANGUAGE_}') else: assert language in _MBPP_LANGUAGE_, ( f'language must be in {_MBPP_LANGUAGE_}') file_path = osp.join(path, f'{tag}-{language}.jsonl') dataset = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: dataset.append(json.loads(line.strip())) return Dataset.from_list(dataset) class MultiplEEvaluator(CodeEvaluator): def _stop_at_stop_token(self, decoded_string, stop_tokens): """Produces the prefix of decoded_string that ends at the first occurrence of a stop_token. WARNING: the decoded_string *must not* include the prompt, which may have stop tokens itself. Args: decoded_string: A string generated by the model. stop_tokens: A list of strings, where each string is a stop token. Returns: The decoded_string, truncated at the first occurrence of a stop token. """ min_stop_index = len(decoded_string) for stop_token in stop_tokens: stop_index = decoded_string.find(stop_token) if stop_index != -1 and stop_index < min_stop_index: min_stop_index = stop_index return decoded_string[:min_stop_index] def _remove_prefix(self, prompt: str, completion: str, threshold: float = 0.95) -> str: """Determine the truncation point in the completion based on the last line of the prompt, remove all content before that line in the completion, and return the completion string after removing the prefix. This is done to convert chatbot-style inference mode to completion mode. Args: prompt (str): The prompt text. completion (str): The completion text. threshold (float): Line similarity threshold. Returns: str: The completion string after removing the prefix. """ prompt_lines = prompt.splitlines() completion_lines = completion.splitlines() if not prompt_lines: return completion last_prompt_line = prompt_lines[-1] cut_index = -1 for i, completion_line in enumerate(completion_lines): similarity = difflib.SequenceMatcher(None, last_prompt_line, completion_line).ratio() if similarity >= threshold: cut_index = i break if cut_index != -1: return '\n'.join(completion_lines[cut_index + 1:]) else: return completion def _process_completions(self, test_case, completion): """Process completions with a test case. Args: test_case (dict): A test case containing prompt and stop tokens. completion (str): The generated code completion. Returns: str: Processed code completion. """ post_comp = self._extract_code(completion) post_comp = self._remove_prefix(test_case['prompt'], post_comp) post_comp = self._stop_at_stop_token(post_comp, test_case['stop_tokens']) return post_comp