mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Optimize Evaluation Speed of SciCode (#1489)
* update scicode * update comments * remove redundant variable * Update --------- Co-authored-by: tonysy <sy.zhangbuaa@gmail.com>
This commit is contained in:
parent
00fc8da5be
commit
faf5260155
@ -1,3 +1,4 @@
|
|||||||
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
@ -95,10 +96,10 @@ def process_hdf5_datagroup(group):
|
|||||||
|
|
||||||
def process_hdf5_to_tuple(step_id, test_num):
|
def process_hdf5_to_tuple(step_id, test_num):
|
||||||
|
|
||||||
H5PY_FILE_FOLDER = './data/scicode/'
|
H5PY_FILE_FOLDER = './data/scicode/test_data'
|
||||||
H5PY_FILE_FOLDER = get_data_path(H5PY_FILE_FOLDER, local_mode=True)
|
H5PY_FILE_FOLDER = get_data_path(H5PY_FILE_FOLDER, local_mode=True)
|
||||||
data_lst = []
|
data_lst = []
|
||||||
H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, 'test_data.h5')
|
H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, f'{step_id}.h5')
|
||||||
assert os.path.exists(
|
assert os.path.exists(
|
||||||
H5PY_FILE
|
H5PY_FILE
|
||||||
), f"Please manually download 'test_data.h5' from https://github.com/open-compass/storage/releases/download/v0.1.0/scicode_test_data.zip and put the file in {H5PY_FILE}" # noqa: E501
|
), f"Please manually download 'test_data.h5' from https://github.com/open-compass/storage/releases/download/v0.1.0/scicode_test_data.zip and put the file in {H5PY_FILE}" # noqa: E501
|
||||||
@ -217,7 +218,7 @@ def cmp_tuple_or_list(var1, var2):
|
|||||||
@ICL_EVALUATORS.register_module()
|
@ICL_EVALUATORS.register_module()
|
||||||
class SciCodeEvaluator(BaseEvaluator):
|
class SciCodeEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
def __init__(self, dataset_path, with_bg, testcode_path='./tmp/scicode'):
|
def __init__(self, dataset_path, with_bg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
test_data = []
|
test_data = []
|
||||||
dataset_path = get_data_path(dataset_path, local_mode=True)
|
dataset_path = get_data_path(dataset_path, local_mode=True)
|
||||||
@ -229,8 +230,6 @@ class SciCodeEvaluator(BaseEvaluator):
|
|||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
test_data = json.load(file)
|
test_data = json.load(file)
|
||||||
self.dataset = Dataset.from_list(test_data)
|
self.dataset = Dataset.from_list(test_data)
|
||||||
self.testcode_path = testcode_path
|
|
||||||
H5PY_FILE = osp.join(dataset_path, 'test_data.h5') # noqa: F841
|
|
||||||
|
|
||||||
def extract_python_script(self, response: str):
|
def extract_python_script(self, response: str):
|
||||||
start_marker = '```python'
|
start_marker = '```python'
|
||||||
@ -271,25 +270,20 @@ class SciCodeEvaluator(BaseEvaluator):
|
|||||||
return 2
|
return 2
|
||||||
|
|
||||||
def score(self, predictions, references):
|
def score(self, predictions, references):
|
||||||
correct, sub_correct = 0, 0
|
# generate all python test codes
|
||||||
count, sub_count = 0, 0
|
|
||||||
details = []
|
|
||||||
|
|
||||||
# generate all python test codes and than test
|
|
||||||
for idx, prediction_list in enumerate(predictions):
|
for idx, prediction_list in enumerate(predictions):
|
||||||
# traverse each test sample
|
# traverse each test sample
|
||||||
problem_id = self.dataset[idx]['id']
|
problem_id = self.dataset[idx]['id']
|
||||||
num_of_subproblems = len(prediction_list)
|
num_of_subproblems = len(prediction_list)
|
||||||
|
|
||||||
# create dir for each test sample
|
# create dir for each test sample
|
||||||
testdir_path = os.path.join(self.testcode_path, str(problem_id))
|
testdir_path = os.path.join(self._out_dir, str(problem_id))
|
||||||
os.makedirs(testdir_path, exist_ok=True)
|
os.makedirs(testdir_path, exist_ok=True)
|
||||||
|
|
||||||
python_code = ''
|
python_code = ''
|
||||||
# add import statement
|
# add import statement
|
||||||
python_code += self.dataset[idx]['import']
|
python_code += self.dataset[idx]['import']
|
||||||
|
|
||||||
is_all_correct = True
|
|
||||||
for sub_idx in range(num_of_subproblems):
|
for sub_idx in range(num_of_subproblems):
|
||||||
# extract code
|
# extract code
|
||||||
response = prediction_list[sub_idx]
|
response = prediction_list[sub_idx]
|
||||||
@ -319,30 +313,50 @@ from opencompass.datasets.scicode import process_hdf5_to_tuple
|
|||||||
'\n')
|
'\n')
|
||||||
for idx2 in range(len(test_lst)):
|
for idx2 in range(len(test_lst)):
|
||||||
f.write(f'target = targets[{idx2}]\n\n')
|
f.write(f'target = targets[{idx2}]\n\n')
|
||||||
for line in test_lst[idx2].split('\n'):
|
for line in test_lst[idx2].split('\n'):
|
||||||
f.write(line + '\n')
|
f.write(line + '\n')
|
||||||
|
|
||||||
# test
|
# find all scripts
|
||||||
ret = self.run_script(testfile_path)
|
python_scripts = []
|
||||||
msg = {'problem': f'{problem_id}-{sub_idx + 1}'}
|
for root, dirs, files in os.walk(self._out_dir):
|
||||||
if ret == 0: # correct
|
for file in files:
|
||||||
sub_correct += 1
|
if file.endswith('.py'):
|
||||||
msg['is_correct'] = True
|
python_scripts.append(os.path.join(root, file))
|
||||||
elif ret == 1: # error
|
|
||||||
is_all_correct = False
|
|
||||||
msg['is_correct'] = False
|
|
||||||
else: # time out
|
|
||||||
is_all_correct = False
|
|
||||||
msg['is_correct'] = False
|
|
||||||
sub_count += 1
|
|
||||||
details.append(msg)
|
|
||||||
|
|
||||||
correct += is_all_correct
|
# Use ThreadPoolExecutor to concurrently execute scripts
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
# Submit task and obtain Future object
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.run_script, script)
|
||||||
|
for script in python_scripts
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
result = future.result()
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
all_results = {}
|
||||||
|
for script_path, result in zip(python_scripts, results):
|
||||||
|
basename = os.path.basename(script_path)
|
||||||
|
main_id = basename.split('-')[0]
|
||||||
|
if all_results.get(main_id):
|
||||||
|
all_results[main_id].append(result)
|
||||||
|
else:
|
||||||
|
all_results[main_id] = [result]
|
||||||
|
|
||||||
|
correct, sub_correct = 0, 0
|
||||||
|
count, sub_count = 0, 0
|
||||||
|
|
||||||
|
for main_id in all_results:
|
||||||
|
correct += sum(all_results[main_id]) == 0
|
||||||
count += 1
|
count += 1
|
||||||
|
for sub in all_results[main_id]:
|
||||||
|
sub_correct += sub == 0
|
||||||
|
sub_count += 1
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'accuracy': 100 * correct / count,
|
'accuracy': 100 * correct / count,
|
||||||
'sub_accuracy': 100 * sub_correct / sub_count,
|
'sub_accuracy': 100 * sub_correct / sub_count,
|
||||||
'details': details
|
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
@ -370,9 +370,9 @@ DATASETS_URL = {
|
|||||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ruler.zip",
|
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ruler.zip",
|
||||||
"md5": "c60bdfff3d02358067104cc1dea7c0f7",
|
"md5": "c60bdfff3d02358067104cc1dea7c0f7",
|
||||||
},
|
},
|
||||||
"scicode/": {
|
"/scicode": {
|
||||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/scicode.zip",
|
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/scicode.zip",
|
||||||
"md5": "06f64edad6680072e5bca3f0ce892d0c",
|
"md5": "9c6c64b8c70edc418f713419ea39989c",
|
||||||
},
|
},
|
||||||
"/commonsenseqa": {
|
"/commonsenseqa": {
|
||||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/commonsenseqa.zip",
|
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/commonsenseqa.zip",
|
||||||
|
Loading…
Reference in New Issue
Block a user