add support for set prediction path (#984)

This commit is contained in:
bittersweet1999 2024-03-19 14:32:15 +08:00 committed by GitHub
parent 4d2591acb2
commit c78a4df923
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 9 deletions

View File

@ -9,7 +9,7 @@ subjective_reader_cfg = dict(
output_column='judge',
)
data_path ="data/subjective/"
data_path ="data/subjective/compass_arena"
subjective_datasets = []

View File

@ -28,8 +28,8 @@ class BasePartitioner:
self.out_dir = out_dir
if keep_keys is None:
self.keep_keys = [
'eval.runner.task.judge_cfg',
'eval.runner.task.dump_details',
'eval.runner.task.judge_cfg', 'eval.runner.task.dump_details',
'eval.given_pred'
]
else:
self.keep_keys = keep_keys

View File

@ -50,6 +50,7 @@ class SubjectiveEvalTask(BaseTask):
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.judge_cfg = copy.deepcopy(judge_cfg)
self.given_pred = cfg.eval.get('given_pred', [])
def get_command(self, cfg_path, template):
"""Get the command template for the task.
@ -89,12 +90,16 @@ class SubjectiveEvalTask(BaseTask):
continue
self._score(model_cfg, dataset_cfg, eval_cfg, output_column)
def _load_model_pred(self, model_cfg: Union[ConfigDict, List[ConfigDict]],
dataset_cfg: ConfigDict,
eval_cfg: ConfigDict) -> Union[None, List[str]]:
def _load_model_pred(
self,
model_cfg: Union[ConfigDict, List[ConfigDict]],
dataset_cfg: ConfigDict,
eval_cfg: ConfigDict,
given_preds: List[dict],
) -> Union[None, List[str]]:
if isinstance(model_cfg, (tuple, list)):
return [
self._load_model_pred(m, dataset_cfg, eval_cfg)
self._load_model_pred(m, dataset_cfg, eval_cfg, given_preds)
for m in model_cfg
]
@ -119,7 +124,11 @@ class SubjectiveEvalTask(BaseTask):
else:
filename = get_infer_output_path(
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
for given_pred in given_preds:
abbr = given_pred['abbr']
path = given_pred['path']
if abbr == model_cfg['abbr']:
filename = osp.join(path, osp.basename(filename))
# Get partition name
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
@ -209,7 +218,7 @@ class SubjectiveEvalTask(BaseTask):
if len(new_model_cfg) == 1:
new_model_cfg = new_model_cfg[0]
model_preds = self._load_model_pred(new_model_cfg, dataset_cfg,
eval_cfg)
eval_cfg, self.given_pred)
if not self.judge_cfg:
raise ValueError('missing "eval.runner.task.judge_cfg"')
eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg