diff --git a/configs/eval_claude2.py b/configs/eval_claude.py similarity index 69% rename from configs/eval_claude2.py rename to configs/eval_claude.py index ef33174b..d3dcabf8 100644 --- a/configs/eval_claude2.py +++ b/configs/eval_claude.py @@ -1,5 +1,4 @@ from mmengine.config import read_base -from opencompass.models.claude_api import Claude from opencompass.partitioners import NaivePartitioner from opencompass.runners import LocalRunner from opencompass.tasks import OpenICLInferTask @@ -9,15 +8,7 @@ with read_base(): from .datasets.collections.chat_medium import datasets # and output the results in a choosen format from .summarizers.medium import summarizer - -models = [ - dict(abbr='Claude2', - type=Claude, - path='claude-2', - key='YOUR_CLAUDE_KEY', - query_per_second=1, - max_out_len=2048, max_seq_len=2048, batch_size=2), -] + from .models.claude import models infer = dict( partitioner=dict(type=NaivePartitioner), diff --git a/configs/models/claude.py b/configs/models/claude.py new file mode 100644 index 00000000..7b52c637 --- /dev/null +++ b/configs/models/claude.py @@ -0,0 +1,64 @@ +from opencompass.models.claude_api.claude_api import Claude +from opencompass.utils.text_postprocessors import last_option_postprocess +from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess + +agieval_single_choice_sets = [ + 'gaokao-chinese', + 'gaokao-english', + 'gaokao-geography', + 'gaokao-history', + 'gaokao-biology', + 'gaokao-chemistry', + 'gaokao-mathqa', + 'logiqa-zh', + 'lsat-ar', + 'lsat-lr', + 'lsat-rc', + 'logiqa-en', + 'sat-math', + 'sat-en', + 'sat-en-without-passage', + 'aqua-rat', +] +agieval_multiple_choices_sets = [ + 'gaokao-physics', + 'jec-qa-kd', + 'jec-qa-ca', +] + +claude_postprocessors = { + 'ceval-*': dict(type=last_option_postprocess, options='ABCD'), + 'bustm-*': dict(type=last_option_postprocess, options='AB'), + 'hellaswag': dict(type=last_option_postprocess, options='ABCD'), + 'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'), + 'openbookqa*': dict(type=last_option_postprocess, options='ABCD'), + 'piqa': dict(type=last_option_postprocess, options='AB'), + 'race-*': dict(type=last_option_postprocess, options='ABCD'), + 'summedits': dict(type=last_option_postprocess, options='AB'), + 'BoolQ': dict(type=last_option_postprocess, options='AB'), + 'CB': dict(type=last_option_postprocess, options='ABC'), + 'MultiRC': dict(type=last_option_postprocess, options='AB'), + 'RTE': dict(type=last_option_postprocess, options='AB'), + 'WiC': dict(type=last_option_postprocess, options='AB'), + 'WSC': dict(type=last_option_postprocess, options='AB'), + 'winogrande': dict(type=last_option_postprocess, options='AB'), + 'gsm8k': dict(type=gsm8k_postprocess), + 'openai_humaneval': dict(type=humaneval_postprocess), + 'lcsts': dict(type=lcsts_postprocess), + 'mbpp': dict(type=mbpp_postprocess), + 'strategyqa': dict(type=strategyqa_pred_postprocess), +} + +for _name in agieval_multiple_choices_sets + agieval_single_choice_sets: + claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE') + +models = [ + dict(abbr='Claude', + type=Claude, + path='claude-1', + key='YOUR_CLAUDE_KEY', + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=2, + pred_postprocessor=claude_postprocessors, + ), +] diff --git a/configs/models/claude2.py b/configs/models/claude2.py new file mode 100644 index 00000000..4249496a --- /dev/null +++ b/configs/models/claude2.py @@ -0,0 +1,64 @@ +from opencompass.models.claude_api.claude_api import Claude +from opencompass.utils.text_postprocessors import last_option_postprocess +from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess + +agieval_single_choice_sets = [ + 'gaokao-chinese', + 'gaokao-english', + 'gaokao-geography', + 'gaokao-history', + 'gaokao-biology', + 'gaokao-chemistry', + 'gaokao-mathqa', + 'logiqa-zh', + 'lsat-ar', + 'lsat-lr', + 'lsat-rc', + 'logiqa-en', + 'sat-math', + 'sat-en', + 'sat-en-without-passage', + 'aqua-rat', +] +agieval_multiple_choices_sets = [ + 'gaokao-physics', + 'jec-qa-kd', + 'jec-qa-ca', +] + +claude_postprocessors = { + 'ceval-*': dict(type=last_option_postprocess, options='ABCD'), + 'bustm-*': dict(type=last_option_postprocess, options='AB'), + 'hellaswag': dict(type=last_option_postprocess, options='ABCD'), + 'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'), + 'openbookqa*': dict(type=last_option_postprocess, options='ABCD'), + 'piqa': dict(type=last_option_postprocess, options='AB'), + 'race-*': dict(type=last_option_postprocess, options='ABCD'), + 'summedits': dict(type=last_option_postprocess, options='AB'), + 'BoolQ': dict(type=last_option_postprocess, options='AB'), + 'CB': dict(type=last_option_postprocess, options='ABC'), + 'MultiRC': dict(type=last_option_postprocess, options='AB'), + 'RTE': dict(type=last_option_postprocess, options='AB'), + 'WiC': dict(type=last_option_postprocess, options='AB'), + 'WSC': dict(type=last_option_postprocess, options='AB'), + 'winogrande': dict(type=last_option_postprocess, options='AB'), + 'gsm8k': dict(type=gsm8k_postprocess), + 'openai_humaneval': dict(type=humaneval_postprocess), + 'lcsts': dict(type=lcsts_postprocess), + 'mbpp': dict(type=mbpp_postprocess), + 'strategyqa': dict(type=strategyqa_pred_postprocess), +} + +for _name in agieval_multiple_choices_sets + agieval_single_choice_sets: + claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE') + +models = [ + dict(abbr='Claude2', + type=Claude, + path='claude-2', + key='YOUR_CLAUDE_KEY', + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=2, + pred_postprocessor=claude_postprocessors, + ), +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index e0e1586c..2a9455b4 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -1,5 +1,6 @@ from .base import BaseModel, LMTemplateParser # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa +from .claude_api import Claude # noqa: F401 from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 diff --git a/opencompass/models/claude_api/__init__.py b/opencompass/models/claude_api/__init__.py new file mode 100644 index 00000000..48ab226e --- /dev/null +++ b/opencompass/models/claude_api/__init__.py @@ -0,0 +1,3 @@ +from .claude_api import Claude + +__all__ = ['Claude'] diff --git a/opencompass/models/claude_api.py b/opencompass/models/claude_api/claude_api.py similarity index 99% rename from opencompass/models/claude_api.py rename to opencompass/models/claude_api/claude_api.py index 08df7e24..542afab5 100644 --- a/opencompass/models/claude_api.py +++ b/opencompass/models/claude_api/claude_api.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union from opencompass.registry import MODELS from opencompass.utils import PromptList -from .base_api import BaseAPIModel +from ..base_api import BaseAPIModel PromptType = Union[PromptList, str] diff --git a/opencompass/models/claude_api/postprocessors.py b/opencompass/models/claude_api/postprocessors.py new file mode 100644 index 00000000..c42358c7 --- /dev/null +++ b/opencompass/models/claude_api/postprocessors.py @@ -0,0 +1,77 @@ +import re + + +def gsm8k_postprocess(text: str) -> str: + text = text.split(' ')[::-1] + flag = False + ret = '' + for i in range(len(text)): + s = text[i] + for i in range(len(s)): + if s[i].isdigit(): + flag = True + ret = s + break + if flag: + break + ret1 = '' + for i in range(len(ret)): + if ret[i].isdigit(): + ret1 += ret[i] + return ret1 + + +def humaneval_postprocess(text: str) -> str: + text = '\n'.join(text.split('\n')[1:]).strip() + if '```' in text: + blocks = re.findall(r'```(.*?)```', text, re.DOTALL) + if len(blocks) == 0: + text = text.split('```')[1] # fall back to default strategy + else: + text = blocks[0] # fetch the first code block + if not text.startswith('\n'): # in case starting with ```python + text = text[max(text.find('\n') + 1, 0):] + if text.strip().startswith('from') or text.strip().startswith('import'): + def_idx = text.find('def') + if def_idx != -1: + text = text[max(text.find('\n', def_idx) + 1, 0):] + if text.strip().startswith('def'): + text = '\n'.join(text.split('\n')[1:]) + if not text.startswith(' '): + if text.startswith(' '): + text = ' ' + text.lstrip() + else: + text = '\n'.join([' ' + line for line in text.split('\n')]) + return text + + +def lcsts_postprocess(text: str) -> str: + text = text.strip() + text = text.replace('1. ', '') if text.startswith('1. ') else text + text = text.replace('- ', '') if text.startswith('- ') else text + text = text.strip('“,。!”') + return text + + +def mbpp_postprocess(text: str) -> str: + if text.startswith('Here'): + text = '\n'.join(text.split('\n')[1:]).strip() + if '```' in text: + blocks = re.findall(r'```(.*?)```', text, re.DOTALL) + if len(blocks) == 0: + text = text.split('```')[1] # fall back to default strategy + else: + text = blocks[0] # fetch the first code block + if not text.startswith('\n'): # in case starting with ```python + text = text[max(text.find('\n') + 1, 0):] + return text + + +def strategyqa_pred_postprocess(text: str) -> str: + if text.startswith('Here'): + text = '\n'.join(text.split('\n')[1:]).strip() + text = text.split('answer is ')[-1] + match = re.search(r'(yes|no)', text.lower()) + if match: + return match.group(1) + return '' diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 978a4a6b..1db2d84b 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -1,4 +1,5 @@ import argparse +import fnmatch import os.path as osp import time from collections import Counter @@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS, TEXT_POSTPROCESSORS) from opencompass.tasks.base import BaseTask -from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path, - get_logger, task_abbr_from_cfg) +from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, + get_infer_output_path, get_logger, + task_abbr_from_cfg) @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run @@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask): self.eval_cfg = self.dataset_cfg.get('eval_cfg') self.output_column = dataset_cfg['reader_cfg']['output_column'] + # overwrite postprocessor if the model has specified one + ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg) + model_postprocessors = self.model_cfg.get( + 'pred_postprocessor', {}) + for pattern in model_postprocessors.keys(): + if fnmatch.fnmatch(ds_abbr, pattern): + self.eval_cfg[ + 'pred_postprocessor'] = model_postprocessors[ + pattern] # noqa + break + out_path = get_infer_output_path( self.model_cfg, self.dataset_cfg, osp.join(self.work_dir, 'results')) diff --git a/opencompass/utils/build.py b/opencompass/utils/build.py index a4e50a36..7d6fe132 100644 --- a/opencompass/utils/build.py +++ b/opencompass/utils/build.py @@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict: model_cfg.pop('max_out_len', None) model_cfg.pop('batch_size', None) model_cfg.pop('abbr', None) + model_cfg.pop('pred_postprocessor', None) return MODELS.build(model_cfg) diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 8c504218..c6db0ba9 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str: if match: return match.group(1) return '' + + +def last_option_postprocess(text: str, options: str) -> str: + match = re.findall(rf'([{options}])', text) + if match: + return match[-1] + return ''