From 3c606cb7126d866f499cac0689c6aafb47f819ef Mon Sep 17 00:00:00 2001 From: bittersweet1999 <148421775+bittersweet1999@users.noreply.github.com> Date: Fri, 5 Jan 2024 21:10:18 +0800 Subject: [PATCH] quick fix for postprocess pred extraction (#771) --- opencompass/tasks/openicl_eval.py | 69 ++++++++++++++-------------- opencompass/tasks/subjective_eval.py | 42 +++-------------- 2 files changed, 42 insertions(+), 69 deletions(-) diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 59495917..92b50e53 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, task_abbr_from_cfg) +def extract_role_pred(s: str, begin_str: Optional[str], + end_str: Optional[str]) -> str: + """Extract the role prediction from the full prediction string. The role + prediction may be the substring between the begin and end string. + + Args: + s (str): Full prediction string. + begin_str (str): The beginning string of the role + end_str (str): The ending string of the role. + + Returns: + str: The extracted role prediction. + """ + start = 0 + end = len(s) + + if begin_str: + begin_idx = s.find(begin_str) + if begin_idx != -1: + start = begin_idx + len(begin_str) + + if end_str: + # TODO: Support calling tokenizer for the accurate eos token + # and avoid such hardcode + end_idx = s.find(end_str, start) + if end_idx != -1: + end = end_idx + + return s[start:end] + + @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run class OpenICLEvalTask(BaseTask): """OpenICL Evaluation Task. @@ -137,14 +168,14 @@ class OpenICLEvalTask(BaseTask): 'must be list.') if pred_list_flag: pred_strs = [[ - self._extract_role_pred(_pred, role.get('begin', None), - role.get('end', None)) + extract_role_pred(_pred, role.get('begin', None), + role.get('end', None)) for _pred in pred ] for pred in pred_strs] else: pred_strs = [ - self._extract_role_pred(pred, role.get('begin', None), - role.get('end', None)) + extract_role_pred(pred, role.get('begin', None), + role.get('end', None)) for pred in pred_strs ] @@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask): mkdir_or_exist(osp.split(out_path)[0]) mmengine.dump(result, out_path, ensure_ascii=False, indent=4) - def _extract_role_pred(self, s: str, begin_str: Optional[str], - end_str: Optional[str]) -> str: - """Extract the role prediction from the full prediction string. The - role prediction may be the substring between the begin and end string. - - Args: - s (str): Full prediction string. - begin_str (str): The beginning string of the role - end_str (str): The ending string of the role. - - Returns: - str: The extracted role prediction. - """ - start = 0 - end = len(s) - - if begin_str: - begin_idx = s.find(begin_str) - if begin_idx != -1: - start = begin_idx + len(begin_str) - - if end_str: - # TODO: Support calling tokenizer for the accurate eos token - # and avoid such hardcode - end_idx = s.find(end_str, start) - if end_idx != -1: - end = end_idx - - return s[start:end] - def format_details(self, predictions, references, details, pred_dicts): """This function is responsible for formatting prediction details. diff --git a/opencompass/tasks/subjective_eval.py b/opencompass/tasks/subjective_eval.py index 61fa8fba..f6086273 100644 --- a/opencompass/tasks/subjective_eval.py +++ b/opencompass/tasks/subjective_eval.py @@ -4,7 +4,7 @@ import fnmatch import os.path as osp import random import time -from typing import List, Optional, Union +from typing import List, Union import mmengine from mmengine.config import Config, ConfigDict @@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS from opencompass.tasks.base import BaseTask +from opencompass.tasks.openicl_eval import extract_role_pred from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, get_infer_output_path, get_logger, model_abbr_from_cfg, task_abbr_from_cfg) @@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask): filename = get_infer_output_path( model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions')) root, ext = osp.splitext(filename) - filename = root[:-2] + ext + last_underscore_index = root.rfind('_') + root = root[:last_underscore_index] + filename = root + ext # If take SubjectNaivePartition, get filename else: filename = get_infer_output_path( @@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask): parser = LMTemplateParser(model_cfg['meta_template']) role = parser.roles[eval_cfg['pred_role']] pred_strs = [ - self._extract_role_pred(pred, role.get('begin', None), - role.get('end', None)) - for pred in pred_strs + extract_role_pred(pred, role.get('begin', None), + role.get('end', None)) for pred in pred_strs ] # Postprocess predictions if necessary @@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask): ensure_ascii=False, indent=4) - def _extract_role_pred(self, s: str, begin_str: Optional[str], - end_str: Optional[str]) -> str: - """Extract the role prediction from the full prediction string. The - role prediction may be the substring between the begin and end string. - - Args: - s (str): Full prediction string. - begin_str (str): The beginning string of the role - end_str (str): The ending string of the role. - - Returns: - str: The extracted role prediction. - """ - start = 0 - end = len(s) - - if begin_str: - begin_idx = s.find(begin_str) - if begin_idx != -1: - start = begin_idx + len(begin_str) - - if end_str: - # TODO: Support calling tokenizer for the accurate eos token - # and avoid such hardcode - end_idx = s.find(end_str[:1], start) - if end_idx != -1: - end = end_idx - - return s[start:end] - def get_output_paths(self, file_extension: str = 'json') -> List[str]: """Get the paths to the output files. Every file should exist if the task succeeds.