From d4af31bab4668138b77d379f30877bea17d88c2c Mon Sep 17 00:00:00 2001 From: Hubert <42952108+yingfhu@users.noreply.github.com> Date: Mon, 27 Nov 2023 19:57:36 +0800 Subject: [PATCH] [Feat] support zhipu post process (#642) * [Feat] support zhipu post * [Feat] support zhipu post * [Feat] support zhipu post --- configs/api_examples/eval_api_zhipu.py | 11 ++++++++++ opencompass/utils/text_postprocessors.py | 27 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/configs/api_examples/eval_api_zhipu.py b/configs/api_examples/eval_api_zhipu.py index b3bf1ff1..9955f660 100644 --- a/configs/api_examples/eval_api_zhipu.py +++ b/configs/api_examples/eval_api_zhipu.py @@ -13,6 +13,17 @@ datasets = [ *ceval_datasets, ] +# needs a special postprocessor for all +# except 'gsm8k' and 'strategyqa' +from opencompass.utils import general_eval_wrapper_postprocess +for _dataset in datasets: + if _dataset['abbr'] not in ['gsm8k', 'strategyqa']: + if hasattr(_dataset['eval_cfg'], 'pred_postprocessor'): + _dataset['eval_cfg']['pred_postprocessor']['postprocess'] = _dataset['eval_cfg']['pred_postprocessor']['type'] + _dataset['eval_cfg']['pred_postprocessor']['type'] = general_eval_wrapper_postprocess + else: + _dataset['eval_cfg']['pred_postprocessor'] = {'type': general_eval_wrapper_postprocess} + models = [ dict( abbr='chatglm_pro', diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 60e59a65..996ff413 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -1,4 +1,5 @@ import re +from typing import Callable, Optional, Union from opencompass.registry import TEXT_POSTPROCESSORS @@ -141,3 +142,29 @@ def first_number_postprocess(text: str) -> float: def multiple_select_postprocess(text: str) -> str: ret = set([t for t in text if t.isupper()]) return ''.join(sorted(ret)) + + +def general_eval_wrapper_postprocess(text: str, + postprocess: Optional[Union[ + str, Callable]] = None, + **kwargs) -> str: + """Wrapper for eval text repr. Especially for chatglmpro. + + Args: + text(str): Text to be postprocessed. + postprocess(Callable, optional): Original post processing function. + Defaults to None. + **kwargs: Other necessary kwargs for post processing function. + """ + try: + text = eval(text) + except Exception: + # in case empty input or other error, skip eval + pass + + if postprocess: + if isinstance(postprocess, str): + postprocess = TEXT_POSTPROCESSORS.get(postprocess) + return postprocess(text, **kwargs) + else: + return text