mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
feat save_to_station
This commit is contained in:
parent
114cf1366c
commit
2aaab41dc9
@ -12,7 +12,7 @@ from mmengine.config import Config, DictAction
|
|||||||
from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg
|
from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg
|
||||||
from opencompass.runners import SlurmRunner
|
from opencompass.runners import SlurmRunner
|
||||||
from opencompass.summarizers import DefaultSummarizer
|
from opencompass.summarizers import DefaultSummarizer
|
||||||
from opencompass.utils import LarkReporter, get_logger
|
from opencompass.utils import LarkReporter, get_logger, Save_To_Station
|
||||||
from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg,
|
from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg,
|
||||||
get_config_from_arg)
|
get_config_from_arg)
|
||||||
|
|
||||||
@ -127,6 +127,26 @@ def parse_args():
|
|||||||
'correctness of each sample, bpb, etc.',
|
'correctness of each sample, bpb, etc.',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--save-to-station',
|
||||||
|
help='Whether to save the evaluation results to the '
|
||||||
|
'data station.',
|
||||||
|
action='store_true',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--read-station',
|
||||||
|
help='Whether to read the evaluation results from the '
|
||||||
|
'data station.',
|
||||||
|
action='store_true',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--station-path',
|
||||||
|
help='Path to your reuslts station.',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# set srun args
|
# set srun args
|
||||||
slurm_parser = parser.add_argument_group('slurm_args')
|
slurm_parser = parser.add_argument_group('slurm_args')
|
||||||
parse_slurm_args(slurm_parser)
|
parse_slurm_args(slurm_parser)
|
||||||
@ -269,6 +289,7 @@ def main():
|
|||||||
content = f'{getpass.getuser()}\'s task has been launched!'
|
content = f'{getpass.getuser()}\'s task has been launched!'
|
||||||
LarkReporter(cfg['lark_bot_url']).post(content)
|
LarkReporter(cfg['lark_bot_url']).post(content)
|
||||||
|
|
||||||
|
# infer
|
||||||
if args.mode in ['all', 'infer']:
|
if args.mode in ['all', 'infer']:
|
||||||
# When user have specified --slurm or --dlc, or have not set
|
# When user have specified --slurm or --dlc, or have not set
|
||||||
# "infer" in config, we will provide a default configuration
|
# "infer" in config, we will provide a default configuration
|
||||||
@ -350,6 +371,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
runner(tasks)
|
runner(tasks)
|
||||||
|
|
||||||
|
# save to station
|
||||||
|
if args.save_to_station:
|
||||||
|
if Save_To_Station(cfg, args):
|
||||||
|
logger.info('Successfully saved to station.')
|
||||||
|
else:
|
||||||
|
logger.warning('Failed to save result to station.')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# visualize
|
# visualize
|
||||||
if args.mode in ['all', 'eval', 'viz']:
|
if args.mode in ['all', 'eval', 'viz']:
|
||||||
summarizer_cfg = cfg.get('summarizer', {})
|
summarizer_cfg = cfg.get('summarizer', {})
|
||||||
|
@ -15,3 +15,4 @@ from .network import * # noqa
|
|||||||
from .postprocessors import * # noqa
|
from .postprocessors import * # noqa
|
||||||
from .prompt import * # noqa
|
from .prompt import * # noqa
|
||||||
from .text_postprocessors import * # noqa
|
from .text_postprocessors import * # noqa
|
||||||
|
from .result_station import * # noqa
|
||||||
|
112
opencompass/utils/result_station.py
Normal file
112
opencompass/utils/result_station.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
from mmengine.config import Config
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def Save_To_Station(cfg, args):
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
station_path = os.getenv('RESULTS_STATION_PATH')
|
||||||
|
assert station_path != None or args.station_path != None
|
||||||
|
station_path = args.station_path if station_path == None else station_path
|
||||||
|
|
||||||
|
|
||||||
|
work_dict = cfg['work_dir']
|
||||||
|
model_list = [i['abbr'] for i in cfg['models']]
|
||||||
|
dataset_list = [i['abbr'] for i in cfg['datasets']]
|
||||||
|
|
||||||
|
|
||||||
|
for dataset in dataset_list:
|
||||||
|
result_path = osp.join(station_path, dataset)
|
||||||
|
if not osp.exists(result_path):
|
||||||
|
os.makedirs(result_path)
|
||||||
|
|
||||||
|
for model in model_list:
|
||||||
|
result_file_name = model + '.json'
|
||||||
|
if osp.exists(osp.join(result_path, result_file_name)):
|
||||||
|
print('result of {} with {} already exists'.format(dataset, model))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
|
||||||
|
# get result dict
|
||||||
|
local_result_path = work_dict + '/results/' + model + '/'
|
||||||
|
local_result_json = local_result_path + dataset + '.json'
|
||||||
|
if not osp.exists(local_result_json):
|
||||||
|
raise ValueError('invalid file: {}'.format(local_result_json))
|
||||||
|
with open(local_result_json, 'r') as f:
|
||||||
|
this_result = json.load(f)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
# get prediction list
|
||||||
|
local_prediction_path = work_dict + '/predictions/' + model + '/'
|
||||||
|
local_prediction_regex = rf"^{re.escape(dataset)}(?:_\d+)?\.json$"
|
||||||
|
local_prediction_json = find_files_by_regex(local_prediction_path, local_prediction_regex)
|
||||||
|
if not check_filenames(dataset, local_prediction_json):
|
||||||
|
raise ValueError('invalid filelist: {}'.format(local_prediction_json))
|
||||||
|
|
||||||
|
this_prediction = []
|
||||||
|
for prediction_json in local_prediction_json:
|
||||||
|
with open(local_prediction_path + prediction_json, 'r') as f:
|
||||||
|
this_prediction_load_json = json.load(f)
|
||||||
|
f.close()
|
||||||
|
for prekey in this_prediction_load_json.keys():
|
||||||
|
this_prediction.append(this_prediction_load_json[prekey])
|
||||||
|
|
||||||
|
# dict combine
|
||||||
|
data_model_results = {
|
||||||
|
'predictions': this_prediction,
|
||||||
|
'results': this_result
|
||||||
|
}
|
||||||
|
with open(osp.join(result_path, result_file_name), 'w') as f:
|
||||||
|
json.dump(data_model_results, f, ensure_ascii=False, indent=4)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def find_files_by_regex(directory, pattern):
|
||||||
|
|
||||||
|
regex = re.compile(pattern)
|
||||||
|
|
||||||
|
matched_files = []
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
if regex.match(filename):
|
||||||
|
matched_files.append(filename)
|
||||||
|
|
||||||
|
return matched_files
|
||||||
|
|
||||||
|
|
||||||
|
def check_filenames(x, filenames):
|
||||||
|
|
||||||
|
if not filenames:
|
||||||
|
return False
|
||||||
|
|
||||||
|
single_pattern = re.compile(rf"^{re.escape(x)}\.json$")
|
||||||
|
numbered_pattern = re.compile(rf"^{re.escape(x)}_(\d+)\.json$")
|
||||||
|
|
||||||
|
is_single = all(single_pattern.match(name) for name in filenames)
|
||||||
|
is_numbered = all(numbered_pattern.match(name) for name in filenames)
|
||||||
|
|
||||||
|
if not (is_single or is_numbered):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_single:
|
||||||
|
return len(filenames) == 1
|
||||||
|
|
||||||
|
if is_numbered:
|
||||||
|
numbers = []
|
||||||
|
for name in filenames:
|
||||||
|
match = numbered_pattern.match(name)
|
||||||
|
if match:
|
||||||
|
numbers.append(int(match.group(1)))
|
||||||
|
|
||||||
|
if sorted(numbers) != list(range(len(numbers))):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
Loading…
Reference in New Issue
Block a user