OpenCompass/opencompass/datasets/multipl_e.py
Dongsheng Zhu 2c79dc5227
[Dataset] Add human_eval/mbpp pro (#2092)
* add bench

* update

* bug fix

* time update

* add index

* fix repeat bug
2025-05-12 18:38:13 +08:00

139 lines
4.9 KiB
Python

import difflib
import json
import os.path as osp
from datasets import Dataset
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
from .base import BaseDataset
# currently supporting languages
_HUMANEVAL_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
'sh', 'swift', 'ts'
]
_MBPP_LANGUAGE_ = [
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
'swift', 'ts'
]
@LOAD_DATASET.register_module()
class MultiplEDataset(BaseDataset):
@staticmethod
def load(path: str,
language: str,
tag: str = 'humaneval',
local_mode: bool = False):
"""Load dataset for pass k mode.
Args:
path(str): The path to the dataset.
language(str): The language of the dataset.
num_repeats(int): Number of repetition for this dataset to get.
tag(str): The tag of the dataset.
local_mode(bool): Whether to load the dataset in local mode.
Returns:
Dataset: A PyTorch dataset.
"""
path = get_data_path(path, local_mode=local_mode)
assert tag in ['humaneval',
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
if tag == 'humaneval':
assert language in _HUMANEVAL_LANGUAGE_, (
f'language must be in {_HUMANEVAL_LANGUAGE_}')
else:
assert language in _MBPP_LANGUAGE_, (
f'language must be in {_MBPP_LANGUAGE_}')
file_path = osp.join(path, f'{tag}-{language}.jsonl')
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
dataset.append(json.loads(line.strip()))
return Dataset.from_list(dataset)
class MultiplEEvaluator(CodeEvaluator):
def _stop_at_stop_token(self, decoded_string, stop_tokens):
"""Produces the prefix of decoded_string that ends at the first
occurrence of a stop_token.
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
Args:
decoded_string: A string generated by the model.
stop_tokens: A list of strings, where each string is a stop token.
Returns:
The decoded_string, truncated at the first occurrence of a stop
token.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
stop_index = decoded_string.find(stop_token)
if stop_index != -1 and stop_index < min_stop_index:
min_stop_index = stop_index
return decoded_string[:min_stop_index]
def _remove_prefix(self,
prompt: str,
completion: str,
threshold: float = 0.95) -> str:
"""Determine the truncation point in the completion based on the last
line of the prompt, remove all content before that line in the
completion, and return the completion string after removing the prefix.
This is done to convert chatbot-style inference mode to completion
mode.
Args:
prompt (str): The prompt text.
completion (str): The completion text.
threshold (float): Line similarity threshold.
Returns:
str: The completion string after removing the prefix.
"""
prompt_lines = prompt.splitlines()
completion_lines = completion.splitlines()
if not prompt_lines:
return completion
last_prompt_line = prompt_lines[-1]
cut_index = -1
for i, completion_line in enumerate(completion_lines):
similarity = difflib.SequenceMatcher(None, last_prompt_line,
completion_line).ratio()
if similarity >= threshold:
cut_index = i
break
if cut_index != -1:
return '\n'.join(completion_lines[cut_index + 1:])
else:
return completion
def _process_completions(self, test_case, completion):
"""Process completions with a test case.
Args:
test_case (dict): A test case containing prompt and stop tokens.
completion (str): The generated code completion.
Returns:
str: Processed code completion.
"""
post_comp = self._extract_code(completion)
post_comp = self._remove_prefix(test_case['prompt'], post_comp)
post_comp = self._stop_at_stop_token(post_comp,
test_case['stop_tokens'])
return post_comp