mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
import difflib
|
|
import json
|
|
import os.path as osp
|
|
|
|
from datasets import Dataset
|
|
|
|
from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
|
|
from opencompass.registry import LOAD_DATASET
|
|
from opencompass.utils import get_data_path
|
|
|
|
from .base import BaseDataset
|
|
|
|
# currently supporting languages
|
|
_HUMANEVAL_LANGUAGE_ = [
|
|
'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
|
|
'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
|
|
'sh', 'swift', 'ts'
|
|
]
|
|
_MBPP_LANGUAGE_ = [
|
|
'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
|
|
'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
|
|
'swift', 'ts'
|
|
]
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class MultiplEDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str,
|
|
language: str,
|
|
tag: str = 'humaneval',
|
|
local_mode: bool = False):
|
|
"""Load dataset for pass k mode.
|
|
|
|
Args:
|
|
path(str): The path to the dataset.
|
|
language(str): The language of the dataset.
|
|
num_repeats(int): Number of repetition for this dataset to get.
|
|
tag(str): The tag of the dataset.
|
|
local_mode(bool): Whether to load the dataset in local mode.
|
|
|
|
Returns:
|
|
Dataset: A PyTorch dataset.
|
|
"""
|
|
path = get_data_path(path, local_mode=local_mode)
|
|
assert tag in ['humaneval',
|
|
'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
|
|
if tag == 'humaneval':
|
|
assert language in _HUMANEVAL_LANGUAGE_, (
|
|
f'language must be in {_HUMANEVAL_LANGUAGE_}')
|
|
else:
|
|
assert language in _MBPP_LANGUAGE_, (
|
|
f'language must be in {_MBPP_LANGUAGE_}')
|
|
file_path = osp.join(path, f'{tag}-{language}.jsonl')
|
|
dataset = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
dataset.append(json.loads(line.strip()))
|
|
return Dataset.from_list(dataset)
|
|
|
|
|
|
class MultiplEEvaluator(CodeEvaluator):
|
|
|
|
def _stop_at_stop_token(self, decoded_string, stop_tokens):
|
|
"""Produces the prefix of decoded_string that ends at the first
|
|
occurrence of a stop_token.
|
|
|
|
WARNING: the decoded_string *must not* include the prompt,
|
|
which may have stop tokens itself.
|
|
|
|
Args:
|
|
decoded_string: A string generated by the model.
|
|
stop_tokens: A list of strings, where each string is a stop token.
|
|
Returns:
|
|
The decoded_string, truncated at the first occurrence of a stop
|
|
token.
|
|
"""
|
|
min_stop_index = len(decoded_string)
|
|
for stop_token in stop_tokens:
|
|
stop_index = decoded_string.find(stop_token)
|
|
if stop_index != -1 and stop_index < min_stop_index:
|
|
min_stop_index = stop_index
|
|
return decoded_string[:min_stop_index]
|
|
|
|
def _remove_prefix(self,
|
|
prompt: str,
|
|
completion: str,
|
|
threshold: float = 0.95) -> str:
|
|
"""Determine the truncation point in the completion based on the last
|
|
line of the prompt, remove all content before that line in the
|
|
completion, and return the completion string after removing the prefix.
|
|
This is done to convert chatbot-style inference mode to completion
|
|
mode.
|
|
|
|
Args:
|
|
prompt (str): The prompt text.
|
|
completion (str): The completion text.
|
|
threshold (float): Line similarity threshold.
|
|
|
|
Returns:
|
|
str: The completion string after removing the prefix.
|
|
"""
|
|
prompt_lines = prompt.splitlines()
|
|
completion_lines = completion.splitlines()
|
|
|
|
if not prompt_lines:
|
|
return completion
|
|
|
|
last_prompt_line = prompt_lines[-1]
|
|
cut_index = -1
|
|
|
|
for i, completion_line in enumerate(completion_lines):
|
|
similarity = difflib.SequenceMatcher(None, last_prompt_line,
|
|
completion_line).ratio()
|
|
if similarity >= threshold:
|
|
cut_index = i
|
|
break
|
|
|
|
if cut_index != -1:
|
|
return '\n'.join(completion_lines[cut_index + 1:])
|
|
else:
|
|
return completion
|
|
|
|
def _process_completions(self, test_case, completion):
|
|
"""Process completions with a test case.
|
|
|
|
Args:
|
|
test_case (dict): A test case containing prompt and stop tokens.
|
|
completion (str): The generated code completion.
|
|
Returns:
|
|
str: Processed code completion.
|
|
"""
|
|
post_comp = self._extract_code(completion)
|
|
post_comp = self._remove_prefix(test_case['prompt'], post_comp)
|
|
post_comp = self._stop_at_stop_token(post_comp,
|
|
test_case['stop_tokens'])
|
|
return post_comp
|