[Feature] Support model-bound prediction postprocessor, use it in Claude (#268)

* [Feature] Support model-bound text postprocessor, add claude as an example

* update

* update

* minor fix

---------

Co-authored-by: zhoufengzhe <zhoufengzhe@pjlab.org.cn>
This commit is contained in:
Tong Gao 2023-08-25 16:12:21 +08:00 committed by GitHub
parent 6df124d40b
commit f480b72703
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 234 additions and 13 deletions

View File

@ -1,5 +1,4 @@
from mmengine.config import read_base
from opencompass.models.claude_api import Claude
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
@ -9,15 +8,7 @@ with read_base():
from .datasets.collections.chat_medium import datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer
models = [
dict(abbr='Claude2',
type=Claude,
path='claude-2',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2),
]
from .models.claude import models
infer = dict(
partitioner=dict(type=NaivePartitioner),

64
configs/models/claude.py Normal file
View File

@ -0,0 +1,64 @@
from opencompass.models.claude_api.claude_api import Claude
from opencompass.utils.text_postprocessors import last_option_postprocess
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
agieval_single_choice_sets = [
'gaokao-chinese',
'gaokao-english',
'gaokao-geography',
'gaokao-history',
'gaokao-biology',
'gaokao-chemistry',
'gaokao-mathqa',
'logiqa-zh',
'lsat-ar',
'lsat-lr',
'lsat-rc',
'logiqa-en',
'sat-math',
'sat-en',
'sat-en-without-passage',
'aqua-rat',
]
agieval_multiple_choices_sets = [
'gaokao-physics',
'jec-qa-kd',
'jec-qa-ca',
]
claude_postprocessors = {
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
'bustm-*': dict(type=last_option_postprocess, options='AB'),
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
'piqa': dict(type=last_option_postprocess, options='AB'),
'race-*': dict(type=last_option_postprocess, options='ABCD'),
'summedits': dict(type=last_option_postprocess, options='AB'),
'BoolQ': dict(type=last_option_postprocess, options='AB'),
'CB': dict(type=last_option_postprocess, options='ABC'),
'MultiRC': dict(type=last_option_postprocess, options='AB'),
'RTE': dict(type=last_option_postprocess, options='AB'),
'WiC': dict(type=last_option_postprocess, options='AB'),
'WSC': dict(type=last_option_postprocess, options='AB'),
'winogrande': dict(type=last_option_postprocess, options='AB'),
'gsm8k': dict(type=gsm8k_postprocess),
'openai_humaneval': dict(type=humaneval_postprocess),
'lcsts': dict(type=lcsts_postprocess),
'mbpp': dict(type=mbpp_postprocess),
'strategyqa': dict(type=strategyqa_pred_postprocess),
}
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
models = [
dict(abbr='Claude',
type=Claude,
path='claude-1',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2,
pred_postprocessor=claude_postprocessors,
),
]

64
configs/models/claude2.py Normal file
View File

@ -0,0 +1,64 @@
from opencompass.models.claude_api.claude_api import Claude
from opencompass.utils.text_postprocessors import last_option_postprocess
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
agieval_single_choice_sets = [
'gaokao-chinese',
'gaokao-english',
'gaokao-geography',
'gaokao-history',
'gaokao-biology',
'gaokao-chemistry',
'gaokao-mathqa',
'logiqa-zh',
'lsat-ar',
'lsat-lr',
'lsat-rc',
'logiqa-en',
'sat-math',
'sat-en',
'sat-en-without-passage',
'aqua-rat',
]
agieval_multiple_choices_sets = [
'gaokao-physics',
'jec-qa-kd',
'jec-qa-ca',
]
claude_postprocessors = {
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
'bustm-*': dict(type=last_option_postprocess, options='AB'),
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
'piqa': dict(type=last_option_postprocess, options='AB'),
'race-*': dict(type=last_option_postprocess, options='ABCD'),
'summedits': dict(type=last_option_postprocess, options='AB'),
'BoolQ': dict(type=last_option_postprocess, options='AB'),
'CB': dict(type=last_option_postprocess, options='ABC'),
'MultiRC': dict(type=last_option_postprocess, options='AB'),
'RTE': dict(type=last_option_postprocess, options='AB'),
'WiC': dict(type=last_option_postprocess, options='AB'),
'WSC': dict(type=last_option_postprocess, options='AB'),
'winogrande': dict(type=last_option_postprocess, options='AB'),
'gsm8k': dict(type=gsm8k_postprocess),
'openai_humaneval': dict(type=humaneval_postprocess),
'lcsts': dict(type=lcsts_postprocess),
'mbpp': dict(type=mbpp_postprocess),
'strategyqa': dict(type=strategyqa_pred_postprocess),
}
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
models = [
dict(abbr='Claude2',
type=Claude,
path='claude-2',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2,
pred_postprocessor=claude_postprocessors,
),
]

View File

@ -1,5 +1,6 @@
from .base import BaseModel, LMTemplateParser # noqa
from .base_api import APITemplateParser, BaseAPIModel # noqa
from .claude_api import Claude # noqa: F401
from .glm import GLM130B # noqa: F401, F403
from .huggingface import HuggingFace # noqa: F401, F403
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403

View File

@ -0,0 +1,3 @@
from .claude_api import Claude
__all__ = ['Claude']

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
from opencompass.registry import MODELS
from opencompass.utils import PromptList
from .base_api import BaseAPIModel
from ..base_api import BaseAPIModel
PromptType = Union[PromptList, str]

View File

@ -0,0 +1,77 @@
import re
def gsm8k_postprocess(text: str) -> str:
text = text.split(' ')[::-1]
flag = False
ret = ''
for i in range(len(text)):
s = text[i]
for i in range(len(s)):
if s[i].isdigit():
flag = True
ret = s
break
if flag:
break
ret1 = ''
for i in range(len(ret)):
if ret[i].isdigit():
ret1 += ret[i]
return ret1
def humaneval_postprocess(text: str) -> str:
text = '\n'.join(text.split('\n')[1:]).strip()
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
if text.strip().startswith('from') or text.strip().startswith('import'):
def_idx = text.find('def')
if def_idx != -1:
text = text[max(text.find('\n', def_idx) + 1, 0):]
if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '):
if text.startswith(' '):
text = ' ' + text.lstrip()
else:
text = '\n'.join([' ' + line for line in text.split('\n')])
return text
def lcsts_postprocess(text: str) -> str:
text = text.strip()
text = text.replace('1. ', '') if text.startswith('1. ') else text
text = text.replace('- ', '') if text.startswith('- ') else text
text = text.strip('“,。!”')
return text
def mbpp_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n'.join(text.split('\n')[1:]).strip()
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
return text
def strategyqa_pred_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n'.join(text.split('\n')[1:]).strip()
text = text.split('answer is ')[-1]
match = re.search(r'(yes|no)', text.lower())
if match:
return match.group(1)
return ''

View File

@ -1,4 +1,5 @@
import argparse
import fnmatch
import os.path as osp
import time
from collections import Counter
@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
TEXT_POSTPROCESSORS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
get_logger, task_abbr_from_cfg)
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger,
task_abbr_from_cfg)
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask):
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
self.output_column = dataset_cfg['reader_cfg']['output_column']
# overwrite postprocessor if the model has specified one
ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg)
model_postprocessors = self.model_cfg.get(
'pred_postprocessor', {})
for pattern in model_postprocessors.keys():
if fnmatch.fnmatch(ds_abbr, pattern):
self.eval_cfg[
'pred_postprocessor'] = model_postprocessors[
pattern] # noqa
break
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))

View File

@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
model_cfg.pop('max_out_len', None)
model_cfg.pop('batch_size', None)
model_cfg.pop('abbr', None)
model_cfg.pop('pred_postprocessor', None)
return MODELS.build(model_cfg)

View File

@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str:
if match:
return match.group(1)
return ''
def last_option_postprocess(text: str, options: str) -> str:
match = re.findall(rf'([{options}])', text)
if match:
return match[-1]
return ''