From 2aaab41dc9f9180436600948da6a0a44a11c4ffa Mon Sep 17 00:00:00 2001 From: Myhs-phz Date: Thu, 27 Feb 2025 14:44:24 +0000 Subject: [PATCH] feat save_to_station --- opencompass/cli/main.py | 32 +++++++- opencompass/utils/__init__.py | 1 + opencompass/utils/result_station.py | 112 ++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 opencompass/utils/result_station.py diff --git a/opencompass/cli/main.py b/opencompass/cli/main.py index 63377371..8dbb9f73 100644 --- a/opencompass/cli/main.py +++ b/opencompass/cli/main.py @@ -12,7 +12,7 @@ from mmengine.config import Config, DictAction from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg from opencompass.runners import SlurmRunner from opencompass.summarizers import DefaultSummarizer -from opencompass.utils import LarkReporter, get_logger +from opencompass.utils import LarkReporter, get_logger, Save_To_Station from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg, get_config_from_arg) @@ -127,6 +127,26 @@ def parse_args(): 'correctness of each sample, bpb, etc.', action='store_true', ) + parser.add_argument( + '--save-to-station', + help='Whether to save the evaluation results to the ' + 'data station.', + action='store_true', + ) + parser.add_argument( + '--read-station', + help='Whether to read the evaluation results from the ' + 'data station.', + action='store_true', + ) + parser.add_argument( + '--station-path', + help='Path to your reuslts station.', + type=str, + default=None, + ) + + # set srun args slurm_parser = parser.add_argument_group('slurm_args') parse_slurm_args(slurm_parser) @@ -269,6 +289,7 @@ def main(): content = f'{getpass.getuser()}\'s task has been launched!' LarkReporter(cfg['lark_bot_url']).post(content) + # infer if args.mode in ['all', 'infer']: # When user have specified --slurm or --dlc, or have not set # "infer" in config, we will provide a default configuration @@ -350,6 +371,15 @@ def main(): else: runner(tasks) + # save to station + if args.save_to_station: + if Save_To_Station(cfg, args): + logger.info('Successfully saved to station.') + else: + logger.warning('Failed to save result to station.') + + + # visualize if args.mode in ['all', 'eval', 'viz']: summarizer_cfg = cfg.get('summarizer', {}) diff --git a/opencompass/utils/__init__.py b/opencompass/utils/__init__.py index 2e528663..91b0e4d6 100644 --- a/opencompass/utils/__init__.py +++ b/opencompass/utils/__init__.py @@ -15,3 +15,4 @@ from .network import * # noqa from .postprocessors import * # noqa from .prompt import * # noqa from .text_postprocessors import * # noqa +from .result_station import * # noqa diff --git a/opencompass/utils/result_station.py b/opencompass/utils/result_station.py new file mode 100644 index 00000000..c7ba14a5 --- /dev/null +++ b/opencompass/utils/result_station.py @@ -0,0 +1,112 @@ +import os +import os.path as osp +from typing import List, Tuple, Union +from mmengine.config import Config +import json +import re + + + + + +def Save_To_Station(cfg, args): + from dotenv import load_dotenv + load_dotenv() + station_path = os.getenv('RESULTS_STATION_PATH') + assert station_path != None or args.station_path != None + station_path = args.station_path if station_path == None else station_path + + + work_dict = cfg['work_dir'] + model_list = [i['abbr'] for i in cfg['models']] + dataset_list = [i['abbr'] for i in cfg['datasets']] + + + for dataset in dataset_list: + result_path = osp.join(station_path, dataset) + if not osp.exists(result_path): + os.makedirs(result_path) + + for model in model_list: + result_file_name = model + '.json' + if osp.exists(osp.join(result_path, result_file_name)): + print('result of {} with {} already exists'.format(dataset, model)) + continue + else: + + # get result dict + local_result_path = work_dict + '/results/' + model + '/' + local_result_json = local_result_path + dataset + '.json' + if not osp.exists(local_result_json): + raise ValueError('invalid file: {}'.format(local_result_json)) + with open(local_result_json, 'r') as f: + this_result = json.load(f) + f.close() + + # get prediction list + local_prediction_path = work_dict + '/predictions/' + model + '/' + local_prediction_regex = rf"^{re.escape(dataset)}(?:_\d+)?\.json$" + local_prediction_json = find_files_by_regex(local_prediction_path, local_prediction_regex) + if not check_filenames(dataset, local_prediction_json): + raise ValueError('invalid filelist: {}'.format(local_prediction_json)) + + this_prediction = [] + for prediction_json in local_prediction_json: + with open(local_prediction_path + prediction_json, 'r') as f: + this_prediction_load_json = json.load(f) + f.close() + for prekey in this_prediction_load_json.keys(): + this_prediction.append(this_prediction_load_json[prekey]) + + # dict combine + data_model_results = { + 'predictions': this_prediction, + 'results': this_result + } + with open(osp.join(result_path, result_file_name), 'w') as f: + json.dump(data_model_results, f, ensure_ascii=False, indent=4) + f.close() + + return True + + +def find_files_by_regex(directory, pattern): + + regex = re.compile(pattern) + + matched_files = [] + for filename in os.listdir(directory): + if regex.match(filename): + matched_files.append(filename) + + return matched_files + + +def check_filenames(x, filenames): + + if not filenames: + return False + + single_pattern = re.compile(rf"^{re.escape(x)}\.json$") + numbered_pattern = re.compile(rf"^{re.escape(x)}_(\d+)\.json$") + + is_single = all(single_pattern.match(name) for name in filenames) + is_numbered = all(numbered_pattern.match(name) for name in filenames) + + if not (is_single or is_numbered): + return False + + if is_single: + return len(filenames) == 1 + + if is_numbered: + numbers = [] + for name in filenames: + match = numbered_pattern.match(name) + if match: + numbers.append(int(match.group(1))) + + if sorted(numbers) != list(range(len(numbers))): + return False + + return True