mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
import json
|
|
import os.path as osp
|
|
|
|
from datasets import Dataset
|
|
|
|
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
|
|
from opencompass.registry import LOAD_DATASET
|
|
from opencompass.utils import get_data_path
|
|
|
|
from .base import BaseDataset
|
|
|
|
# currently supporting languages
|
|
_HUMANEVAL_LANGUAGE_ = [
|
|
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
|
|
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
|
|
'sh', 'swift', 'ts'
|
|
]
|
|
_MBPP_LANGUAGE_ = [
|
|
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
|
|
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
|
|
'swift', 'ts'
|
|
]
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class MultiplEDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str,
|
|
language: str,
|
|
num_repeats: int = 1,
|
|
tag: str = 'humaneval',
|
|
local_mode: bool = False):
|
|
"""Load humaneval dataset for pass k mode.
|
|
|
|
Note that you can use num_repeats > 1 when your model does not support
|
|
`num_return_sequence` in generation, otherwise use the raw
|
|
humaneval dataset and set `num_return_sequence` in model config to
|
|
generate multiple responses for testing pass@k>1.
|
|
|
|
It better to change your dataset abbr correspondingly if you want to
|
|
change num_repeats>1, otherwise the number in
|
|
`.cache/dataset_size.json` might be inconsistent.
|
|
|
|
Args:
|
|
num_repeats(int): Number of repetition for this dataset to get
|
|
multiple responses in special cases.
|
|
"""
|
|
path = get_data_path(path, local_mode=local_mode)
|
|
assert tag in ['humaneval',
|
|
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
|
|
if tag == 'humaneval':
|
|
assert language in _HUMANEVAL_LANGUAGE_, (
|
|
f'language must be in {_HUMANEVAL_LANGUAGE_}')
|
|
else:
|
|
assert language in _MBPP_LANGUAGE_, (
|
|
f'language must be in {_MBPP_LANGUAGE_}')
|
|
file_path = osp.join(path, f'{tag}-{language}.jsonl')
|
|
dataset = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
dataset.extend(
|
|
[json.loads(line.strip()) for _ in range(num_repeats)])
|
|
return Dataset.from_list(dataset)
|
|
|
|
|
|
class MultiplEEvaluator(CodeEvaluator):
|
|
|
|
def _stop_at_stop_token(self, decoded_string, stop_tokens):
|
|
"""Produces the prefix of decoded_string that ends at the first
|
|
occurrence of a stop_token.
|
|
|
|
WARNING: the decoded_string *must not* include the prompt,
|
|
which may have stop tokens itself.
|
|
"""
|
|
min_stop_index = len(decoded_string)
|
|
for stop_token in stop_tokens:
|
|
stop_index = decoded_string.find(stop_token)
|
|
if stop_index != -1 and stop_index < min_stop_index:
|
|
min_stop_index = stop_index
|
|
return decoded_string[:min_stop_index]
|
|
|
|
def _process_completions(self, test_case, completions):
|
|
processed_completions = []
|
|
for comp in completions:
|
|
comp = self._extract_code(comp)
|
|
post_comp = self._remove_prefix(test_case['prompt'], comp)
|
|
post_comp = self._stop_at_stop_token(post_comp,
|
|
test_case['stop_tokens'])
|
|
processed_completions.append(post_comp)
|
|
return processed_completions
|