[Fix] Fix cli evaluation for multiple models (#1454)

* update

* update
This commit is contained in:
Linchen Xiao 2024-08-23 17:15:36 +08:00 committed by GitHub
parent 2295a33a18
commit 94b6bd65fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,16 +41,18 @@ def match_cfg_file(workdir: Union[str, List[str]],
if len(files) != len(pattern): if len(files) != len(pattern):
nomatched = [] nomatched = []
ambiguous = [] ambiguous = []
ambiguous_return_list = []
err_msg = ('The provided pattern matches 0 or more than one ' err_msg = ('The provided pattern matches 0 or more than one '
'config. Please verify your pattern and try again. ' 'config. Please verify your pattern and try again. '
'You may use tools/list_configs.py to list or ' 'You may use tools/list_configs.py to list or '
'locate the configurations.\n') 'locate the configurations.\n')
for p in pattern: for p in pattern:
files = _mf_with_multi_workdirs(workdir, p, fuzzy=False) files_ = _mf_with_multi_workdirs(workdir, p, fuzzy=False)
if len(files) == 0: if len(files_) == 0:
nomatched.append([p[:-3]]) nomatched.append([p[:-3]])
elif len(files) > 1: elif len(files_) > 1:
ambiguous.append([p[:-3], '\n'.join(f[1] for f in files)]) ambiguous.append([p[:-3], '\n'.join(f[1] for f in files_)])
ambiguous_return_list.append(files_[0])
if nomatched: if nomatched:
table = [['Not matched patterns'], *nomatched] table = [['Not matched patterns'], *nomatched]
err_msg += tabulate.tabulate(table, err_msg += tabulate.tabulate(table,
@ -58,12 +60,12 @@ def match_cfg_file(workdir: Union[str, List[str]],
tablefmt='psql') tablefmt='psql')
if ambiguous: if ambiguous:
table = [['Ambiguous patterns', 'Matched files'], *ambiguous] table = [['Ambiguous patterns', 'Matched files'], *ambiguous]
warning_msg = 'Found ambiguous patterns, using the first matched config.' warning_msg = 'Found ambiguous patterns, using the first matched config.\n'
warning_msg += tabulate.tabulate(table, warning_msg += tabulate.tabulate(table,
headers='firstrow', headers='firstrow',
tablefmt='psql') tablefmt='psql')
logger.warning(warning_msg) logger.warning(warning_msg)
return [files[0]] return ambiguous_return_list
raise ValueError(err_msg) raise ValueError(err_msg)
return files return files
@ -162,8 +164,8 @@ def get_config_from_arg(args) -> Config:
] ]
if args.models: if args.models:
# model_dir = os.path.join(args.config_dir, 'models') for model_arg in args.models:
for model in match_cfg_file(models_dir, args.models): for model in match_cfg_file(models_dir, [model_arg]):
logger.info(f'Loading {model[0]}: {model[1]}') logger.info(f'Loading {model[0]}: {model[1]}')
cfg = Config.fromfile(model[1]) cfg = Config.fromfile(model[1])
if 'models' not in cfg: if 'models' not in cfg: