OpenCompass/opencompass/datasets/multipl_e.py
2025-03-20 05:56:27 +00:00

92 lines
3.4 KiB
Python

import json
import os.path as osp
from datasets import Dataset
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
from .base import BaseDataset
# currently supporting languages
_HUMANEVAL_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
'sh', 'swift', 'ts'
]
_MBPP_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
'swift', 'ts'
]
@LOAD_DATASET.register_module()
class MultiplEDataset(BaseDataset):
@staticmethod
def load(path: str,
language: str,
num_repeats: int = 1,
tag: str = 'humaneval',
local_mode: bool = False):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
path = get_data_path(path, local_mode=local_mode)
assert tag in ['humaneval',
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
if tag == 'humaneval':
assert language in _HUMANEVAL_LANGUAGE_, (
f'language must be in {_HUMANEVAL_LANGUAGE_}')
else:
assert language in _MBPP_LANGUAGE_, (
f'language must be in {_MBPP_LANGUAGE_}')
file_path = osp.join(path, f'{tag}-{language}.jsonl')
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class MultiplEEvaluator(CodeEvaluator):
def _stop_at_stop_token(self, decoded_string, stop_tokens):
"""Produces the prefix of decoded_string that ends at the first
occurrence of a stop_token.
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
stop_index = decoded_string.find(stop_token)
if stop_index != -1 and stop_index < min_stop_index:
min_stop_index = stop_index
return decoded_string[:min_stop_index]
def _process_completions(self, test_case, completions):
processed_completions = []
for comp in completions:
comp = self._extract_code(comp)
post_comp = self._remove_prefix(test_case['prompt'], comp)
post_comp = self._stop_at_stop_token(post_comp,
test_case['stop_tokens'])
processed_completions.append(post_comp)
return processed_completions