quick fix for postprocess pred extraction (#771)

This commit is contained in:
bittersweet1999 2024-01-05 21:10:18 +08:00 committed by GitHub
parent 0c75f0f95a
commit 3c606cb712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 69 deletions

View File

@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
task_abbr_from_cfg) task_abbr_from_cfg)
def extract_role_pred(s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The role
prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class OpenICLEvalTask(BaseTask): class OpenICLEvalTask(BaseTask):
"""OpenICL Evaluation Task. """OpenICL Evaluation Task.
@ -137,13 +168,13 @@ class OpenICLEvalTask(BaseTask):
'must be list.') 'must be list.')
if pred_list_flag: if pred_list_flag:
pred_strs = [[ pred_strs = [[
self._extract_role_pred(_pred, role.get('begin', None), extract_role_pred(_pred, role.get('begin', None),
role.get('end', None)) role.get('end', None))
for _pred in pred for _pred in pred
] for pred in pred_strs] ] for pred in pred_strs]
else: else:
pred_strs = [ pred_strs = [
self._extract_role_pred(pred, role.get('begin', None), extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) role.get('end', None))
for pred in pred_strs for pred in pred_strs
] ]
@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask):
mkdir_or_exist(osp.split(out_path)[0]) mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4) mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
def format_details(self, predictions, references, details, pred_dicts): def format_details(self, predictions, references, details, pred_dicts):
"""This function is responsible for formatting prediction details. """This function is responsible for formatting prediction details.

View File

@ -4,7 +4,7 @@ import fnmatch
import os.path as osp import os.path as osp
import random import random
import time import time
from typing import List, Optional, Union from typing import List, Union
import mmengine import mmengine
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
from opencompass.tasks.base import BaseTask from opencompass.tasks.base import BaseTask
from opencompass.tasks.openicl_eval import extract_role_pred
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger, get_infer_output_path, get_logger,
model_abbr_from_cfg, task_abbr_from_cfg) model_abbr_from_cfg, task_abbr_from_cfg)
@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask):
filename = get_infer_output_path( filename = get_infer_output_path(
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions')) model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
root, ext = osp.splitext(filename) root, ext = osp.splitext(filename)
filename = root[:-2] + ext last_underscore_index = root.rfind('_')
root = root[:last_underscore_index]
filename = root + ext
# If take SubjectNaivePartition, get filename # If take SubjectNaivePartition, get filename
else: else:
filename = get_infer_output_path( filename = get_infer_output_path(
@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask):
parser = LMTemplateParser(model_cfg['meta_template']) parser = LMTemplateParser(model_cfg['meta_template'])
role = parser.roles[eval_cfg['pred_role']] role = parser.roles[eval_cfg['pred_role']]
pred_strs = [ pred_strs = [
self._extract_role_pred(pred, role.get('begin', None), extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) role.get('end', None)) for pred in pred_strs
for pred in pred_strs
] ]
# Postprocess predictions if necessary # Postprocess predictions if necessary
@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask):
ensure_ascii=False, ensure_ascii=False,
indent=4) indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str[:1], start)
if end_idx != -1:
end = end_idx
return s[start:end]
def get_output_paths(self, file_extension: str = 'json') -> List[str]: def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the paths to the output files. Every file should exist if the """Get the paths to the output files. Every file should exist if the
task succeeds. task succeeds.