diff --git a/opencompass/cli/main.py b/opencompass/cli/main.py index 7e19a517..5c8748ba 100644 --- a/opencompass/cli/main.py +++ b/opencompass/cli/main.py @@ -129,18 +129,19 @@ def parse_args(): 'correctness of each sample, bpb, etc.', action='store_true', ) - parser.add_argument( - '--save-to-station', - help='Whether to save the evaluation results to the ' - 'data station.', - action='store_true', - ) + parser.add_argument('-sp', '--station-path', - help='Path to your reuslts station.', + help='Path to your results station.', type=str, default=None, ) + + parser.add_argument('--station-overwrite', + help='Whether to overwrite the results at station.', + action='store_true', + ) + parser.add_argument( '--read-from-station', help='Whether to save the evaluation results to the ' @@ -290,7 +291,6 @@ def main(): rs_exist_results = [comb['combination'] for comb in existing_results_list] cfg['rs_exist_results'] = rs_exist_results - # report to lark bot if specify --lark if not args.lark: cfg['lark_bot_url'] = None @@ -381,7 +381,7 @@ def main(): runner(tasks) # save to station - if args.save_to_station: + if args.station_path is not None or cfg.get('station_path') is not None: Save_To_Station(cfg, args) # visualize diff --git a/opencompass/utils/result_station.py b/opencompass/utils/result_station.py index 0b14f875..0ca7afbf 100644 --- a/opencompass/utils/result_station.py +++ b/opencompass/utils/result_station.py @@ -8,12 +8,10 @@ from opencompass.utils.abbr import dataset_abbr_from_cfg, model_abbr_from_cfg def Save_To_Station(cfg, args): - assert args.station_path is not None or 'station_path' in cfg.keys( - ) and cfg['station_path'] is not None - if 'station_path' in cfg.keys() and cfg['station_path'] is not None: - station_path = cfg['station_path'] - else: + if args.station_path is not None: station_path = args.station_path + else: + station_path = cfg.get('station_path') work_dict = cfg['work_dir'] model_list = [model_abbr_from_cfg(model) for model in cfg['models']] @@ -34,18 +32,21 @@ def Save_To_Station(cfg, args): os.makedirs(result_path) for model in model_list: - if [model, dataset] in rs_exist_results: + if ([model, dataset] in rs_exist_results + and not args.station_overwrite): continue result_file_name = model + '.json' - if osp.exists(osp.join(result_path, result_file_name)): + if osp.exists( + osp.join(result_path, + result_file_name)) and not args.station_overwrite: print('result of {} with {} already exists'.format( dataset, model)) continue else: - # get result dict - local_result_path = work_dict + '/results/' + model + '/' - local_result_json = local_result_path + dataset + '.json' + local_result_path = osp.join(work_dict, 'results', model) + local_result_json = osp.join(local_result_path, + dataset + '.json') if not osp.exists(local_result_json): raise ValueError( 'invalid file: {}'.format(local_result_json)) @@ -54,8 +55,8 @@ def Save_To_Station(cfg, args): f.close() # get prediction list - local_prediction_path = (work_dict + '/predictions/' + model + - '/') + local_prediction_path = osp.join(work_dict, 'predictions', + model) local_prediction_regex = \ rf'^{re.escape(dataset)}(?:_\d+)?\.json$' local_prediction_json = find_files_by_regex( @@ -66,7 +67,7 @@ def Save_To_Station(cfg, args): this_prediction = [] for prediction_json in local_prediction_json: - with open(local_prediction_path + prediction_json, + with open(osp.join(local_prediction_path, prediction_json), 'r') as f: this_prediction_load_json = json.load(f) f.close() @@ -104,12 +105,11 @@ def Save_To_Station(cfg, args): def Read_From_Station(cfg, args): - assert args.station_path is not None or 'station_path' in cfg.keys( - ) and cfg['station_path'] is not None - if 'station_path' in cfg.keys() and cfg['station_path'] is not None: - station_path = cfg['station_path'] - else: + assert args.station_path is not None or cfg.get('station_path') is not None + if args.station_path is not None: station_path = args.station_path + else: + station_path = cfg.get('station_path') model_list = [model_abbr_from_cfg(model) for model in cfg['models']] dataset_list = [ @@ -150,6 +150,8 @@ def Read_From_Station(cfg, args): os.makedirs(this_result_local_path) this_result_local_file_path = osp.join(this_result_local_path, i['combination'][1] + '.json') + if osp.exists(this_result_local_file_path): + continue with open(this_result_local_file_path, 'w') as f: json.dump(this_result, f, ensure_ascii=False, indent=4) f.close()