quick fix for postprocess pred extraction (#771)

This commit is contained in:
bittersweet1999 2024-01-05 21:10:18 +08:00 committed by GitHub
parent 0c75f0f95a
commit 3c606cb712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 69 deletions

View File

@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
task_abbr_from_cfg)
def extract_role_pred(s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The role
prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class OpenICLEvalTask(BaseTask):
"""OpenICL Evaluation Task.
@ -137,13 +168,13 @@ class OpenICLEvalTask(BaseTask):
'must be list.')
if pred_list_flag:
pred_strs = [[
self._extract_role_pred(_pred, role.get('begin', None),
extract_role_pred(_pred, role.get('begin', None),
role.get('end', None))
for _pred in pred
] for pred in pred_strs]
else:
pred_strs = [
self._extract_role_pred(pred, role.get('begin', None),
extract_role_pred(pred, role.get('begin', None),
role.get('end', None))
for pred in pred_strs
]
@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask):
mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
def format_details(self, predictions, references, details, pred_dicts):
"""This function is responsible for formatting prediction details.

View File

@ -4,7 +4,7 @@ import fnmatch
import os.path as osp
import random
import time
from typing import List, Optional, Union
from typing import List, Union
import mmengine
from mmengine.config import Config, ConfigDict
@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
from opencompass.tasks.base import BaseTask
from opencompass.tasks.openicl_eval import extract_role_pred
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger,
model_abbr_from_cfg, task_abbr_from_cfg)
@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask):
filename = get_infer_output_path(
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
root, ext = osp.splitext(filename)
filename = root[:-2] + ext
last_underscore_index = root.rfind('_')
root = root[:last_underscore_index]
filename = root + ext
# If take SubjectNaivePartition, get filename
else:
filename = get_infer_output_path(
@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask):
parser = LMTemplateParser(model_cfg['meta_template'])
role = parser.roles[eval_cfg['pred_role']]
pred_strs = [
self._extract_role_pred(pred, role.get('begin', None),
role.get('end', None))
for pred in pred_strs
extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) for pred in pred_strs
]
# Postprocess predictions if necessary
@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask):
ensure_ascii=False,
indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str[:1], start)
if end_idx != -1:
end = end_idx
return s[start:end]
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the paths to the output files. Every file should exist if the
task succeeds.