[feat] support multipl-e (#846)

* [feat] support humaneval_multipl-e

* format

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Connor-Shen 2024-02-06 23:30:28 +08:00 committed by GitHub
parent a6c49f15ce
commit 444d8d9507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 268 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .humaneval_multi_gen_82cf85 import humaneval_multi_datasets # noqa: F401, F403

View File

@ -0,0 +1,46 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalMultiDataset, HumanevalMultiEvaluator
humaneval_multi_reader_cfg = dict(input_columns=['prompt'], output_column='tests')
humaneval_multi_infer_cfg = dict(
prompt_template=dict(type=PromptTemplate, template='{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024),
)
humaneval_multi_eval_cfg = {
lang: dict(
evaluator=dict(
type=HumanevalMultiEvaluator,
language=lang,
ip_address='localhost', # replace to your code_eval_server ip_address, port
port=5000,
), # refer to https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html to launch a server
pred_role='BOT',
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
}
'''there are four versions of humaneval-{LANG}-{version}.jsonl:
['keep', 'transform', 'reworded', 'remove']
SRCDATA-LANG-keep is the same as SRCDATA-LANG, but the text of the prompt is totally unchanged. If the original prompt had Python doctests, they remain as Python instead of being translated to LANG. If the original prompt had Python-specific terminology, e.g., 'list', it remains 'list', instead of being translated, e.g., to 'vector' for C++.
SRCDATA-LANG-transform transforms the doctests to LANG but leaves the natural language text of the prompt unchanged.
SRCDATA-LANG-reworded transforms both the doctests and the natural language text of the prompt to LANG.
SRCDATA-LANG-remove removes the doctests from the prompt.
'''
humaneval_multi_datasets = [
dict(
type=HumanevalMultiDataset,
abbr=f'humaneval_multiple-{lang}',
language=lang,
version='reworded', # choose from ['keep', 'transform', 'reworded', 'remove']
num_repeats=1,
path='./data/multi-data/humaneval_multipl-e/',
reader_cfg=humaneval_multi_reader_cfg,
infer_cfg=humaneval_multi_infer_cfg,
eval_cfg=humaneval_multi_eval_cfg[lang],
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
]

View File

@ -0,0 +1,218 @@
import gzip
import json
import os
import os.path as osp
import re
import shutil
import subprocess
import tempfile
import time
import numpy as np
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
# currently supporting 19 languages
_LANGUAGE_NAME_DICT = {
'cpp': 'CPP',
'cs': 'C#',
'd': 'D',
'go': 'Go',
'java': 'Java',
'jl': 'Julia',
'js': 'JavaScript',
'lua': 'Lua',
'php': 'PHP',
'pl': 'Perl',
'py': 'Python',
'r': 'R',
'rb': 'Ruby',
'rkt': 'Racket',
'rs': 'Rust',
'scala': 'Scala',
'sh': 'Shell',
'swift': 'Swift',
'ts': 'TypeScript',
}
@LOAD_DATASET.register_module()
class HumanevalMultiDataset(BaseDataset):
@staticmethod
def load(path, language, version, num_repeats: int = 1, **kwargs):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
assert language in _LANGUAGE_NAME_DICT.keys(), (
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
assert version in [
'keep', 'transform', 'reworded', 'remove'
], ('version must be in ["keep", "transform", "reworded", "remove"]')
file_path = osp.join(path, f'humaneval-{language}-{version}.jsonl')
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class HumanevalMultiEvaluator(BaseEvaluator):
def __init__(self,
language,
ip_address='localhost',
port=5000,
retry=2,
timeout=600) -> None:
self.language = language
self.ip_address = ip_address
self.port = port
self.retry = retry
self.timeout = timeout
super().__init__()
def stop_at_stop_token(self, decoded_string, stop_tokens):
"""Produces the prefix of decoded_string that ends at the first
occurrence of a stop_token.
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
stop_index = decoded_string.find(stop_token)
if stop_index != -1 and stop_index < min_stop_index:
min_stop_index = stop_index
return decoded_string[:min_stop_index]
def _code_eval_service(self, file_path):
exec_result = subprocess.run([
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
f'dataset=multipl-e/{self.language}',
f'{self.ip_address}:{self.port}/evaluate'
],
timeout=self.timeout,
capture_output=True)
if exec_result.returncode == 0 and re.match(
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
return True, json.loads(exec_result.stdout.decode('utf-8'))
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except Exception:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except Exception:
err = exec_result.stdout
return False, err
def estimator(self, n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def for_file(self, path):
try:
with gzip.open(path, 'rt') as f:
data = json.load(f)
except Exception:
return None
n = len(data['results'])
c = len([
True for r in data['results']
if r['status'] == 'OK' and r['exit_code'] == 0
])
return {
'pass@1': self.estimator(n, c, 1),
'pass@10': self.estimator(n, c, 10),
'pass@100': self.estimator(n, c, 100),
'n': n,
'c': c,
}
def score(self, predictions, references, test_set):
stop_tokens = test_set['stop_tokens'][0]
print(stop_tokens)
# convert to original version
test_set = test_set.to_pandas()
test_set_origin = test_set.drop_duplicates(subset='name')
num_repeats = int(len(test_set) / len(test_set_origin))
print(num_repeats)
# Create a temporary directory using the tempfile module
with tempfile.TemporaryDirectory() as tmpdir:
for i in range(len(test_set_origin)):
completions = predictions[i * num_repeats:(i + 1) *
num_repeats]
processed_completions = []
for comp in completions:
comp = self.stop_at_stop_token(comp, stop_tokens)
processed_completions.append(comp)
result_dict = {
'name': test_set_origin.iloc[i]['name'],
'language': test_set_origin.iloc[i]['language'],
'prompt': test_set_origin.iloc[i]['prompt'],
'tests': test_set_origin.iloc[i]['tests'],
'completions': processed_completions
}
json_str = json.dumps(result_dict)
json_bytes = json_str.encode('utf-8')
with gzip.GzipFile(
os.path.join(tmpdir, f'{result_dict["name"]}.json.gz'),
'w') as f:
f.write(json_bytes)
# create a zip file containing all the generated .json.gz files
zipname = os.path.join(tmpdir, 'archive')
shutil.make_archive(zipname, 'zip', tmpdir)
zipfile_path = f'{zipname}.zip'
num_retry = 0
while num_retry < self.retry:
succeed, output = self._code_eval_service(
file_path=zipfile_path)
if not succeed and '(56) Recv failure' in output:
# only retry when connection failed
num_retry += 1
# wait a min in case the service load is too high
time.sleep(60)
else:
break
if succeed:
if isinstance(output, str):
return json.loads(output)
elif isinstance(output, dict):
return output