mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support model-bound prediction postprocessor, use it in Claude (#268)
* [Feature] Support model-bound text postprocessor, add claude as an example * update * update * minor fix --------- Co-authored-by: zhoufengzhe <zhoufengzhe@pjlab.org.cn>
This commit is contained in:
parent
6df124d40b
commit
f480b72703
@ -1,5 +1,4 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
from opencompass.models.claude_api import Claude
|
|
||||||
from opencompass.partitioners import NaivePartitioner
|
from opencompass.partitioners import NaivePartitioner
|
||||||
from opencompass.runners import LocalRunner
|
from opencompass.runners import LocalRunner
|
||||||
from opencompass.tasks import OpenICLInferTask
|
from opencompass.tasks import OpenICLInferTask
|
||||||
@ -9,15 +8,7 @@ with read_base():
|
|||||||
from .datasets.collections.chat_medium import datasets
|
from .datasets.collections.chat_medium import datasets
|
||||||
# and output the results in a choosen format
|
# and output the results in a choosen format
|
||||||
from .summarizers.medium import summarizer
|
from .summarizers.medium import summarizer
|
||||||
|
from .models.claude import models
|
||||||
models = [
|
|
||||||
dict(abbr='Claude2',
|
|
||||||
type=Claude,
|
|
||||||
path='claude-2',
|
|
||||||
key='YOUR_CLAUDE_KEY',
|
|
||||||
query_per_second=1,
|
|
||||||
max_out_len=2048, max_seq_len=2048, batch_size=2),
|
|
||||||
]
|
|
||||||
|
|
||||||
infer = dict(
|
infer = dict(
|
||||||
partitioner=dict(type=NaivePartitioner),
|
partitioner=dict(type=NaivePartitioner),
|
64
configs/models/claude.py
Normal file
64
configs/models/claude.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from opencompass.models.claude_api.claude_api import Claude
|
||||||
|
from opencompass.utils.text_postprocessors import last_option_postprocess
|
||||||
|
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
||||||
|
|
||||||
|
agieval_single_choice_sets = [
|
||||||
|
'gaokao-chinese',
|
||||||
|
'gaokao-english',
|
||||||
|
'gaokao-geography',
|
||||||
|
'gaokao-history',
|
||||||
|
'gaokao-biology',
|
||||||
|
'gaokao-chemistry',
|
||||||
|
'gaokao-mathqa',
|
||||||
|
'logiqa-zh',
|
||||||
|
'lsat-ar',
|
||||||
|
'lsat-lr',
|
||||||
|
'lsat-rc',
|
||||||
|
'logiqa-en',
|
||||||
|
'sat-math',
|
||||||
|
'sat-en',
|
||||||
|
'sat-en-without-passage',
|
||||||
|
'aqua-rat',
|
||||||
|
]
|
||||||
|
agieval_multiple_choices_sets = [
|
||||||
|
'gaokao-physics',
|
||||||
|
'jec-qa-kd',
|
||||||
|
'jec-qa-ca',
|
||||||
|
]
|
||||||
|
|
||||||
|
claude_postprocessors = {
|
||||||
|
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'piqa': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'CB': dict(type=last_option_postprocess, options='ABC'),
|
||||||
|
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'RTE': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'WSC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'gsm8k': dict(type=gsm8k_postprocess),
|
||||||
|
'openai_humaneval': dict(type=humaneval_postprocess),
|
||||||
|
'lcsts': dict(type=lcsts_postprocess),
|
||||||
|
'mbpp': dict(type=mbpp_postprocess),
|
||||||
|
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||||
|
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(abbr='Claude',
|
||||||
|
type=Claude,
|
||||||
|
path='claude-1',
|
||||||
|
key='YOUR_CLAUDE_KEY',
|
||||||
|
query_per_second=1,
|
||||||
|
max_out_len=2048, max_seq_len=2048, batch_size=2,
|
||||||
|
pred_postprocessor=claude_postprocessors,
|
||||||
|
),
|
||||||
|
]
|
64
configs/models/claude2.py
Normal file
64
configs/models/claude2.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from opencompass.models.claude_api.claude_api import Claude
|
||||||
|
from opencompass.utils.text_postprocessors import last_option_postprocess
|
||||||
|
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
||||||
|
|
||||||
|
agieval_single_choice_sets = [
|
||||||
|
'gaokao-chinese',
|
||||||
|
'gaokao-english',
|
||||||
|
'gaokao-geography',
|
||||||
|
'gaokao-history',
|
||||||
|
'gaokao-biology',
|
||||||
|
'gaokao-chemistry',
|
||||||
|
'gaokao-mathqa',
|
||||||
|
'logiqa-zh',
|
||||||
|
'lsat-ar',
|
||||||
|
'lsat-lr',
|
||||||
|
'lsat-rc',
|
||||||
|
'logiqa-en',
|
||||||
|
'sat-math',
|
||||||
|
'sat-en',
|
||||||
|
'sat-en-without-passage',
|
||||||
|
'aqua-rat',
|
||||||
|
]
|
||||||
|
agieval_multiple_choices_sets = [
|
||||||
|
'gaokao-physics',
|
||||||
|
'jec-qa-kd',
|
||||||
|
'jec-qa-ca',
|
||||||
|
]
|
||||||
|
|
||||||
|
claude_postprocessors = {
|
||||||
|
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'piqa': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
|
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'CB': dict(type=last_option_postprocess, options='ABC'),
|
||||||
|
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'RTE': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'WSC': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
||||||
|
'gsm8k': dict(type=gsm8k_postprocess),
|
||||||
|
'openai_humaneval': dict(type=humaneval_postprocess),
|
||||||
|
'lcsts': dict(type=lcsts_postprocess),
|
||||||
|
'mbpp': dict(type=mbpp_postprocess),
|
||||||
|
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||||
|
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(abbr='Claude2',
|
||||||
|
type=Claude,
|
||||||
|
path='claude-2',
|
||||||
|
key='YOUR_CLAUDE_KEY',
|
||||||
|
query_per_second=1,
|
||||||
|
max_out_len=2048, max_seq_len=2048, batch_size=2,
|
||||||
|
pred_postprocessor=claude_postprocessors,
|
||||||
|
),
|
||||||
|
]
|
@ -1,5 +1,6 @@
|
|||||||
from .base import BaseModel, LMTemplateParser # noqa
|
from .base import BaseModel, LMTemplateParser # noqa
|
||||||
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
||||||
|
from .claude_api import Claude # noqa: F401
|
||||||
from .glm import GLM130B # noqa: F401, F403
|
from .glm import GLM130B # noqa: F401, F403
|
||||||
from .huggingface import HuggingFace # noqa: F401, F403
|
from .huggingface import HuggingFace # noqa: F401, F403
|
||||||
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
||||||
|
3
opencompass/models/claude_api/__init__.py
Normal file
3
opencompass/models/claude_api/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .claude_api import Claude
|
||||||
|
|
||||||
|
__all__ = ['Claude']
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
from opencompass.registry import MODELS
|
from opencompass.registry import MODELS
|
||||||
from opencompass.utils import PromptList
|
from opencompass.utils import PromptList
|
||||||
|
|
||||||
from .base_api import BaseAPIModel
|
from ..base_api import BaseAPIModel
|
||||||
|
|
||||||
PromptType = Union[PromptList, str]
|
PromptType = Union[PromptList, str]
|
||||||
|
|
77
opencompass/models/claude_api/postprocessors.py
Normal file
77
opencompass/models/claude_api/postprocessors.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def gsm8k_postprocess(text: str) -> str:
|
||||||
|
text = text.split(' ')[::-1]
|
||||||
|
flag = False
|
||||||
|
ret = ''
|
||||||
|
for i in range(len(text)):
|
||||||
|
s = text[i]
|
||||||
|
for i in range(len(s)):
|
||||||
|
if s[i].isdigit():
|
||||||
|
flag = True
|
||||||
|
ret = s
|
||||||
|
break
|
||||||
|
if flag:
|
||||||
|
break
|
||||||
|
ret1 = ''
|
||||||
|
for i in range(len(ret)):
|
||||||
|
if ret[i].isdigit():
|
||||||
|
ret1 += ret[i]
|
||||||
|
return ret1
|
||||||
|
|
||||||
|
|
||||||
|
def humaneval_postprocess(text: str) -> str:
|
||||||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||||
|
if '```' in text:
|
||||||
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||||
|
if len(blocks) == 0:
|
||||||
|
text = text.split('```')[1] # fall back to default strategy
|
||||||
|
else:
|
||||||
|
text = blocks[0] # fetch the first code block
|
||||||
|
if not text.startswith('\n'): # in case starting with ```python
|
||||||
|
text = text[max(text.find('\n') + 1, 0):]
|
||||||
|
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||||||
|
def_idx = text.find('def')
|
||||||
|
if def_idx != -1:
|
||||||
|
text = text[max(text.find('\n', def_idx) + 1, 0):]
|
||||||
|
if text.strip().startswith('def'):
|
||||||
|
text = '\n'.join(text.split('\n')[1:])
|
||||||
|
if not text.startswith(' '):
|
||||||
|
if text.startswith(' '):
|
||||||
|
text = ' ' + text.lstrip()
|
||||||
|
else:
|
||||||
|
text = '\n'.join([' ' + line for line in text.split('\n')])
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def lcsts_postprocess(text: str) -> str:
|
||||||
|
text = text.strip()
|
||||||
|
text = text.replace('1. ', '') if text.startswith('1. ') else text
|
||||||
|
text = text.replace('- ', '') if text.startswith('- ') else text
|
||||||
|
text = text.strip('“,。!”')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def mbpp_postprocess(text: str) -> str:
|
||||||
|
if text.startswith('Here'):
|
||||||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||||
|
if '```' in text:
|
||||||
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||||
|
if len(blocks) == 0:
|
||||||
|
text = text.split('```')[1] # fall back to default strategy
|
||||||
|
else:
|
||||||
|
text = blocks[0] # fetch the first code block
|
||||||
|
if not text.startswith('\n'): # in case starting with ```python
|
||||||
|
text = text[max(text.find('\n') + 1, 0):]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def strategyqa_pred_postprocess(text: str) -> str:
|
||||||
|
if text.startswith('Here'):
|
||||||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||||
|
text = text.split('answer is ')[-1]
|
||||||
|
match = re.search(r'(yes|no)', text.lower())
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
return ''
|
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import fnmatch
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import time
|
import time
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist
|
|||||||
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
|
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
|
||||||
TEXT_POSTPROCESSORS)
|
TEXT_POSTPROCESSORS)
|
||||||
from opencompass.tasks.base import BaseTask
|
from opencompass.tasks.base import BaseTask
|
||||||
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
|
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||||
get_logger, task_abbr_from_cfg)
|
get_infer_output_path, get_logger,
|
||||||
|
task_abbr_from_cfg)
|
||||||
|
|
||||||
|
|
||||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||||
@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
|
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
|
||||||
self.output_column = dataset_cfg['reader_cfg']['output_column']
|
self.output_column = dataset_cfg['reader_cfg']['output_column']
|
||||||
|
|
||||||
|
# overwrite postprocessor if the model has specified one
|
||||||
|
ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg)
|
||||||
|
model_postprocessors = self.model_cfg.get(
|
||||||
|
'pred_postprocessor', {})
|
||||||
|
for pattern in model_postprocessors.keys():
|
||||||
|
if fnmatch.fnmatch(ds_abbr, pattern):
|
||||||
|
self.eval_cfg[
|
||||||
|
'pred_postprocessor'] = model_postprocessors[
|
||||||
|
pattern] # noqa
|
||||||
|
break
|
||||||
|
|
||||||
out_path = get_infer_output_path(
|
out_path = get_infer_output_path(
|
||||||
self.model_cfg, self.dataset_cfg,
|
self.model_cfg, self.dataset_cfg,
|
||||||
osp.join(self.work_dir, 'results'))
|
osp.join(self.work_dir, 'results'))
|
||||||
|
@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
|
|||||||
model_cfg.pop('max_out_len', None)
|
model_cfg.pop('max_out_len', None)
|
||||||
model_cfg.pop('batch_size', None)
|
model_cfg.pop('batch_size', None)
|
||||||
model_cfg.pop('abbr', None)
|
model_cfg.pop('abbr', None)
|
||||||
|
model_cfg.pop('pred_postprocessor', None)
|
||||||
return MODELS.build(model_cfg)
|
return MODELS.build(model_cfg)
|
||||||
|
@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str:
|
|||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def last_option_postprocess(text: str, options: str) -> str:
|
||||||
|
match = re.findall(rf'([{options}])', text)
|
||||||
|
if match:
|
||||||
|
return match[-1]
|
||||||
|
return ''
|
||||||
|
Loading…
Reference in New Issue
Block a user