mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix and lint
This commit is contained in:
parent
df64ae1997
commit
9e06ab535a
@ -129,18 +129,19 @@ def parse_args():
|
||||
'correctness of each sample, bpb, etc.',
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save-to-station',
|
||||
help='Whether to save the evaluation results to the '
|
||||
'data station.',
|
||||
action='store_true',
|
||||
)
|
||||
|
||||
parser.add_argument('-sp',
|
||||
'--station-path',
|
||||
help='Path to your reuslts station.',
|
||||
help='Path to your results station.',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument('--station-overwrite',
|
||||
help='Whether to overwrite the results at station.',
|
||||
action='store_true',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--read-from-station',
|
||||
help='Whether to save the evaluation results to the '
|
||||
@ -290,7 +291,6 @@ def main():
|
||||
rs_exist_results = [comb['combination'] for comb in existing_results_list]
|
||||
cfg['rs_exist_results'] = rs_exist_results
|
||||
|
||||
|
||||
# report to lark bot if specify --lark
|
||||
if not args.lark:
|
||||
cfg['lark_bot_url'] = None
|
||||
@ -381,7 +381,7 @@ def main():
|
||||
runner(tasks)
|
||||
|
||||
# save to station
|
||||
if args.save_to_station:
|
||||
if args.station_path is not None or cfg.get('station_path') is not None:
|
||||
Save_To_Station(cfg, args)
|
||||
|
||||
# visualize
|
||||
|
@ -8,12 +8,10 @@ from opencompass.utils.abbr import dataset_abbr_from_cfg, model_abbr_from_cfg
|
||||
|
||||
def Save_To_Station(cfg, args):
|
||||
|
||||
assert args.station_path is not None or 'station_path' in cfg.keys(
|
||||
) and cfg['station_path'] is not None
|
||||
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
|
||||
station_path = cfg['station_path']
|
||||
else:
|
||||
if args.station_path is not None:
|
||||
station_path = args.station_path
|
||||
else:
|
||||
station_path = cfg.get('station_path')
|
||||
|
||||
work_dict = cfg['work_dir']
|
||||
model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
|
||||
@ -34,18 +32,21 @@ def Save_To_Station(cfg, args):
|
||||
os.makedirs(result_path)
|
||||
|
||||
for model in model_list:
|
||||
if [model, dataset] in rs_exist_results:
|
||||
if ([model, dataset] in rs_exist_results
|
||||
and not args.station_overwrite):
|
||||
continue
|
||||
result_file_name = model + '.json'
|
||||
if osp.exists(osp.join(result_path, result_file_name)):
|
||||
if osp.exists(
|
||||
osp.join(result_path,
|
||||
result_file_name)) and not args.station_overwrite:
|
||||
print('result of {} with {} already exists'.format(
|
||||
dataset, model))
|
||||
continue
|
||||
else:
|
||||
|
||||
# get result dict
|
||||
local_result_path = work_dict + '/results/' + model + '/'
|
||||
local_result_json = local_result_path + dataset + '.json'
|
||||
local_result_path = osp.join(work_dict, 'results', model)
|
||||
local_result_json = osp.join(local_result_path,
|
||||
dataset + '.json')
|
||||
if not osp.exists(local_result_json):
|
||||
raise ValueError(
|
||||
'invalid file: {}'.format(local_result_json))
|
||||
@ -54,8 +55,8 @@ def Save_To_Station(cfg, args):
|
||||
f.close()
|
||||
|
||||
# get prediction list
|
||||
local_prediction_path = (work_dict + '/predictions/' + model +
|
||||
'/')
|
||||
local_prediction_path = osp.join(work_dict, 'predictions',
|
||||
model)
|
||||
local_prediction_regex = \
|
||||
rf'^{re.escape(dataset)}(?:_\d+)?\.json$'
|
||||
local_prediction_json = find_files_by_regex(
|
||||
@ -66,7 +67,7 @@ def Save_To_Station(cfg, args):
|
||||
|
||||
this_prediction = []
|
||||
for prediction_json in local_prediction_json:
|
||||
with open(local_prediction_path + prediction_json,
|
||||
with open(osp.join(local_prediction_path, prediction_json),
|
||||
'r') as f:
|
||||
this_prediction_load_json = json.load(f)
|
||||
f.close()
|
||||
@ -104,12 +105,11 @@ def Save_To_Station(cfg, args):
|
||||
|
||||
def Read_From_Station(cfg, args):
|
||||
|
||||
assert args.station_path is not None or 'station_path' in cfg.keys(
|
||||
) and cfg['station_path'] is not None
|
||||
if 'station_path' in cfg.keys() and cfg['station_path'] is not None:
|
||||
station_path = cfg['station_path']
|
||||
else:
|
||||
assert args.station_path is not None or cfg.get('station_path') is not None
|
||||
if args.station_path is not None:
|
||||
station_path = args.station_path
|
||||
else:
|
||||
station_path = cfg.get('station_path')
|
||||
|
||||
model_list = [model_abbr_from_cfg(model) for model in cfg['models']]
|
||||
dataset_list = [
|
||||
@ -150,6 +150,8 @@ def Read_From_Station(cfg, args):
|
||||
os.makedirs(this_result_local_path)
|
||||
this_result_local_file_path = osp.join(this_result_local_path,
|
||||
i['combination'][1] + '.json')
|
||||
if osp.exists(this_result_local_file_path):
|
||||
continue
|
||||
with open(this_result_local_file_path, 'w') as f:
|
||||
json.dump(this_result, f, ensure_ascii=False, indent=4)
|
||||
f.close()
|
||||
|
Loading…
Reference in New Issue
Block a user