mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support model-bound prediction postprocessor, use it in Claude (#268)
* [Feature] Support model-bound text postprocessor, add claude as an example * update * update * minor fix --------- Co-authored-by: zhoufengzhe <zhoufengzhe@pjlab.org.cn>
This commit is contained in:
parent
6df124d40b
commit
f480b72703
@ -1,5 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models.claude_api import Claude
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
@ -9,15 +8,7 @@ with read_base():
|
||||
from .datasets.collections.chat_medium import datasets
|
||||
# and output the results in a choosen format
|
||||
from .summarizers.medium import summarizer
|
||||
|
||||
models = [
|
||||
dict(abbr='Claude2',
|
||||
type=Claude,
|
||||
path='claude-2',
|
||||
key='YOUR_CLAUDE_KEY',
|
||||
query_per_second=1,
|
||||
max_out_len=2048, max_seq_len=2048, batch_size=2),
|
||||
]
|
||||
from .models.claude import models
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
64
configs/models/claude.py
Normal file
64
configs/models/claude.py
Normal file
@ -0,0 +1,64 @@
|
||||
from opencompass.models.claude_api.claude_api import Claude
|
||||
from opencompass.utils.text_postprocessors import last_option_postprocess
|
||||
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
||||
|
||||
agieval_single_choice_sets = [
|
||||
'gaokao-chinese',
|
||||
'gaokao-english',
|
||||
'gaokao-geography',
|
||||
'gaokao-history',
|
||||
'gaokao-biology',
|
||||
'gaokao-chemistry',
|
||||
'gaokao-mathqa',
|
||||
'logiqa-zh',
|
||||
'lsat-ar',
|
||||
'lsat-lr',
|
||||
'lsat-rc',
|
||||
'logiqa-en',
|
||||
'sat-math',
|
||||
'sat-en',
|
||||
'sat-en-without-passage',
|
||||
'aqua-rat',
|
||||
]
|
||||
agieval_multiple_choices_sets = [
|
||||
'gaokao-physics',
|
||||
'jec-qa-kd',
|
||||
'jec-qa-ca',
|
||||
]
|
||||
|
||||
claude_postprocessors = {
|
||||
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'piqa': dict(type=last_option_postprocess, options='AB'),
|
||||
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
||||
'CB': dict(type=last_option_postprocess, options='ABC'),
|
||||
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
||||
'RTE': dict(type=last_option_postprocess, options='AB'),
|
||||
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||
'WSC': dict(type=last_option_postprocess, options='AB'),
|
||||
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
||||
'gsm8k': dict(type=gsm8k_postprocess),
|
||||
'openai_humaneval': dict(type=humaneval_postprocess),
|
||||
'lcsts': dict(type=lcsts_postprocess),
|
||||
'mbpp': dict(type=mbpp_postprocess),
|
||||
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||
}
|
||||
|
||||
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
|
||||
|
||||
models = [
|
||||
dict(abbr='Claude',
|
||||
type=Claude,
|
||||
path='claude-1',
|
||||
key='YOUR_CLAUDE_KEY',
|
||||
query_per_second=1,
|
||||
max_out_len=2048, max_seq_len=2048, batch_size=2,
|
||||
pred_postprocessor=claude_postprocessors,
|
||||
),
|
||||
]
|
64
configs/models/claude2.py
Normal file
64
configs/models/claude2.py
Normal file
@ -0,0 +1,64 @@
|
||||
from opencompass.models.claude_api.claude_api import Claude
|
||||
from opencompass.utils.text_postprocessors import last_option_postprocess
|
||||
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
||||
|
||||
agieval_single_choice_sets = [
|
||||
'gaokao-chinese',
|
||||
'gaokao-english',
|
||||
'gaokao-geography',
|
||||
'gaokao-history',
|
||||
'gaokao-biology',
|
||||
'gaokao-chemistry',
|
||||
'gaokao-mathqa',
|
||||
'logiqa-zh',
|
||||
'lsat-ar',
|
||||
'lsat-lr',
|
||||
'lsat-rc',
|
||||
'logiqa-en',
|
||||
'sat-math',
|
||||
'sat-en',
|
||||
'sat-en-without-passage',
|
||||
'aqua-rat',
|
||||
]
|
||||
agieval_multiple_choices_sets = [
|
||||
'gaokao-physics',
|
||||
'jec-qa-kd',
|
||||
'jec-qa-ca',
|
||||
]
|
||||
|
||||
claude_postprocessors = {
|
||||
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'piqa': dict(type=last_option_postprocess, options='AB'),
|
||||
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
||||
'CB': dict(type=last_option_postprocess, options='ABC'),
|
||||
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
||||
'RTE': dict(type=last_option_postprocess, options='AB'),
|
||||
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||
'WSC': dict(type=last_option_postprocess, options='AB'),
|
||||
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
||||
'gsm8k': dict(type=gsm8k_postprocess),
|
||||
'openai_humaneval': dict(type=humaneval_postprocess),
|
||||
'lcsts': dict(type=lcsts_postprocess),
|
||||
'mbpp': dict(type=mbpp_postprocess),
|
||||
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||
}
|
||||
|
||||
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
|
||||
|
||||
models = [
|
||||
dict(abbr='Claude2',
|
||||
type=Claude,
|
||||
path='claude-2',
|
||||
key='YOUR_CLAUDE_KEY',
|
||||
query_per_second=1,
|
||||
max_out_len=2048, max_seq_len=2048, batch_size=2,
|
||||
pred_postprocessor=claude_postprocessors,
|
||||
),
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from .base import BaseModel, LMTemplateParser # noqa
|
||||
from .base_api import APITemplateParser, BaseAPIModel # noqa
|
||||
from .claude_api import Claude # noqa: F401
|
||||
from .glm import GLM130B # noqa: F401, F403
|
||||
from .huggingface import HuggingFace # noqa: F401, F403
|
||||
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
||||
|
3
opencompass/models/claude_api/__init__.py
Normal file
3
opencompass/models/claude_api/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .claude_api import Claude
|
||||
|
||||
__all__ = ['Claude']
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
from ..base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
77
opencompass/models/claude_api/postprocessors.py
Normal file
77
opencompass/models/claude_api/postprocessors.py
Normal file
@ -0,0 +1,77 @@
|
||||
import re
|
||||
|
||||
|
||||
def gsm8k_postprocess(text: str) -> str:
|
||||
text = text.split(' ')[::-1]
|
||||
flag = False
|
||||
ret = ''
|
||||
for i in range(len(text)):
|
||||
s = text[i]
|
||||
for i in range(len(s)):
|
||||
if s[i].isdigit():
|
||||
flag = True
|
||||
ret = s
|
||||
break
|
||||
if flag:
|
||||
break
|
||||
ret1 = ''
|
||||
for i in range(len(ret)):
|
||||
if ret[i].isdigit():
|
||||
ret1 += ret[i]
|
||||
return ret1
|
||||
|
||||
|
||||
def humaneval_postprocess(text: str) -> str:
|
||||
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||
if '```' in text:
|
||||
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) == 0:
|
||||
text = text.split('```')[1] # fall back to default strategy
|
||||
else:
|
||||
text = blocks[0] # fetch the first code block
|
||||
if not text.startswith('\n'): # in case starting with ```python
|
||||
text = text[max(text.find('\n') + 1, 0):]
|
||||
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||||
def_idx = text.find('def')
|
||||
if def_idx != -1:
|
||||
text = text[max(text.find('\n', def_idx) + 1, 0):]
|
||||
if text.strip().startswith('def'):
|
||||
text = '\n'.join(text.split('\n')[1:])
|
||||
if not text.startswith(' '):
|
||||
if text.startswith(' '):
|
||||
text = ' ' + text.lstrip()
|
||||
else:
|
||||
text = '\n'.join([' ' + line for line in text.split('\n')])
|
||||
return text
|
||||
|
||||
|
||||
def lcsts_postprocess(text: str) -> str:
|
||||
text = text.strip()
|
||||
text = text.replace('1. ', '') if text.startswith('1. ') else text
|
||||
text = text.replace('- ', '') if text.startswith('- ') else text
|
||||
text = text.strip('“,。!”')
|
||||
return text
|
||||
|
||||
|
||||
def mbpp_postprocess(text: str) -> str:
|
||||
if text.startswith('Here'):
|
||||
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||
if '```' in text:
|
||||
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||||
if len(blocks) == 0:
|
||||
text = text.split('```')[1] # fall back to default strategy
|
||||
else:
|
||||
text = blocks[0] # fetch the first code block
|
||||
if not text.startswith('\n'): # in case starting with ```python
|
||||
text = text[max(text.find('\n') + 1, 0):]
|
||||
return text
|
||||
|
||||
|
||||
def strategyqa_pred_postprocess(text: str) -> str:
|
||||
if text.startswith('Here'):
|
||||
text = '\n'.join(text.split('\n')[1:]).strip()
|
||||
text = text.split('answer is ')[-1]
|
||||
match = re.search(r'(yes|no)', text.lower())
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ''
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import fnmatch
|
||||
import os.path as osp
|
||||
import time
|
||||
from collections import Counter
|
||||
@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist
|
||||
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
|
||||
get_logger, task_abbr_from_cfg)
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
task_abbr_from_cfg)
|
||||
|
||||
|
||||
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
|
||||
@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask):
|
||||
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
|
||||
self.output_column = dataset_cfg['reader_cfg']['output_column']
|
||||
|
||||
# overwrite postprocessor if the model has specified one
|
||||
ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg)
|
||||
model_postprocessors = self.model_cfg.get(
|
||||
'pred_postprocessor', {})
|
||||
for pattern in model_postprocessors.keys():
|
||||
if fnmatch.fnmatch(ds_abbr, pattern):
|
||||
self.eval_cfg[
|
||||
'pred_postprocessor'] = model_postprocessors[
|
||||
pattern] # noqa
|
||||
break
|
||||
|
||||
out_path = get_infer_output_path(
|
||||
self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'results'))
|
||||
|
@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
|
||||
model_cfg.pop('max_out_len', None)
|
||||
model_cfg.pop('batch_size', None)
|
||||
model_cfg.pop('abbr', None)
|
||||
model_cfg.pop('pred_postprocessor', None)
|
||||
return MODELS.build(model_cfg)
|
||||
|
@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str:
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ''
|
||||
|
||||
|
||||
def last_option_postprocess(text: str, options: str) -> str:
|
||||
match = re.findall(rf'([{options}])', text)
|
||||
if match:
|
||||
return match[-1]
|
||||
return ''
|
||||
|
Loading…
Reference in New Issue
Block a user