mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Optimize Evaluation Speed of SciCode (#1489)
* update scicode * update comments * remove redundant variable * Update --------- Co-authored-by: tonysy <sy.zhangbuaa@gmail.com>
This commit is contained in:
parent
00fc8da5be
commit
faf5260155
@ -1,3 +1,4 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
@ -95,10 +96,10 @@ def process_hdf5_datagroup(group):
|
||||
|
||||
def process_hdf5_to_tuple(step_id, test_num):
|
||||
|
||||
H5PY_FILE_FOLDER = './data/scicode/'
|
||||
H5PY_FILE_FOLDER = './data/scicode/test_data'
|
||||
H5PY_FILE_FOLDER = get_data_path(H5PY_FILE_FOLDER, local_mode=True)
|
||||
data_lst = []
|
||||
H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, 'test_data.h5')
|
||||
H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, f'{step_id}.h5')
|
||||
assert os.path.exists(
|
||||
H5PY_FILE
|
||||
), f"Please manually download 'test_data.h5' from https://github.com/open-compass/storage/releases/download/v0.1.0/scicode_test_data.zip and put the file in {H5PY_FILE}" # noqa: E501
|
||||
@ -217,7 +218,7 @@ def cmp_tuple_or_list(var1, var2):
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class SciCodeEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self, dataset_path, with_bg, testcode_path='./tmp/scicode'):
|
||||
def __init__(self, dataset_path, with_bg):
|
||||
super().__init__()
|
||||
test_data = []
|
||||
dataset_path = get_data_path(dataset_path, local_mode=True)
|
||||
@ -229,8 +230,6 @@ class SciCodeEvaluator(BaseEvaluator):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
test_data = json.load(file)
|
||||
self.dataset = Dataset.from_list(test_data)
|
||||
self.testcode_path = testcode_path
|
||||
H5PY_FILE = osp.join(dataset_path, 'test_data.h5') # noqa: F841
|
||||
|
||||
def extract_python_script(self, response: str):
|
||||
start_marker = '```python'
|
||||
@ -271,25 +270,20 @@ class SciCodeEvaluator(BaseEvaluator):
|
||||
return 2
|
||||
|
||||
def score(self, predictions, references):
|
||||
correct, sub_correct = 0, 0
|
||||
count, sub_count = 0, 0
|
||||
details = []
|
||||
|
||||
# generate all python test codes and than test
|
||||
# generate all python test codes
|
||||
for idx, prediction_list in enumerate(predictions):
|
||||
# traverse each test sample
|
||||
problem_id = self.dataset[idx]['id']
|
||||
num_of_subproblems = len(prediction_list)
|
||||
|
||||
# create dir for each test sample
|
||||
testdir_path = os.path.join(self.testcode_path, str(problem_id))
|
||||
testdir_path = os.path.join(self._out_dir, str(problem_id))
|
||||
os.makedirs(testdir_path, exist_ok=True)
|
||||
|
||||
python_code = ''
|
||||
# add import statement
|
||||
python_code += self.dataset[idx]['import']
|
||||
|
||||
is_all_correct = True
|
||||
for sub_idx in range(num_of_subproblems):
|
||||
# extract code
|
||||
response = prediction_list[sub_idx]
|
||||
@ -319,30 +313,50 @@ from opencompass.datasets.scicode import process_hdf5_to_tuple
|
||||
'\n')
|
||||
for idx2 in range(len(test_lst)):
|
||||
f.write(f'target = targets[{idx2}]\n\n')
|
||||
for line in test_lst[idx2].split('\n'):
|
||||
f.write(line + '\n')
|
||||
for line in test_lst[idx2].split('\n'):
|
||||
f.write(line + '\n')
|
||||
|
||||
# test
|
||||
ret = self.run_script(testfile_path)
|
||||
msg = {'problem': f'{problem_id}-{sub_idx + 1}'}
|
||||
if ret == 0: # correct
|
||||
sub_correct += 1
|
||||
msg['is_correct'] = True
|
||||
elif ret == 1: # error
|
||||
is_all_correct = False
|
||||
msg['is_correct'] = False
|
||||
else: # time out
|
||||
is_all_correct = False
|
||||
msg['is_correct'] = False
|
||||
sub_count += 1
|
||||
details.append(msg)
|
||||
# find all scripts
|
||||
python_scripts = []
|
||||
for root, dirs, files in os.walk(self._out_dir):
|
||||
for file in files:
|
||||
if file.endswith('.py'):
|
||||
python_scripts.append(os.path.join(root, file))
|
||||
|
||||
correct += is_all_correct
|
||||
# Use ThreadPoolExecutor to concurrently execute scripts
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submit task and obtain Future object
|
||||
futures = [
|
||||
executor.submit(self.run_script, script)
|
||||
for script in python_scripts
|
||||
]
|
||||
|
||||
results = []
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
all_results = {}
|
||||
for script_path, result in zip(python_scripts, results):
|
||||
basename = os.path.basename(script_path)
|
||||
main_id = basename.split('-')[0]
|
||||
if all_results.get(main_id):
|
||||
all_results[main_id].append(result)
|
||||
else:
|
||||
all_results[main_id] = [result]
|
||||
|
||||
correct, sub_correct = 0, 0
|
||||
count, sub_count = 0, 0
|
||||
|
||||
for main_id in all_results:
|
||||
correct += sum(all_results[main_id]) == 0
|
||||
count += 1
|
||||
for sub in all_results[main_id]:
|
||||
sub_correct += sub == 0
|
||||
sub_count += 1
|
||||
|
||||
result = {
|
||||
'accuracy': 100 * correct / count,
|
||||
'sub_accuracy': 100 * sub_correct / sub_count,
|
||||
'details': details
|
||||
}
|
||||
return result
|
||||
|
@ -370,9 +370,9 @@ DATASETS_URL = {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ruler.zip",
|
||||
"md5": "c60bdfff3d02358067104cc1dea7c0f7",
|
||||
},
|
||||
"scicode/": {
|
||||
"/scicode": {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/scicode.zip",
|
||||
"md5": "06f64edad6680072e5bca3f0ce892d0c",
|
||||
"md5": "9c6c64b8c70edc418f713419ea39989c",
|
||||
},
|
||||
"/commonsenseqa": {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/commonsenseqa.zip",
|
||||
|
Loading…
Reference in New Issue
Block a user