fix and lint

This commit is contained in:
Myhs-phz 2025-03-03 10:38:44 +00:00
parent df64ae1997
commit 9e06ab535a
2 changed files with 29 additions and 27 deletions

View File

@ -129,18 +129,19 @@ def parse_args():
'correctness of each sample, bpb, etc.',
action='store_true',
)
parser.add_argument(
'--save-to-station',
help='Whether to save the evaluation results to the '
'data station.',
action='store_true',
)
parser.add_argument('-sp',
'--station-path',
help='Path to your reuslts station.',
help='Path to your results station.',
type=str,
default=None,
)
parser.add_argument('--station-overwrite',
help='Whether to overwrite the results at station.',
action='store_true',
)
parser.add_argument(
'--read-from-station',
help='Whether to save the evaluation results to the '
@ -290,7 +291,6 @@ def main():
rs_exist_results = [comb['combination'] for comb in existing_results_list]
cfg['rs_exist_results'] = rs_exist_results
# report to lark bot if specify --lark
if not args.lark:
cfg['lark_bot_url'] = None
@ -381,7 +381,7 @@ def main():
runner(tasks)
# save to station
if args.save_to_station:
if args.station_path is not None or cfg.get('station_path') is not None:
Save_To_Station(cfg, args)
# visualize

View File

@ -8,12 +8,10 @@ from opencompass.utils.abbr import dataset_abbr_from_cfg, model_abbr_from_cfg
def Save_To_Station(cfg, args):
assert args.station_path is not None or 'station_path' in cfg.keys(
) and cfg['station_path'] is not None
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
station_path = cfg['station_path']
else:
if args.station_path is not None:
station_path = args.station_path
else:
station_path = cfg.get('station_path')
work_dict = cfg['work_dir']
model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
@ -34,18 +32,21 @@ def Save_To_Station(cfg, args):
os.makedirs(result_path)
for model in model_list:
if [model, dataset] in rs_exist_results:
if ([model, dataset] in rs_exist_results
and not args.station_overwrite):
continue
result_file_name = model + '.json'
if osp.exists(osp.join(result_path, result_file_name)):
if osp.exists(
osp.join(result_path,
result_file_name)) and not args.station_overwrite:
print('result of {} with {} already exists'.format(
dataset, model))
continue
else:
# get result dict
local_result_path = work_dict + '/results/' + model + '/'
local_result_json = local_result_path + dataset + '.json'
local_result_path = osp.join(work_dict, 'results', model)
local_result_json = osp.join(local_result_path,
dataset + '.json')
if not osp.exists(local_result_json):
raise ValueError(
'invalid file: {}'.format(local_result_json))
@ -54,8 +55,8 @@ def Save_To_Station(cfg, args):
f.close()
# get prediction list
local_prediction_path = (work_dict + '/predictions/' + model +
'/')
local_prediction_path = osp.join(work_dict, 'predictions',
model)
local_prediction_regex = \
rf'^{re.escape(dataset)}(?:_\d+)?\.json$'
local_prediction_json = find_files_by_regex(
@ -66,7 +67,7 @@ def Save_To_Station(cfg, args):
this_prediction = []
for prediction_json in local_prediction_json:
with open(local_prediction_path + prediction_json,
with open(osp.join(local_prediction_path, prediction_json),
'r') as f:
this_prediction_load_json = json.load(f)
f.close()
@ -104,12 +105,11 @@ def Save_To_Station(cfg, args):
def Read_From_Station(cfg, args):
assert args.station_path is not None or 'station_path' in cfg.keys(
) and cfg['station_path'] is not None
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
station_path = cfg['station_path']
else:
assert args.station_path is not None or cfg.get('station_path') is not None
if args.station_path is not None:
station_path = args.station_path
else:
station_path = cfg.get('station_path')
model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
dataset_list = [
@ -150,6 +150,8 @@ def Read_From_Station(cfg, args):
os.makedirs(this_result_local_path)
this_result_local_file_path = osp.join(this_result_local_path,
i['combination'][1] + '.json')
if osp.exists(this_result_local_file_path):
continue
with open(this_result_local_file_path, 'w') as f:
json.dump(this_result, f, ensure_ascii=False, indent=4)
f.close()