mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
quick fix for postprocess pred extraction (#771)
This commit is contained in:
parent
0c75f0f95a
commit
3c606cb712
@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
|||||||
task_abbr_from_cfg)
|
task_abbr_from_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_role_pred(s: str, begin_str: Optional[str],
|
||||||
|
end_str: Optional[str]) -> str:
|
||||||
|
"""Extract the role prediction from the full prediction string. The role
|
||||||
|
prediction may be the substring between the begin and end string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (str): Full prediction string.
|
||||||
|
begin_str (str): The beginning string of the role
|
||||||
|
end_str (str): The ending string of the role.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The extracted role prediction.
|
||||||
|
"""
|
||||||
|
start = 0
|
||||||
|
end = len(s)
|
||||||
|
|
||||||
|
if begin_str:
|
||||||
|
begin_idx = s.find(begin_str)
|
||||||
|
if begin_idx != -1:
|
||||||
|
start = begin_idx + len(begin_str)
|
||||||
|
|
||||||
|
if end_str:
|
||||||
|
# TODO: Support calling tokenizer for the accurate eos token
|
||||||
|
# and avoid such hardcode
|
||||||
|
end_idx = s.find(end_str, start)
|
||||||
|
if end_idx != -1:
|
||||||
|
end = end_idx
|
||||||
|
|
||||||
|
return s[start:end]
|
||||||
|
|
||||||
|
|
||||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||||
class OpenICLEvalTask(BaseTask):
|
class OpenICLEvalTask(BaseTask):
|
||||||
"""OpenICL Evaluation Task.
|
"""OpenICL Evaluation Task.
|
||||||
@ -137,13 +168,13 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
'must be list.')
|
'must be list.')
|
||||||
if pred_list_flag:
|
if pred_list_flag:
|
||||||
pred_strs = [[
|
pred_strs = [[
|
||||||
self._extract_role_pred(_pred, role.get('begin', None),
|
extract_role_pred(_pred, role.get('begin', None),
|
||||||
role.get('end', None))
|
role.get('end', None))
|
||||||
for _pred in pred
|
for _pred in pred
|
||||||
] for pred in pred_strs]
|
] for pred in pred_strs]
|
||||||
else:
|
else:
|
||||||
pred_strs = [
|
pred_strs = [
|
||||||
self._extract_role_pred(pred, role.get('begin', None),
|
extract_role_pred(pred, role.get('begin', None),
|
||||||
role.get('end', None))
|
role.get('end', None))
|
||||||
for pred in pred_strs
|
for pred in pred_strs
|
||||||
]
|
]
|
||||||
@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
mkdir_or_exist(osp.split(out_path)[0])
|
mkdir_or_exist(osp.split(out_path)[0])
|
||||||
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
|
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
def _extract_role_pred(self, s: str, begin_str: Optional[str],
|
|
||||||
end_str: Optional[str]) -> str:
|
|
||||||
"""Extract the role prediction from the full prediction string. The
|
|
||||||
role prediction may be the substring between the begin and end string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
s (str): Full prediction string.
|
|
||||||
begin_str (str): The beginning string of the role
|
|
||||||
end_str (str): The ending string of the role.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted role prediction.
|
|
||||||
"""
|
|
||||||
start = 0
|
|
||||||
end = len(s)
|
|
||||||
|
|
||||||
if begin_str:
|
|
||||||
begin_idx = s.find(begin_str)
|
|
||||||
if begin_idx != -1:
|
|
||||||
start = begin_idx + len(begin_str)
|
|
||||||
|
|
||||||
if end_str:
|
|
||||||
# TODO: Support calling tokenizer for the accurate eos token
|
|
||||||
# and avoid such hardcode
|
|
||||||
end_idx = s.find(end_str, start)
|
|
||||||
if end_idx != -1:
|
|
||||||
end = end_idx
|
|
||||||
|
|
||||||
return s[start:end]
|
|
||||||
|
|
||||||
def format_details(self, predictions, references, details, pred_dicts):
|
def format_details(self, predictions, references, details, pred_dicts):
|
||||||
"""This function is responsible for formatting prediction details.
|
"""This function is responsible for formatting prediction details.
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import fnmatch
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist
|
|||||||
|
|
||||||
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
|
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
|
||||||
from opencompass.tasks.base import BaseTask
|
from opencompass.tasks.base import BaseTask
|
||||||
|
from opencompass.tasks.openicl_eval import extract_role_pred
|
||||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||||
get_infer_output_path, get_logger,
|
get_infer_output_path, get_logger,
|
||||||
model_abbr_from_cfg, task_abbr_from_cfg)
|
model_abbr_from_cfg, task_abbr_from_cfg)
|
||||||
@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask):
|
|||||||
filename = get_infer_output_path(
|
filename = get_infer_output_path(
|
||||||
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
|
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
|
||||||
root, ext = osp.splitext(filename)
|
root, ext = osp.splitext(filename)
|
||||||
filename = root[:-2] + ext
|
last_underscore_index = root.rfind('_')
|
||||||
|
root = root[:last_underscore_index]
|
||||||
|
filename = root + ext
|
||||||
# If take SubjectNaivePartition, get filename
|
# If take SubjectNaivePartition, get filename
|
||||||
else:
|
else:
|
||||||
filename = get_infer_output_path(
|
filename = get_infer_output_path(
|
||||||
@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask):
|
|||||||
parser = LMTemplateParser(model_cfg['meta_template'])
|
parser = LMTemplateParser(model_cfg['meta_template'])
|
||||||
role = parser.roles[eval_cfg['pred_role']]
|
role = parser.roles[eval_cfg['pred_role']]
|
||||||
pred_strs = [
|
pred_strs = [
|
||||||
self._extract_role_pred(pred, role.get('begin', None),
|
extract_role_pred(pred, role.get('begin', None),
|
||||||
role.get('end', None))
|
role.get('end', None)) for pred in pred_strs
|
||||||
for pred in pred_strs
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Postprocess predictions if necessary
|
# Postprocess predictions if necessary
|
||||||
@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask):
|
|||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
indent=4)
|
indent=4)
|
||||||
|
|
||||||
def _extract_role_pred(self, s: str, begin_str: Optional[str],
|
|
||||||
end_str: Optional[str]) -> str:
|
|
||||||
"""Extract the role prediction from the full prediction string. The
|
|
||||||
role prediction may be the substring between the begin and end string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
s (str): Full prediction string.
|
|
||||||
begin_str (str): The beginning string of the role
|
|
||||||
end_str (str): The ending string of the role.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The extracted role prediction.
|
|
||||||
"""
|
|
||||||
start = 0
|
|
||||||
end = len(s)
|
|
||||||
|
|
||||||
if begin_str:
|
|
||||||
begin_idx = s.find(begin_str)
|
|
||||||
if begin_idx != -1:
|
|
||||||
start = begin_idx + len(begin_str)
|
|
||||||
|
|
||||||
if end_str:
|
|
||||||
# TODO: Support calling tokenizer for the accurate eos token
|
|
||||||
# and avoid such hardcode
|
|
||||||
end_idx = s.find(end_str[:1], start)
|
|
||||||
if end_idx != -1:
|
|
||||||
end = end_idx
|
|
||||||
|
|
||||||
return s[start:end]
|
|
||||||
|
|
||||||
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
||||||
"""Get the paths to the output files. Every file should exist if the
|
"""Get the paths to the output files. Every file should exist if the
|
||||||
task succeeds.
|
task succeeds.
|
||||||
|
Loading…
Reference in New Issue
Block a user