mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[feat] support multipl-e (#846)
* [feat] support humaneval_multipl-e * format --------- Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
a6c49f15ce
commit
444d8d9507
4
configs/datasets/humaneval_multi/humaneval_multi_gen.py
Normal file
4
configs/datasets/humaneval_multi/humaneval_multi_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .humaneval_multi_gen_82cf85 import humaneval_multi_datasets # noqa: F401, F403
|
@ -0,0 +1,46 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HumanevalMultiDataset, HumanevalMultiEvaluator
|
||||
|
||||
humaneval_multi_reader_cfg = dict(input_columns=['prompt'], output_column='tests')
|
||||
|
||||
humaneval_multi_infer_cfg = dict(
|
||||
prompt_template=dict(type=PromptTemplate, template='{prompt}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=1024),
|
||||
)
|
||||
|
||||
humaneval_multi_eval_cfg = {
|
||||
lang: dict(
|
||||
evaluator=dict(
|
||||
type=HumanevalMultiEvaluator,
|
||||
language=lang,
|
||||
ip_address='localhost', # replace to your code_eval_server ip_address, port
|
||||
port=5000,
|
||||
), # refer to https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html to launch a server
|
||||
pred_role='BOT',
|
||||
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
|
||||
}
|
||||
|
||||
'''there are four versions of humaneval-{LANG}-{version}.jsonl:
|
||||
['keep', 'transform', 'reworded', 'remove']
|
||||
SRCDATA-LANG-keep is the same as SRCDATA-LANG, but the text of the prompt is totally unchanged. If the original prompt had Python doctests, they remain as Python instead of being translated to LANG. If the original prompt had Python-specific terminology, e.g., 'list', it remains 'list', instead of being translated, e.g., to 'vector' for C++.
|
||||
SRCDATA-LANG-transform transforms the doctests to LANG but leaves the natural language text of the prompt unchanged.
|
||||
SRCDATA-LANG-reworded transforms both the doctests and the natural language text of the prompt to LANG.
|
||||
SRCDATA-LANG-remove removes the doctests from the prompt.
|
||||
'''
|
||||
|
||||
humaneval_multi_datasets = [
|
||||
dict(
|
||||
type=HumanevalMultiDataset,
|
||||
abbr=f'humaneval_multiple-{lang}',
|
||||
language=lang,
|
||||
version='reworded', # choose from ['keep', 'transform', 'reworded', 'remove']
|
||||
num_repeats=1,
|
||||
path='./data/multi-data/humaneval_multipl-e/',
|
||||
reader_cfg=humaneval_multi_reader_cfg,
|
||||
infer_cfg=humaneval_multi_infer_cfg,
|
||||
eval_cfg=humaneval_multi_eval_cfg[lang],
|
||||
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
|
||||
]
|
218
opencompass/datasets/humaneval_multi.py
Normal file
218
opencompass/datasets/humaneval_multi.py
Normal file
@ -0,0 +1,218 @@
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
# currently supporting 19 languages
|
||||
_LANGUAGE_NAME_DICT = {
|
||||
'cpp': 'CPP',
|
||||
'cs': 'C#',
|
||||
'd': 'D',
|
||||
'go': 'Go',
|
||||
'java': 'Java',
|
||||
'jl': 'Julia',
|
||||
'js': 'JavaScript',
|
||||
'lua': 'Lua',
|
||||
'php': 'PHP',
|
||||
'pl': 'Perl',
|
||||
'py': 'Python',
|
||||
'r': 'R',
|
||||
'rb': 'Ruby',
|
||||
'rkt': 'Racket',
|
||||
'rs': 'Rust',
|
||||
'scala': 'Scala',
|
||||
'sh': 'Shell',
|
||||
'swift': 'Swift',
|
||||
'ts': 'TypeScript',
|
||||
}
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class HumanevalMultiDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path, language, version, num_repeats: int = 1, **kwargs):
|
||||
"""Load humaneval dataset for pass k mode.
|
||||
|
||||
Note that you can use num_repeats > 1 when your model does not support
|
||||
`num_return_sequence` in generation, otherwise use the raw
|
||||
humaneval dataset and set `num_return_sequence` in model config to
|
||||
generate multiple responses for testing pass@k>1.
|
||||
|
||||
It better to change your dataset abbr correspondingly if you want to
|
||||
change num_repeats>1, otherwise the number in
|
||||
`.cache/dataset_size.json` might be inconsistent.
|
||||
|
||||
Args:
|
||||
num_repeats(int): Number of repetition for this dataset to get
|
||||
multiple responses in special cases.
|
||||
"""
|
||||
assert language in _LANGUAGE_NAME_DICT.keys(), (
|
||||
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
|
||||
assert version in [
|
||||
'keep', 'transform', 'reworded', 'remove'
|
||||
], ('version must be in ["keep", "transform", "reworded", "remove"]')
|
||||
file_path = osp.join(path, f'humaneval-{language}-{version}.jsonl')
|
||||
dataset = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
dataset.extend(
|
||||
[json.loads(line.strip()) for _ in range(num_repeats)])
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
|
||||
class HumanevalMultiEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self,
|
||||
language,
|
||||
ip_address='localhost',
|
||||
port=5000,
|
||||
retry=2,
|
||||
timeout=600) -> None:
|
||||
self.language = language
|
||||
self.ip_address = ip_address
|
||||
self.port = port
|
||||
self.retry = retry
|
||||
self.timeout = timeout
|
||||
super().__init__()
|
||||
|
||||
def stop_at_stop_token(self, decoded_string, stop_tokens):
|
||||
"""Produces the prefix of decoded_string that ends at the first
|
||||
occurrence of a stop_token.
|
||||
|
||||
WARNING: the decoded_string *must not* include the prompt,
|
||||
which may have stop tokens itself.
|
||||
"""
|
||||
min_stop_index = len(decoded_string)
|
||||
for stop_token in stop_tokens:
|
||||
stop_index = decoded_string.find(stop_token)
|
||||
if stop_index != -1 and stop_index < min_stop_index:
|
||||
min_stop_index = stop_index
|
||||
return decoded_string[:min_stop_index]
|
||||
|
||||
def _code_eval_service(self, file_path):
|
||||
exec_result = subprocess.run([
|
||||
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
|
||||
f'dataset=multipl-e/{self.language}',
|
||||
f'{self.ip_address}:{self.port}/evaluate'
|
||||
],
|
||||
timeout=self.timeout,
|
||||
capture_output=True)
|
||||
|
||||
if exec_result.returncode == 0 and re.match(
|
||||
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
|
||||
return True, json.loads(exec_result.stdout.decode('utf-8'))
|
||||
else:
|
||||
if exec_result.stderr:
|
||||
try:
|
||||
err = exec_result.stderr.decode()
|
||||
except Exception:
|
||||
err = exec_result.stderr
|
||||
else:
|
||||
try:
|
||||
err = exec_result.stdout.decode()
|
||||
except Exception:
|
||||
err = exec_result.stdout
|
||||
return False, err
|
||||
|
||||
def estimator(self, n: int, c: int, k: int) -> float:
|
||||
"""
|
||||
Calculates 1 - comb(n - c, k) / comb(n, k).
|
||||
"""
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||
|
||||
def for_file(self, path):
|
||||
|
||||
try:
|
||||
with gzip.open(path, 'rt') as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
n = len(data['results'])
|
||||
c = len([
|
||||
True for r in data['results']
|
||||
if r['status'] == 'OK' and r['exit_code'] == 0
|
||||
])
|
||||
return {
|
||||
'pass@1': self.estimator(n, c, 1),
|
||||
'pass@10': self.estimator(n, c, 10),
|
||||
'pass@100': self.estimator(n, c, 100),
|
||||
'n': n,
|
||||
'c': c,
|
||||
}
|
||||
|
||||
def score(self, predictions, references, test_set):
|
||||
|
||||
stop_tokens = test_set['stop_tokens'][0]
|
||||
print(stop_tokens)
|
||||
|
||||
# convert to original version
|
||||
test_set = test_set.to_pandas()
|
||||
test_set_origin = test_set.drop_duplicates(subset='name')
|
||||
num_repeats = int(len(test_set) / len(test_set_origin))
|
||||
print(num_repeats)
|
||||
|
||||
# Create a temporary directory using the tempfile module
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for i in range(len(test_set_origin)):
|
||||
completions = predictions[i * num_repeats:(i + 1) *
|
||||
num_repeats]
|
||||
processed_completions = []
|
||||
for comp in completions:
|
||||
comp = self.stop_at_stop_token(comp, stop_tokens)
|
||||
processed_completions.append(comp)
|
||||
|
||||
result_dict = {
|
||||
'name': test_set_origin.iloc[i]['name'],
|
||||
'language': test_set_origin.iloc[i]['language'],
|
||||
'prompt': test_set_origin.iloc[i]['prompt'],
|
||||
'tests': test_set_origin.iloc[i]['tests'],
|
||||
'completions': processed_completions
|
||||
}
|
||||
|
||||
json_str = json.dumps(result_dict)
|
||||
json_bytes = json_str.encode('utf-8')
|
||||
|
||||
with gzip.GzipFile(
|
||||
os.path.join(tmpdir, f'{result_dict["name"]}.json.gz'),
|
||||
'w') as f:
|
||||
f.write(json_bytes)
|
||||
|
||||
# create a zip file containing all the generated .json.gz files
|
||||
zipname = os.path.join(tmpdir, 'archive')
|
||||
shutil.make_archive(zipname, 'zip', tmpdir)
|
||||
zipfile_path = f'{zipname}.zip'
|
||||
|
||||
num_retry = 0
|
||||
while num_retry < self.retry:
|
||||
succeed, output = self._code_eval_service(
|
||||
file_path=zipfile_path)
|
||||
if not succeed and '(56) Recv failure' in output:
|
||||
# only retry when connection failed
|
||||
num_retry += 1
|
||||
# wait a min in case the service load is too high
|
||||
time.sleep(60)
|
||||
else:
|
||||
break
|
||||
|
||||
if succeed:
|
||||
if isinstance(output, str):
|
||||
return json.loads(output)
|
||||
elif isinstance(output, dict):
|
||||
return output
|
Loading…
Reference in New Issue
Block a user