mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add support for set prediction path (#984)
This commit is contained in:
parent
4d2591acb2
commit
c78a4df923
@ -9,7 +9,7 @@ subjective_reader_cfg = dict(
|
||||
output_column='judge',
|
||||
)
|
||||
|
||||
data_path ="data/subjective/"
|
||||
data_path ="data/subjective/compass_arena"
|
||||
|
||||
subjective_datasets = []
|
||||
|
||||
|
@ -28,8 +28,8 @@ class BasePartitioner:
|
||||
self.out_dir = out_dir
|
||||
if keep_keys is None:
|
||||
self.keep_keys = [
|
||||
'eval.runner.task.judge_cfg',
|
||||
'eval.runner.task.dump_details',
|
||||
'eval.runner.task.judge_cfg', 'eval.runner.task.dump_details',
|
||||
'eval.given_pred'
|
||||
]
|
||||
else:
|
||||
self.keep_keys = keep_keys
|
||||
|
@ -50,6 +50,7 @@ class SubjectiveEvalTask(BaseTask):
|
||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||
self.num_procs = run_cfg.get('num_procs', 1)
|
||||
self.judge_cfg = copy.deepcopy(judge_cfg)
|
||||
self.given_pred = cfg.eval.get('given_pred', [])
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
@ -89,12 +90,16 @@ class SubjectiveEvalTask(BaseTask):
|
||||
continue
|
||||
self._score(model_cfg, dataset_cfg, eval_cfg, output_column)
|
||||
|
||||
def _load_model_pred(self, model_cfg: Union[ConfigDict, List[ConfigDict]],
|
||||
dataset_cfg: ConfigDict,
|
||||
eval_cfg: ConfigDict) -> Union[None, List[str]]:
|
||||
def _load_model_pred(
|
||||
self,
|
||||
model_cfg: Union[ConfigDict, List[ConfigDict]],
|
||||
dataset_cfg: ConfigDict,
|
||||
eval_cfg: ConfigDict,
|
||||
given_preds: List[dict],
|
||||
) -> Union[None, List[str]]:
|
||||
if isinstance(model_cfg, (tuple, list)):
|
||||
return [
|
||||
self._load_model_pred(m, dataset_cfg, eval_cfg)
|
||||
self._load_model_pred(m, dataset_cfg, eval_cfg, given_preds)
|
||||
for m in model_cfg
|
||||
]
|
||||
|
||||
@ -119,7 +124,11 @@ class SubjectiveEvalTask(BaseTask):
|
||||
else:
|
||||
filename = get_infer_output_path(
|
||||
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
|
||||
|
||||
for given_pred in given_preds:
|
||||
abbr = given_pred['abbr']
|
||||
path = given_pred['path']
|
||||
if abbr == model_cfg['abbr']:
|
||||
filename = osp.join(path, osp.basename(filename))
|
||||
# Get partition name
|
||||
root, ext = osp.splitext(filename)
|
||||
partial_filename = root + '_0' + ext
|
||||
@ -209,7 +218,7 @@ class SubjectiveEvalTask(BaseTask):
|
||||
if len(new_model_cfg) == 1:
|
||||
new_model_cfg = new_model_cfg[0]
|
||||
model_preds = self._load_model_pred(new_model_cfg, dataset_cfg,
|
||||
eval_cfg)
|
||||
eval_cfg, self.given_pred)
|
||||
if not self.judge_cfg:
|
||||
raise ValueError('missing "eval.runner.task.judge_cfg"')
|
||||
eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg
|
||||
|
Loading…
Reference in New Issue
Block a user