mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update claude2 postprocessor (#365)
* [Feature] Update claude2 config * [Feature] Update claude2 postprocessor
This commit is contained in:
parent
b885ec84df
commit
b11838f80a
@ -1,6 +1,9 @@
|
|||||||
from opencompass.models.claude_api.claude_api import Claude
|
from opencompass.models.claude_api.claude_api import Claude
|
||||||
from opencompass.utils.text_postprocessors import last_option_postprocess
|
from opencompass.utils.text_postprocessors import last_option_postprocess, first_option_postprocess
|
||||||
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
|
from opencompass.models.claude_api.postprocessors import (yes_no_postprocess, humaneval_claude2_postprocess, record_postprocess,
|
||||||
|
gsm8k_postprocess, strategyqa_pred_postprocess, mbpp_postprocess,
|
||||||
|
lcsts_postprocess)
|
||||||
|
|
||||||
|
|
||||||
agieval_single_choice_sets = [
|
agieval_single_choice_sets = [
|
||||||
'gaokao-chinese',
|
'gaokao-chinese',
|
||||||
@ -29,24 +32,20 @@ agieval_multiple_choices_sets = [
|
|||||||
claude_postprocessors = {
|
claude_postprocessors = {
|
||||||
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
|
||||||
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
'bustm-*': dict(type=last_option_postprocess, options='AB'),
|
||||||
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
|
|
||||||
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
|
|
||||||
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
|
|
||||||
'piqa': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'race-*': dict(type=last_option_postprocess, options='ABCD'),
|
|
||||||
'summedits': dict(type=last_option_postprocess, options='AB'),
|
'summedits': dict(type=last_option_postprocess, options='AB'),
|
||||||
'BoolQ': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'CB': dict(type=last_option_postprocess, options='ABC'),
|
|
||||||
'MultiRC': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'RTE': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'WiC': dict(type=last_option_postprocess, options='AB'),
|
'WiC': dict(type=last_option_postprocess, options='AB'),
|
||||||
'WSC': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'winogrande': dict(type=last_option_postprocess, options='AB'),
|
|
||||||
'gsm8k': dict(type=gsm8k_postprocess),
|
'gsm8k': dict(type=gsm8k_postprocess),
|
||||||
'openai_humaneval': dict(type=humaneval_postprocess),
|
'openai_humaneval': dict(type=humaneval_claude2_postprocess),
|
||||||
'lcsts': dict(type=lcsts_postprocess),
|
'lcsts': dict(type=lcsts_postprocess),
|
||||||
'mbpp': dict(type=mbpp_postprocess),
|
'mbpp': dict(type=mbpp_postprocess),
|
||||||
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
'strategyqa': dict(type=strategyqa_pred_postprocess),
|
||||||
|
'WSC': dict(type=yes_no_postprocess),
|
||||||
|
'BoolQ': dict(type=yes_no_postprocess),
|
||||||
|
'cmnli': dict(type=first_option_postprocess, options='ABC'),
|
||||||
|
'ocnli_fc-*': dict(type=first_option_postprocess, options='ABC'),
|
||||||
|
'MultiRC': dict(type=yes_no_postprocess),
|
||||||
|
'ReCoRD': dict(type=record_postprocess),
|
||||||
|
'commonsense_qa': dict(type=last_option_postprocess, options='ABCDE'),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
|
from opencompass.datasets.humaneval import humaneval_gpt_postprocess
|
||||||
|
from opencompass.datasets.record import ReCoRD_postprocess
|
||||||
|
from opencompass.datasets.xsum import Xsum_postprocess
|
||||||
|
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||||
|
|
||||||
|
|
||||||
def gsm8k_postprocess(text: str) -> str:
|
def gsm8k_postprocess(text: str) -> str:
|
||||||
text = text.split(' ')[::-1]
|
text = text.split(' ')[::-1]
|
||||||
@ -75,3 +80,32 @@ def strategyqa_pred_postprocess(text: str) -> str:
|
|||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def record_postprocess(text: str) -> str:
|
||||||
|
match = re.search(r'(?<=refers to )[^.]+', text)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
return match.group().strip() # Outputs: abc def
|
||||||
|
|
||||||
|
return ReCoRD_postprocess(text)
|
||||||
|
|
||||||
|
|
||||||
|
def humaneval_claude2_postprocess(text: str) -> str:
|
||||||
|
if text.startswith('Here'):
|
||||||
|
text = '\n\n'.join(text.split('\n\n')[1:])
|
||||||
|
return humaneval_gpt_postprocess(text)
|
||||||
|
|
||||||
|
|
||||||
|
def xsum_postprocess(text: str) -> str:
|
||||||
|
if text.startswith('Here'):
|
||||||
|
text = '\n\n'.join(text.split('\n\n')[1:])
|
||||||
|
return Xsum_postprocess(text)
|
||||||
|
|
||||||
|
|
||||||
|
def yes_no_postprocess(text: str) -> str:
|
||||||
|
if 'yes' in text.lower():
|
||||||
|
return 'A'
|
||||||
|
elif 'no' in text.lower():
|
||||||
|
return 'B'
|
||||||
|
return first_option_postprocess(text, 'AB')
|
||||||
|
Loading…
Reference in New Issue
Block a user