import json import os.path as osp from datasets import Dataset from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator from opencompass.registry import LOAD_DATASET from opencompass.utils import get_data_path from .base import BaseDataset # currently supporting languages _HUMANEVAL_LANGUAGE_ = [ 'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl', 'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts' ] _MBPP_LANGUAGE_ = [ 'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts' ] @LOAD_DATASET.register_module() class MultiplEDataset(BaseDataset): @staticmethod def load(path: str, language: str, num_repeats: int = 1, tag: str = 'humaneval', local_mode: bool = False): """Load dataset for pass k mode. Args: path(str): The path to the dataset. language(str): The language of the dataset. num_repeats(int): Number of repetition for this dataset to get. tag(str): The tag of the dataset. local_mode(bool): Whether to load the dataset in local mode. Returns: Dataset: A PyTorch dataset. """ path = get_data_path(path, local_mode=local_mode) assert tag in ['humaneval', 'mbpp'], 'tag must be in ["humaneval", "mbpp"]' if tag == 'humaneval': assert language in _HUMANEVAL_LANGUAGE_, ( f'language must be in {_HUMANEVAL_LANGUAGE_}') else: assert language in _MBPP_LANGUAGE_, ( f'language must be in {_MBPP_LANGUAGE_}') file_path = osp.join(path, f'{tag}-{language}.jsonl') dataset = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: dataset.extend( [json.loads(line.strip()) for _ in range(num_repeats)]) return Dataset.from_list(dataset) class MultiplEEvaluator(CodeEvaluator): def _stop_at_stop_token(self, decoded_string, stop_tokens): """Produces the prefix of decoded_string that ends at the first occurrence of a stop_token. WARNING: the decoded_string *must not* include the prompt, which may have stop tokens itself. Args: decoded_string: A string generated by the model. stop_tokens: A list of strings, where each string is a stop token. Returns: The decoded_string, truncated at the first occurrence of a stop token. """ min_stop_index = len(decoded_string) for stop_token in stop_tokens: stop_index = decoded_string.find(stop_token) if stop_index != -1 and stop_index < min_stop_index: min_stop_index = stop_index return decoded_string[:min_stop_index] def _process_completions(self, test_case, completions): """Process completions with a test case. Args: test_case: A test case. completions: A list of completions. Returns: A list of processed completions. """ processed_completions = [] for comp in completions: comp = self._extract_code(comp) post_comp = self._remove_prefix(test_case['prompt'], comp) post_comp = self._stop_at_stop_token(post_comp, test_case['stop_tokens']) processed_completions.append(post_comp) return processed_completions