[Feature] Update claude2 postprocessor (#365)

* [Feature] Update claude2 config

* [Feature] Update claude2 postprocessor
This commit is contained in:
Tong Gao 2023-09-07 11:26:26 +08:00 committed by GitHub
parent b885ec84df
commit b11838f80a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 14 deletions

View File

@ -1,6 +1,9 @@
from opencompass.models.claude_api.claude_api import Claude
from opencompass.utils.text_postprocessors import last_option_postprocess
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
from opencompass.utils.text_postprocessors import last_option_postprocess, first_option_postprocess
from opencompass.models.claude_api.postprocessors import (yes_no_postprocess, humaneval_claude2_postprocess, record_postprocess,
gsm8k_postprocess, strategyqa_pred_postprocess, mbpp_postprocess,
lcsts_postprocess)
agieval_single_choice_sets = [
'gaokao-chinese',
@ -29,24 +32,20 @@ agieval_multiple_choices_sets = [
claude_postprocessors = {
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
'bustm-*': dict(type=last_option_postprocess, options='AB'),
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
'piqa': dict(type=last_option_postprocess, options='AB'),
'race-*': dict(type=last_option_postprocess, options='ABCD'),
'summedits': dict(type=last_option_postprocess, options='AB'),
'BoolQ': dict(type=last_option_postprocess, options='AB'),
'CB': dict(type=last_option_postprocess, options='ABC'),
'MultiRC': dict(type=last_option_postprocess, options='AB'),
'RTE': dict(type=last_option_postprocess, options='AB'),
'WiC': dict(type=last_option_postprocess, options='AB'),
'WSC': dict(type=last_option_postprocess, options='AB'),
'winogrande': dict(type=last_option_postprocess, options='AB'),
'gsm8k': dict(type=gsm8k_postprocess),
'openai_humaneval': dict(type=humaneval_postprocess),
'openai_humaneval': dict(type=humaneval_claude2_postprocess),
'lcsts': dict(type=lcsts_postprocess),
'mbpp': dict(type=mbpp_postprocess),
'strategyqa': dict(type=strategyqa_pred_postprocess),
'WSC': dict(type=yes_no_postprocess),
'BoolQ': dict(type=yes_no_postprocess),
'cmnli': dict(type=first_option_postprocess, options='ABC'),
'ocnli_fc-*': dict(type=first_option_postprocess, options='ABC'),
'MultiRC': dict(type=yes_no_postprocess),
'ReCoRD': dict(type=record_postprocess),
'commonsense_qa': dict(type=last_option_postprocess, options='ABCDE'),
}
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:

View File

@ -1,5 +1,10 @@
import re
from opencompass.datasets.humaneval import humaneval_gpt_postprocess
from opencompass.datasets.record import ReCoRD_postprocess
from opencompass.datasets.xsum import Xsum_postprocess
from opencompass.utils.text_postprocessors import first_option_postprocess
def gsm8k_postprocess(text: str) -> str:
text = text.split(' ')[::-1]
@ -75,3 +80,32 @@ def strategyqa_pred_postprocess(text: str) -> str:
if match:
return match.group(1)
return ''
def record_postprocess(text: str) -> str:
match = re.search(r'(?<=refers to )[^.]+', text)
if match:
return match.group().strip() # Outputs: abc def
return ReCoRD_postprocess(text)
def humaneval_claude2_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n\n'.join(text.split('\n\n')[1:])
return humaneval_gpt_postprocess(text)
def xsum_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n\n'.join(text.split('\n\n')[1:])
return Xsum_postprocess(text)
def yes_no_postprocess(text: str) -> str:
if 'yes' in text.lower():
return 'A'
elif 'no' in text.lower():
return 'B'
return first_option_postprocess(text, 'AB')