From faf5260155beb3aac0d4cfe93c8ba198edee8437 Mon Sep 17 00:00:00 2001 From: Hari Seldon <95674173+HariSeldon0@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:59:41 +0800 Subject: [PATCH] [Feature] Optimize Evaluation Speed of SciCode (#1489) * update scicode * update comments * remove redundant variable * Update --------- Co-authored-by: tonysy --- opencompass/datasets/scicode.py | 74 ++++++++++++++++++------------ opencompass/utils/datasets_info.py | 4 +- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/opencompass/datasets/scicode.py b/opencompass/datasets/scicode.py index e6a09e8b..bf8ee96e 100644 --- a/opencompass/datasets/scicode.py +++ b/opencompass/datasets/scicode.py @@ -1,3 +1,4 @@ +import concurrent.futures import json import os import os.path as osp @@ -95,10 +96,10 @@ def process_hdf5_datagroup(group): def process_hdf5_to_tuple(step_id, test_num): - H5PY_FILE_FOLDER = './data/scicode/' + H5PY_FILE_FOLDER = './data/scicode/test_data' H5PY_FILE_FOLDER = get_data_path(H5PY_FILE_FOLDER, local_mode=True) data_lst = [] - H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, 'test_data.h5') + H5PY_FILE = os.path.join(H5PY_FILE_FOLDER, f'{step_id}.h5') assert os.path.exists( H5PY_FILE ), f"Please manually download 'test_data.h5' from https://github.com/open-compass/storage/releases/download/v0.1.0/scicode_test_data.zip and put the file in {H5PY_FILE}" # noqa: E501 @@ -217,7 +218,7 @@ def cmp_tuple_or_list(var1, var2): @ICL_EVALUATORS.register_module() class SciCodeEvaluator(BaseEvaluator): - def __init__(self, dataset_path, with_bg, testcode_path='./tmp/scicode'): + def __init__(self, dataset_path, with_bg): super().__init__() test_data = [] dataset_path = get_data_path(dataset_path, local_mode=True) @@ -229,8 +230,6 @@ class SciCodeEvaluator(BaseEvaluator): with open(file_path, 'r', encoding='utf-8') as file: test_data = json.load(file) self.dataset = Dataset.from_list(test_data) - self.testcode_path = testcode_path - H5PY_FILE = osp.join(dataset_path, 'test_data.h5') # noqa: F841 def extract_python_script(self, response: str): start_marker = '```python' @@ -271,25 +270,20 @@ class SciCodeEvaluator(BaseEvaluator): return 2 def score(self, predictions, references): - correct, sub_correct = 0, 0 - count, sub_count = 0, 0 - details = [] - - # generate all python test codes and than test + # generate all python test codes for idx, prediction_list in enumerate(predictions): # traverse each test sample problem_id = self.dataset[idx]['id'] num_of_subproblems = len(prediction_list) # create dir for each test sample - testdir_path = os.path.join(self.testcode_path, str(problem_id)) + testdir_path = os.path.join(self._out_dir, str(problem_id)) os.makedirs(testdir_path, exist_ok=True) python_code = '' # add import statement python_code += self.dataset[idx]['import'] - is_all_correct = True for sub_idx in range(num_of_subproblems): # extract code response = prediction_list[sub_idx] @@ -319,30 +313,50 @@ from opencompass.datasets.scicode import process_hdf5_to_tuple '\n') for idx2 in range(len(test_lst)): f.write(f'target = targets[{idx2}]\n\n') - for line in test_lst[idx2].split('\n'): - f.write(line + '\n') + for line in test_lst[idx2].split('\n'): + f.write(line + '\n') - # test - ret = self.run_script(testfile_path) - msg = {'problem': f'{problem_id}-{sub_idx + 1}'} - if ret == 0: # correct - sub_correct += 1 - msg['is_correct'] = True - elif ret == 1: # error - is_all_correct = False - msg['is_correct'] = False - else: # time out - is_all_correct = False - msg['is_correct'] = False - sub_count += 1 - details.append(msg) + # find all scripts + python_scripts = [] + for root, dirs, files in os.walk(self._out_dir): + for file in files: + if file.endswith('.py'): + python_scripts.append(os.path.join(root, file)) - correct += is_all_correct + # Use ThreadPoolExecutor to concurrently execute scripts + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit task and obtain Future object + futures = [ + executor.submit(self.run_script, script) + for script in python_scripts + ] + + results = [] + for future in concurrent.futures.as_completed(futures): + result = future.result() + results.append(result) + + all_results = {} + for script_path, result in zip(python_scripts, results): + basename = os.path.basename(script_path) + main_id = basename.split('-')[0] + if all_results.get(main_id): + all_results[main_id].append(result) + else: + all_results[main_id] = [result] + + correct, sub_correct = 0, 0 + count, sub_count = 0, 0 + + for main_id in all_results: + correct += sum(all_results[main_id]) == 0 count += 1 + for sub in all_results[main_id]: + sub_correct += sub == 0 + sub_count += 1 result = { 'accuracy': 100 * correct / count, 'sub_accuracy': 100 * sub_correct / sub_count, - 'details': details } return result diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index aa75ea4d..8ee208ea 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -370,9 +370,9 @@ DATASETS_URL = { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ruler.zip", "md5": "c60bdfff3d02358067104cc1dea7c0f7", }, - "scicode/": { + "/scicode": { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/scicode.zip", - "md5": "06f64edad6680072e5bca3f0ce892d0c", + "md5": "9c6c64b8c70edc418f713419ea39989c", }, "/commonsenseqa": { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/commonsenseqa.zip",