diff --git a/configs/models/claude2.py b/configs/models/claude2.py index 4249496a..9c1eaf83 100644 --- a/configs/models/claude2.py +++ b/configs/models/claude2.py @@ -1,6 +1,9 @@ from opencompass.models.claude_api.claude_api import Claude -from opencompass.utils.text_postprocessors import last_option_postprocess -from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess +from opencompass.utils.text_postprocessors import last_option_postprocess, first_option_postprocess +from opencompass.models.claude_api.postprocessors import (yes_no_postprocess, humaneval_claude2_postprocess, record_postprocess, + gsm8k_postprocess, strategyqa_pred_postprocess, mbpp_postprocess, + lcsts_postprocess) + agieval_single_choice_sets = [ 'gaokao-chinese', @@ -29,24 +32,20 @@ agieval_multiple_choices_sets = [ claude_postprocessors = { 'ceval-*': dict(type=last_option_postprocess, options='ABCD'), 'bustm-*': dict(type=last_option_postprocess, options='AB'), - 'hellaswag': dict(type=last_option_postprocess, options='ABCD'), - 'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'), - 'openbookqa*': dict(type=last_option_postprocess, options='ABCD'), - 'piqa': dict(type=last_option_postprocess, options='AB'), - 'race-*': dict(type=last_option_postprocess, options='ABCD'), 'summedits': dict(type=last_option_postprocess, options='AB'), - 'BoolQ': dict(type=last_option_postprocess, options='AB'), - 'CB': dict(type=last_option_postprocess, options='ABC'), - 'MultiRC': dict(type=last_option_postprocess, options='AB'), - 'RTE': dict(type=last_option_postprocess, options='AB'), 'WiC': dict(type=last_option_postprocess, options='AB'), - 'WSC': dict(type=last_option_postprocess, options='AB'), - 'winogrande': dict(type=last_option_postprocess, options='AB'), 'gsm8k': dict(type=gsm8k_postprocess), - 'openai_humaneval': dict(type=humaneval_postprocess), + 'openai_humaneval': dict(type=humaneval_claude2_postprocess), 'lcsts': dict(type=lcsts_postprocess), 'mbpp': dict(type=mbpp_postprocess), 'strategyqa': dict(type=strategyqa_pred_postprocess), + 'WSC': dict(type=yes_no_postprocess), + 'BoolQ': dict(type=yes_no_postprocess), + 'cmnli': dict(type=first_option_postprocess, options='ABC'), + 'ocnli_fc-*': dict(type=first_option_postprocess, options='ABC'), + 'MultiRC': dict(type=yes_no_postprocess), + 'ReCoRD': dict(type=record_postprocess), + 'commonsense_qa': dict(type=last_option_postprocess, options='ABCDE'), } for _name in agieval_multiple_choices_sets + agieval_single_choice_sets: diff --git a/opencompass/models/claude_api/postprocessors.py b/opencompass/models/claude_api/postprocessors.py index c42358c7..878f1669 100644 --- a/opencompass/models/claude_api/postprocessors.py +++ b/opencompass/models/claude_api/postprocessors.py @@ -1,5 +1,10 @@ import re +from opencompass.datasets.humaneval import humaneval_gpt_postprocess +from opencompass.datasets.record import ReCoRD_postprocess +from opencompass.datasets.xsum import Xsum_postprocess +from opencompass.utils.text_postprocessors import first_option_postprocess + def gsm8k_postprocess(text: str) -> str: text = text.split(' ')[::-1] @@ -75,3 +80,32 @@ def strategyqa_pred_postprocess(text: str) -> str: if match: return match.group(1) return '' + + +def record_postprocess(text: str) -> str: + match = re.search(r'(?<=refers to )[^.]+', text) + + if match: + return match.group().strip() # Outputs: abc def + + return ReCoRD_postprocess(text) + + +def humaneval_claude2_postprocess(text: str) -> str: + if text.startswith('Here'): + text = '\n\n'.join(text.split('\n\n')[1:]) + return humaneval_gpt_postprocess(text) + + +def xsum_postprocess(text: str) -> str: + if text.startswith('Here'): + text = '\n\n'.join(text.split('\n\n')[1:]) + return Xsum_postprocess(text) + + +def yes_no_postprocess(text: str) -> str: + if 'yes' in text.lower(): + return 'A' + elif 'no' in text.lower(): + return 'B' + return first_option_postprocess(text, 'AB')