fix and lint

This commit is contained in:
Myhs-phz 2025-03-03 10:38:44 +00:00
parent df64ae1997
commit 9e06ab535a
2 changed files with 29 additions and 27 deletions

View File

@ -129,18 +129,19 @@ def parse_args():
'correctness of each sample, bpb, etc.', 'correctness of each sample, bpb, etc.',
action='store_true', action='store_true',
) )
parser.add_argument(
'--save-to-station',
help='Whether to save the evaluation results to the '
'data station.',
action='store_true',
)
parser.add_argument('-sp', parser.add_argument('-sp',
'--station-path', '--station-path',
help='Path to your reuslts station.', help='Path to your results station.',
type=str, type=str,
default=None, default=None,
) )
parser.add_argument('--station-overwrite',
help='Whether to overwrite the results at station.',
action='store_true',
)
parser.add_argument( parser.add_argument(
'--read-from-station', '--read-from-station',
help='Whether to save the evaluation results to the ' help='Whether to save the evaluation results to the '
@ -290,7 +291,6 @@ def main():
rs_exist_results = [comb['combination'] for comb in existing_results_list] rs_exist_results = [comb['combination'] for comb in existing_results_list]
cfg['rs_exist_results'] = rs_exist_results cfg['rs_exist_results'] = rs_exist_results
# report to lark bot if specify --lark # report to lark bot if specify --lark
if not args.lark: if not args.lark:
cfg['lark_bot_url'] = None cfg['lark_bot_url'] = None
@ -381,7 +381,7 @@ def main():
runner(tasks) runner(tasks)
# save to station # save to station
if args.save_to_station: if args.station_path is not None or cfg.get('station_path') is not None:
Save_To_Station(cfg, args) Save_To_Station(cfg, args)
# visualize # visualize

View File

@ -8,12 +8,10 @@ from opencompass.utils.abbr import dataset_abbr_from_cfg, model_abbr_from_cfg
def Save_To_Station(cfg, args): def Save_To_Station(cfg, args):
assert args.station_path is not None or 'station_path' in cfg.keys( if args.station_path is not None:
) and cfg['station_path'] is not None
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
station_path = cfg['station_path']
else:
station_path = args.station_path station_path = args.station_path
else:
station_path = cfg.get('station_path')
work_dict = cfg['work_dir'] work_dict = cfg['work_dir']
model_list = [model_abbr_from_cfg(model) for model in cfg['models']] model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
@ -34,18 +32,21 @@ def Save_To_Station(cfg, args):
os.makedirs(result_path) os.makedirs(result_path)
for model in model_list: for model in model_list:
if [model, dataset] in rs_exist_results: if ([model, dataset] in rs_exist_results
and not args.station_overwrite):
continue continue
result_file_name = model + '.json' result_file_name = model + '.json'
if osp.exists(osp.join(result_path, result_file_name)): if osp.exists(
osp.join(result_path,
result_file_name)) and not args.station_overwrite:
print('result of {} with {} already exists'.format( print('result of {} with {} already exists'.format(
dataset, model)) dataset, model))
continue continue
else: else:
# get result dict # get result dict
local_result_path = work_dict + '/results/' + model + '/' local_result_path = osp.join(work_dict, 'results', model)
local_result_json = local_result_path + dataset + '.json' local_result_json = osp.join(local_result_path,
dataset + '.json')
if not osp.exists(local_result_json): if not osp.exists(local_result_json):
raise ValueError( raise ValueError(
'invalid file: {}'.format(local_result_json)) 'invalid file: {}'.format(local_result_json))
@ -54,8 +55,8 @@ def Save_To_Station(cfg, args):
f.close() f.close()
# get prediction list # get prediction list
local_prediction_path = (work_dict + '/predictions/' + model + local_prediction_path = osp.join(work_dict, 'predictions',
'/') model)
local_prediction_regex = \ local_prediction_regex = \
rf'^{re.escape(dataset)}(?:_\d+)?\.json$' rf'^{re.escape(dataset)}(?:_\d+)?\.json$'
local_prediction_json = find_files_by_regex( local_prediction_json = find_files_by_regex(
@ -66,7 +67,7 @@ def Save_To_Station(cfg, args):
this_prediction = [] this_prediction = []
for prediction_json in local_prediction_json: for prediction_json in local_prediction_json:
with open(local_prediction_path + prediction_json, with open(osp.join(local_prediction_path, prediction_json),
'r') as f: 'r') as f:
this_prediction_load_json = json.load(f) this_prediction_load_json = json.load(f)
f.close() f.close()
@ -104,12 +105,11 @@ def Save_To_Station(cfg, args):
def Read_From_Station(cfg, args): def Read_From_Station(cfg, args):
assert args.station_path is not None or 'station_path' in cfg.keys( assert args.station_path is not None or cfg.get('station_path') is not None
) and cfg['station_path'] is not None if args.station_path is not None:
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
station_path = cfg['station_path']
else:
station_path = args.station_path station_path = args.station_path
else:
station_path = cfg.get('station_path')
model_list = [model_abbr_from_cfg(model) for model in cfg['models']] model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
dataset_list = [ dataset_list = [
@ -150,6 +150,8 @@ def Read_From_Station(cfg, args):
os.makedirs(this_result_local_path) os.makedirs(this_result_local_path)
this_result_local_file_path = osp.join(this_result_local_path, this_result_local_file_path = osp.join(this_result_local_path,
i['combination'][1] + '.json') i['combination'][1] + '.json')
if osp.exists(this_result_local_file_path):
continue
with open(this_result_local_file_path, 'w') as f: with open(this_result_local_file_path, 'w') as f:
json.dump(this_result, f, ensure_ascii=False, indent=4) json.dump(this_result, f, ensure_ascii=False, indent=4)
f.close() f.close()