2024-02-06 23:30:28 +08:00
|
|
|
import gzip
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os.path as osp
|
|
|
|
import re
|
|
|
|
import shutil
|
|
|
|
import subprocess
|
|
|
|
import tempfile
|
|
|
|
import time
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from datasets import Dataset
|
|
|
|
|
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
|
|
from opencompass.registry import LOAD_DATASET
|
[Feature] Support ModelScope datasets (#1289)
* add ceval, gsm8k modelscope surpport
* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest
* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets
* format file
* format file
* update dataset format
* support ms_dataset
* udpate dataset for modelscope support
* merge myl_dev and update test_ms_dataset
* udpate dataset for modelscope support
* update readme
* update eval_api_zhipu_v2
* remove unused code
* add get_data_path function
* update readme
* remove tydiqa japanese subset
* add ceval, gsm8k modelscope surpport
* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest
* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets
* format file
* format file
* update dataset format
* support ms_dataset
* udpate dataset for modelscope support
* merge myl_dev and update test_ms_dataset
* update readme
* udpate dataset for modelscope support
* update eval_api_zhipu_v2
* remove unused code
* add get_data_path function
* remove tydiqa japanese subset
* update util
* remove .DS_Store
* fix md format
* move util into package
* update docs/get_started.md
* restore eval_api_zhipu_v2.py, add environment setting
* Update dataset
* Update
* Update
* Update
* Update
---------
Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local>
Co-authored-by: Yunnglin <mao.looper@qq.com>
Co-authored-by: Yun lin <yunlin@laptop.local>
Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn>
Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00
|
|
|
from opencompass.utils import get_data_path
|
2024-02-06 23:30:28 +08:00
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
# currently supporting 19 languages
|
|
|
|
_LANGUAGE_NAME_DICT = {
|
|
|
|
'cpp': 'CPP',
|
|
|
|
'cs': 'C#',
|
|
|
|
'd': 'D',
|
|
|
|
'go': 'Go',
|
|
|
|
'java': 'Java',
|
|
|
|
'jl': 'Julia',
|
|
|
|
'js': 'JavaScript',
|
|
|
|
'lua': 'Lua',
|
|
|
|
'php': 'PHP',
|
|
|
|
'pl': 'Perl',
|
|
|
|
'py': 'Python',
|
|
|
|
'r': 'R',
|
|
|
|
'rb': 'Ruby',
|
|
|
|
'rkt': 'Racket',
|
|
|
|
'rs': 'Rust',
|
|
|
|
'scala': 'Scala',
|
|
|
|
'sh': 'Shell',
|
|
|
|
'swift': 'Swift',
|
|
|
|
'ts': 'TypeScript',
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class HumanevalMultiDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(path, language, version, num_repeats: int = 1, **kwargs):
|
|
|
|
"""Load humaneval dataset for pass k mode.
|
|
|
|
|
|
|
|
Note that you can use num_repeats > 1 when your model does not support
|
|
|
|
`num_return_sequence` in generation, otherwise use the raw
|
|
|
|
humaneval dataset and set `num_return_sequence` in model config to
|
|
|
|
generate multiple responses for testing pass@k>1.
|
|
|
|
|
|
|
|
It better to change your dataset abbr correspondingly if you want to
|
|
|
|
change num_repeats>1, otherwise the number in
|
|
|
|
`.cache/dataset_size.json` might be inconsistent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_repeats(int): Number of repetition for this dataset to get
|
|
|
|
multiple responses in special cases.
|
|
|
|
"""
|
[Feature] Support ModelScope datasets (#1289)
* add ceval, gsm8k modelscope surpport
* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest
* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets
* format file
* format file
* update dataset format
* support ms_dataset
* udpate dataset for modelscope support
* merge myl_dev and update test_ms_dataset
* udpate dataset for modelscope support
* update readme
* update eval_api_zhipu_v2
* remove unused code
* add get_data_path function
* update readme
* remove tydiqa japanese subset
* add ceval, gsm8k modelscope surpport
* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest
* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets
* format file
* format file
* update dataset format
* support ms_dataset
* udpate dataset for modelscope support
* merge myl_dev and update test_ms_dataset
* update readme
* udpate dataset for modelscope support
* update eval_api_zhipu_v2
* remove unused code
* add get_data_path function
* remove tydiqa japanese subset
* update util
* remove .DS_Store
* fix md format
* move util into package
* update docs/get_started.md
* restore eval_api_zhipu_v2.py, add environment setting
* Update dataset
* Update
* Update
* Update
* Update
---------
Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local>
Co-authored-by: Yunnglin <mao.looper@qq.com>
Co-authored-by: Yun lin <yunlin@laptop.local>
Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn>
Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00
|
|
|
path = get_data_path(path, local_mode=True)
|
2024-02-06 23:30:28 +08:00
|
|
|
assert language in _LANGUAGE_NAME_DICT.keys(), (
|
|
|
|
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
|
|
|
|
assert version in [
|
|
|
|
'keep', 'transform', 'reworded', 'remove'
|
|
|
|
], ('version must be in ["keep", "transform", "reworded", "remove"]')
|
|
|
|
file_path = osp.join(path, f'humaneval-{language}-{version}.jsonl')
|
|
|
|
dataset = []
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
for line in f:
|
|
|
|
dataset.extend(
|
|
|
|
[json.loads(line.strip()) for _ in range(num_repeats)])
|
|
|
|
return Dataset.from_list(dataset)
|
|
|
|
|
|
|
|
|
|
|
|
class HumanevalMultiEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
language,
|
|
|
|
ip_address='localhost',
|
|
|
|
port=5000,
|
|
|
|
retry=2,
|
|
|
|
timeout=600) -> None:
|
|
|
|
self.language = language
|
|
|
|
self.ip_address = ip_address
|
|
|
|
self.port = port
|
|
|
|
self.retry = retry
|
|
|
|
self.timeout = timeout
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def stop_at_stop_token(self, decoded_string, stop_tokens):
|
|
|
|
"""Produces the prefix of decoded_string that ends at the first
|
|
|
|
occurrence of a stop_token.
|
|
|
|
|
|
|
|
WARNING: the decoded_string *must not* include the prompt,
|
|
|
|
which may have stop tokens itself.
|
|
|
|
"""
|
|
|
|
min_stop_index = len(decoded_string)
|
|
|
|
for stop_token in stop_tokens:
|
|
|
|
stop_index = decoded_string.find(stop_token)
|
|
|
|
if stop_index != -1 and stop_index < min_stop_index:
|
|
|
|
min_stop_index = stop_index
|
|
|
|
return decoded_string[:min_stop_index]
|
|
|
|
|
|
|
|
def _code_eval_service(self, file_path):
|
|
|
|
exec_result = subprocess.run([
|
|
|
|
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
|
|
|
|
f'dataset=multipl-e/{self.language}',
|
|
|
|
f'{self.ip_address}:{self.port}/evaluate'
|
|
|
|
],
|
|
|
|
timeout=self.timeout,
|
|
|
|
capture_output=True)
|
|
|
|
|
|
|
|
if exec_result.returncode == 0 and re.match(
|
|
|
|
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
|
|
|
|
return True, json.loads(exec_result.stdout.decode('utf-8'))
|
|
|
|
else:
|
|
|
|
if exec_result.stderr:
|
|
|
|
try:
|
|
|
|
err = exec_result.stderr.decode()
|
|
|
|
except Exception:
|
|
|
|
err = exec_result.stderr
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
err = exec_result.stdout.decode()
|
|
|
|
except Exception:
|
|
|
|
err = exec_result.stdout
|
|
|
|
return False, err
|
|
|
|
|
|
|
|
def estimator(self, n: int, c: int, k: int) -> float:
|
|
|
|
"""
|
|
|
|
Calculates 1 - comb(n - c, k) / comb(n, k).
|
|
|
|
"""
|
|
|
|
if n - c < k:
|
|
|
|
return 1.0
|
|
|
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
|
|
|
|
|
|
|
def for_file(self, path):
|
|
|
|
|
|
|
|
try:
|
|
|
|
with gzip.open(path, 'rt') as f:
|
|
|
|
data = json.load(f)
|
|
|
|
except Exception:
|
|
|
|
return None
|
|
|
|
|
|
|
|
n = len(data['results'])
|
|
|
|
c = len([
|
|
|
|
True for r in data['results']
|
|
|
|
if r['status'] == 'OK' and r['exit_code'] == 0
|
|
|
|
])
|
|
|
|
return {
|
|
|
|
'pass@1': self.estimator(n, c, 1),
|
|
|
|
'pass@10': self.estimator(n, c, 10),
|
|
|
|
'pass@100': self.estimator(n, c, 100),
|
|
|
|
'n': n,
|
|
|
|
'c': c,
|
|
|
|
}
|
|
|
|
|
|
|
|
def score(self, predictions, references, test_set):
|
|
|
|
|
|
|
|
stop_tokens = test_set['stop_tokens'][0]
|
|
|
|
print(stop_tokens)
|
|
|
|
|
|
|
|
# convert to original version
|
|
|
|
test_set = test_set.to_pandas()
|
|
|
|
test_set_origin = test_set.drop_duplicates(subset='name')
|
|
|
|
num_repeats = int(len(test_set) / len(test_set_origin))
|
|
|
|
print(num_repeats)
|
|
|
|
|
|
|
|
# Create a temporary directory using the tempfile module
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
|
for i in range(len(test_set_origin)):
|
|
|
|
completions = predictions[i * num_repeats:(i + 1) *
|
|
|
|
num_repeats]
|
|
|
|
processed_completions = []
|
|
|
|
for comp in completions:
|
|
|
|
comp = self.stop_at_stop_token(comp, stop_tokens)
|
|
|
|
processed_completions.append(comp)
|
|
|
|
|
|
|
|
result_dict = {
|
|
|
|
'name': test_set_origin.iloc[i]['name'],
|
|
|
|
'language': test_set_origin.iloc[i]['language'],
|
|
|
|
'prompt': test_set_origin.iloc[i]['prompt'],
|
|
|
|
'tests': test_set_origin.iloc[i]['tests'],
|
|
|
|
'completions': processed_completions
|
|
|
|
}
|
|
|
|
|
|
|
|
json_str = json.dumps(result_dict)
|
|
|
|
json_bytes = json_str.encode('utf-8')
|
|
|
|
|
|
|
|
with gzip.GzipFile(
|
|
|
|
os.path.join(tmpdir, f'{result_dict["name"]}.json.gz'),
|
|
|
|
'w') as f:
|
|
|
|
f.write(json_bytes)
|
|
|
|
|
|
|
|
# create a zip file containing all the generated .json.gz files
|
|
|
|
zipname = os.path.join(tmpdir, 'archive')
|
|
|
|
shutil.make_archive(zipname, 'zip', tmpdir)
|
|
|
|
zipfile_path = f'{zipname}.zip'
|
|
|
|
|
|
|
|
num_retry = 0
|
|
|
|
while num_retry < self.retry:
|
|
|
|
succeed, output = self._code_eval_service(
|
|
|
|
file_path=zipfile_path)
|
|
|
|
if not succeed and '(56) Recv failure' in output:
|
|
|
|
# only retry when connection failed
|
|
|
|
num_retry += 1
|
|
|
|
# wait a min in case the service load is too high
|
|
|
|
time.sleep(60)
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
|
|
|
|
if succeed:
|
|
|
|
if isinstance(output, str):
|
|
|
|
return json.loads(output)
|
|
|
|
elif isinstance(output, dict):
|
|
|
|
return output
|