mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update claude2 postprocessor (#365)
* [Feature] Update claude2 config * [Feature] Update claude2 postprocessor
This commit is contained in:
parent
b885ec84df
commit
b11838f80a
@ -1,6 +1,9 @@
|
||||
from opencompass.models.claude_api.claude_api import Claude
|
||||
from opencompass.utils.text_postprocessors import last_option_postprocess
|
||||
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
||||
from opencompass.utils.text_postprocessors import last_option_postprocess, first_option_postprocess
|
||||
from opencompass.models.claude_api.postprocessors import (yes_no_postprocess, humaneval_claude2_postprocess, record_postprocess,
|
||||
gsm8k_postprocess, strategyqa_pred_postprocess, mbpp_postprocess,
|
||||
lcsts_postprocess)
|
||||
|
||||
|
||||
agieval_single_choice_sets = [
|
||||
'gaokao-chinese',
|
||||
@ -29,24 +32,20 @@ agieval_multiple_choices_sets = [
|
||||
claude_postprocessors = {
|
||||
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'piqa': dict(type=last_option_postprocess, options='AB'),
|
||||
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
||||
'CB': dict(type=last_option_postprocess, options='ABC'),
|
||||
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
||||
'RTE': dict(type=last_option_postprocess, options='AB'),
|
||||
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||
'WSC': dict(type=last_option_postprocess, options='AB'),
|
||||
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
||||
'gsm8k': dict(type=gsm8k_postprocess),
|
||||
'openai_humaneval': dict(type=humaneval_postprocess),
|
||||
'openai_humaneval': dict(type=humaneval_claude2_postprocess),
|
||||
'lcsts': dict(type=lcsts_postprocess),
|
||||
'mbpp': dict(type=mbpp_postprocess),
|
||||
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||
'WSC': dict(type=yes_no_postprocess),
|
||||
'BoolQ': dict(type=yes_no_postprocess),
|
||||
'cmnli': dict(type=first_option_postprocess, options='ABC'),
|
||||
'ocnli_fc-*': dict(type=first_option_postprocess, options='ABC'),
|
||||
'MultiRC': dict(type=yes_no_postprocess),
|
||||
'ReCoRD': dict(type=record_postprocess),
|
||||
'commonsense_qa': dict(type=last_option_postprocess, options='ABCDE'),
|
||||
}
|
||||
|
||||
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||
|
@ -1,5 +1,10 @@
|
||||
import re
|
||||
|
||||
from opencompass.datasets.humaneval import humaneval_gpt_postprocess
|
||||
from opencompass.datasets.record import ReCoRD_postprocess
|
||||
from opencompass.datasets.xsum import Xsum_postprocess
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
|
||||
def gsm8k_postprocess(text: str) -> str:
|
||||
text = text.split(' ')[::-1]
|
||||
@ -75,3 +80,32 @@ def strategyqa_pred_postprocess(text: str) -> str:
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ''
|
||||
|
||||
|
||||
def record_postprocess(text: str) -> str:
|
||||
match = re.search(r'(?<=refers to )[^.]+', text)
|
||||
|
||||
if match:
|
||||
return match.group().strip() # Outputs: abc def
|
||||
|
||||
return ReCoRD_postprocess(text)
|
||||
|
||||
|
||||
def humaneval_claude2_postprocess(text: str) -> str:
|
||||
if text.startswith('Here'):
|
||||
text = '\n\n'.join(text.split('\n\n')[1:])
|
||||
return humaneval_gpt_postprocess(text)
|
||||
|
||||
|
||||
def xsum_postprocess(text: str) -> str:
|
||||
if text.startswith('Here'):
|
||||
text = '\n\n'.join(text.split('\n\n')[1:])
|
||||
return Xsum_postprocess(text)
|
||||
|
||||
|
||||
def yes_no_postprocess(text: str) -> str:
|
||||
if 'yes' in text.lower():
|
||||
return 'A'
|
||||
elif 'no' in text.lower():
|
||||
return 'B'
|
||||
return first_option_postprocess(text, 'AB')
|
||||
|
Loading…
Reference in New Issue
Block a user