mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
quick fix for postprocess pred extraction (#771)
This commit is contained in:
parent
0c75f0f95a
commit
3c606cb712
@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
task_abbr_from_cfg)
|
||||
|
||||
|
||||
def extract_role_pred(s: str, begin_str: Optional[str],
|
||||
end_str: Optional[str]) -> str:
|
||||
"""Extract the role prediction from the full prediction string. The role
|
||||
prediction may be the substring between the begin and end string.
|
||||
|
||||
Args:
|
||||
s (str): Full prediction string.
|
||||
begin_str (str): The beginning string of the role
|
||||
end_str (str): The ending string of the role.
|
||||
|
||||
Returns:
|
||||
str: The extracted role prediction.
|
||||
"""
|
||||
start = 0
|
||||
end = len(s)
|
||||
|
||||
if begin_str:
|
||||
begin_idx = s.find(begin_str)
|
||||
if begin_idx != -1:
|
||||
start = begin_idx + len(begin_str)
|
||||
|
||||
if end_str:
|
||||
# TODO: Support calling tokenizer for the accurate eos token
|
||||
# and avoid such hardcode
|
||||
end_idx = s.find(end_str, start)
|
||||
if end_idx != -1:
|
||||
end = end_idx
|
||||
|
||||
return s[start:end]
|
||||
|
||||
|
||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||
class OpenICLEvalTask(BaseTask):
|
||||
"""OpenICL Evaluation Task.
|
||||
@ -137,13 +168,13 @@ class OpenICLEvalTask(BaseTask):
|
||||
'must be list.')
|
||||
if pred_list_flag:
|
||||
pred_strs = [[
|
||||
self._extract_role_pred(_pred, role.get('begin', None),
|
||||
extract_role_pred(_pred, role.get('begin', None),
|
||||
role.get('end', None))
|
||||
for _pred in pred
|
||||
] for pred in pred_strs]
|
||||
else:
|
||||
pred_strs = [
|
||||
self._extract_role_pred(pred, role.get('begin', None),
|
||||
extract_role_pred(pred, role.get('begin', None),
|
||||
role.get('end', None))
|
||||
for pred in pred_strs
|
||||
]
|
||||
@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask):
|
||||
mkdir_or_exist(osp.split(out_path)[0])
|
||||
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
|
||||
|
||||
def _extract_role_pred(self, s: str, begin_str: Optional[str],
|
||||
end_str: Optional[str]) -> str:
|
||||
"""Extract the role prediction from the full prediction string. The
|
||||
role prediction may be the substring between the begin and end string.
|
||||
|
||||
Args:
|
||||
s (str): Full prediction string.
|
||||
begin_str (str): The beginning string of the role
|
||||
end_str (str): The ending string of the role.
|
||||
|
||||
Returns:
|
||||
str: The extracted role prediction.
|
||||
"""
|
||||
start = 0
|
||||
end = len(s)
|
||||
|
||||
if begin_str:
|
||||
begin_idx = s.find(begin_str)
|
||||
if begin_idx != -1:
|
||||
start = begin_idx + len(begin_str)
|
||||
|
||||
if end_str:
|
||||
# TODO: Support calling tokenizer for the accurate eos token
|
||||
# and avoid such hardcode
|
||||
end_idx = s.find(end_str, start)
|
||||
if end_idx != -1:
|
||||
end = end_idx
|
||||
|
||||
return s[start:end]
|
||||
|
||||
def format_details(self, predictions, references, details, pred_dicts):
|
||||
"""This function is responsible for formatting prediction details.
|
||||
|
||||
|
@ -4,7 +4,7 @@ import fnmatch
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.tasks.openicl_eval import extract_role_pred
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
model_abbr_from_cfg, task_abbr_from_cfg)
|
||||
@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask):
|
||||
filename = get_infer_output_path(
|
||||
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
|
||||
root, ext = osp.splitext(filename)
|
||||
filename = root[:-2] + ext
|
||||
last_underscore_index = root.rfind('_')
|
||||
root = root[:last_underscore_index]
|
||||
filename = root + ext
|
||||
# If take SubjectNaivePartition, get filename
|
||||
else:
|
||||
filename = get_infer_output_path(
|
||||
@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask):
|
||||
parser = LMTemplateParser(model_cfg['meta_template'])
|
||||
role = parser.roles[eval_cfg['pred_role']]
|
||||
pred_strs = [
|
||||
self._extract_role_pred(pred, role.get('begin', None),
|
||||
role.get('end', None))
|
||||
for pred in pred_strs
|
||||
extract_role_pred(pred, role.get('begin', None),
|
||||
role.get('end', None)) for pred in pred_strs
|
||||
]
|
||||
|
||||
# Postprocess predictions if necessary
|
||||
@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask):
|
||||
ensure_ascii=False,
|
||||
indent=4)
|
||||
|
||||
def _extract_role_pred(self, s: str, begin_str: Optional[str],
|
||||
end_str: Optional[str]) -> str:
|
||||
"""Extract the role prediction from the full prediction string. The
|
||||
role prediction may be the substring between the begin and end string.
|
||||
|
||||
Args:
|
||||
s (str): Full prediction string.
|
||||
begin_str (str): The beginning string of the role
|
||||
end_str (str): The ending string of the role.
|
||||
|
||||
Returns:
|
||||
str: The extracted role prediction.
|
||||
"""
|
||||
start = 0
|
||||
end = len(s)
|
||||
|
||||
if begin_str:
|
||||
begin_idx = s.find(begin_str)
|
||||
if begin_idx != -1:
|
||||
start = begin_idx + len(begin_str)
|
||||
|
||||
if end_str:
|
||||
# TODO: Support calling tokenizer for the accurate eos token
|
||||
# and avoid such hardcode
|
||||
end_idx = s.find(end_str[:1], start)
|
||||
if end_idx != -1:
|
||||
end = end_idx
|
||||
|
||||
return s[start:end]
|
||||
|
||||
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
||||
"""Get the paths to the output files. Every file should exist if the
|
||||
task succeeds.
|
||||
|
Loading…
Reference in New Issue
Block a user