mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add model postprocess function (#1484)
* Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function * Add model postprocess function --------- Co-authored-by: liushz <liuhongwei@pjlab.rog.cn>
This commit is contained in:
parent
45efdc994d
commit
00fc8da5be
@ -70,6 +70,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
|
||||
|
||||
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2024.09.05\]** We now support answer extraction through model post-processing to provide a more accurate representation of the model's capabilities. As part of this update, we have integrated [XFinder](https://github.com/IAAR-Shanghai/xFinder) as our first post-processing model. For more detailed information, please refer to the [documentation](opencompass/utils/postprocessors/xfinder/README.md), and give it a try! 🔥🔥🔥
|
||||
- **\[2024.08.20\]** OpenCompass now supports the [SciCode](https://github.com/scicode-bench/SciCode): A Research Coding Benchmark Curated by Scientists. 🔥🔥🔥
|
||||
- **\[2024.08.16\]** OpenCompass now supports the brand new long-context language model evaluation benchmark — [RULER](https://arxiv.org/pdf/2404.06654). RULER provides an evaluation of long-context including retrieval, multi-hop tracing, aggregation, and question answering through flexible configurations. Check out the [RULER](configs/datasets/ruler/README.md) evaluation config now! 🔥🔥🔥
|
||||
- **\[2024.08.09\]** We have released the example data and configuration for the CompassBench-202408, welcome to [CompassBench](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/compassbench_intro.html) for more details. 🔥🔥🔥
|
||||
|
@ -69,6 +69,7 @@
|
||||
|
||||
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2024.09.05\]** OpenCompass 现在支持通过模型后处理来进行答案提取,以更准确地展示模型的能力。作为此次更新的一部分,我们集成了 [XFinder](https://github.com/IAAR-Shanghai/xFinder) 作为首个后处理模型。具体信息请参阅 [文档](opencompass/utils/postprocessors/xfinder/README.md),欢迎尝试! 🔥🔥🔥
|
||||
- **\[2024.08.20\]** OpenCompass 现已支持 [SciCode](https://github.com/scicode-bench/SciCode): A Research Coding Benchmark Curated by Scientists。 🔥🔥🔥
|
||||
- **\[2024.08.16\]** OpenCompass 现已支持全新的长上下文语言模型评估基准——[RULER](https://arxiv.org/pdf/2404.06654)。RULER 通过灵活的配置,提供了对长上下文包括检索、多跳追踪、聚合和问答等多种任务类型的评测,欢迎访问[RULER](configs/datasets/ruler/README.md)。🔥🔥🔥
|
||||
- **\[2024.07.23\]** 我们支持了[Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)模型,欢迎试用!🔥🔥🔥
|
||||
|
43
configs/datasets/gsm8k/gsm8k_xfinder_gen_a58960.py
Normal file
43
configs/datasets/gsm8k/gsm8k_xfinder_gen_a58960.py
Normal file
@ -0,0 +1,43 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import GSM8KDataset, gsm8k_postprocess, gsm8k_dataset_postprocess, Gsm8kEvaluator
|
||||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512),
|
||||
)
|
||||
|
||||
gsm8k_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator, version='v2'),
|
||||
pred_postprocessor=dict(type=math_postprocess_v2),
|
||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='math',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=GSM8KDataset,
|
||||
path='opencompass/gsm8k',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg,
|
||||
)
|
||||
]
|
130
configs/datasets/mmlu/mmlu_xfinder_gen_4d595a.py
Normal file
130
configs/datasets/mmlu/mmlu_xfinder_gen_4d595a.py
Normal file
@ -0,0 +1,130 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=['input', 'A', 'B', 'C', 'D'],
|
||||
output_column='target',
|
||||
train_split='dev')
|
||||
|
||||
mmlu_all_sets = [
|
||||
'college_biology',
|
||||
'college_chemistry',
|
||||
'college_computer_science',
|
||||
'college_mathematics',
|
||||
'college_physics',
|
||||
'electrical_engineering',
|
||||
'astronomy',
|
||||
'anatomy',
|
||||
'abstract_algebra',
|
||||
'machine_learning',
|
||||
'clinical_knowledge',
|
||||
'global_facts',
|
||||
'management',
|
||||
'nutrition',
|
||||
'marketing',
|
||||
'professional_accounting',
|
||||
'high_school_geography',
|
||||
'international_law',
|
||||
'moral_scenarios',
|
||||
'computer_security',
|
||||
'high_school_microeconomics',
|
||||
'professional_law',
|
||||
'medical_genetics',
|
||||
'professional_psychology',
|
||||
'jurisprudence',
|
||||
'world_religions',
|
||||
'philosophy',
|
||||
'virology',
|
||||
'high_school_chemistry',
|
||||
'public_relations',
|
||||
'high_school_macroeconomics',
|
||||
'human_sexuality',
|
||||
'elementary_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_computer_science',
|
||||
'high_school_european_history',
|
||||
'business_ethics',
|
||||
'moral_disputes',
|
||||
'high_school_statistics',
|
||||
'miscellaneous',
|
||||
'formal_logic',
|
||||
'high_school_government_and_politics',
|
||||
'prehistory',
|
||||
'security_studies',
|
||||
'high_school_biology',
|
||||
'logical_fallacies',
|
||||
'high_school_world_history',
|
||||
'professional_medicine',
|
||||
'high_school_mathematics',
|
||||
'college_medicine',
|
||||
'high_school_us_history',
|
||||
'sociology',
|
||||
'econometrics',
|
||||
'high_school_psychology',
|
||||
'human_aging',
|
||||
'us_foreign_policy',
|
||||
'conceptual_physics',
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for _name in mmlu_all_sets:
|
||||
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=
|
||||
f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
|
||||
),
|
||||
dict(role='BOT', prompt='{target}\n')
|
||||
]),
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin='</E>',
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
|
||||
),
|
||||
],
|
||||
),
|
||||
ice_token='</E>',
|
||||
),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='alphabet_option',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f'lukaemon_mmlu_{_name}',
|
||||
type=MMLUDataset,
|
||||
path='opencompass/mmlu',
|
||||
name=_name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _hint
|
37
configs/datasets/nq/nq_xfinder_gen_3dcea1.py
Normal file
37
configs/datasets/nq/nq_xfinder_gen_3dcea1.py
Normal file
@ -0,0 +1,37 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import NaturalQuestionDataset, NQEvaluator
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
nq_reader_cfg = dict(
|
||||
input_columns=['question'], output_column='answer', train_split='test')
|
||||
|
||||
nq_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='Question: {question}?\nAnswer: '),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
nq_eval_cfg = dict(
|
||||
evaluator=dict(type=NQEvaluator), pred_role='BOT',
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='short_text',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
nq_datasets = [
|
||||
dict(
|
||||
type=NaturalQuestionDataset,
|
||||
abbr='nq',
|
||||
path='opencompass/natural_question',
|
||||
reader_cfg=nq_reader_cfg,
|
||||
infer_cfg=nq_infer_cfg,
|
||||
eval_cfg=nq_eval_cfg)
|
||||
]
|
@ -0,0 +1,43 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import GSM8KDataset, gsm8k_postprocess, gsm8k_dataset_postprocess, Gsm8kEvaluator
|
||||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512),
|
||||
)
|
||||
|
||||
gsm8k_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator, version='v2'),
|
||||
pred_postprocessor=dict(type=math_postprocess_v2),
|
||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='math',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=GSM8KDataset,
|
||||
path='opencompass/gsm8k',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg,
|
||||
)
|
||||
]
|
130
opencompass/configs/datasets/mmlu/mmlu_xfinder_gen_4d595a.py
Normal file
130
opencompass/configs/datasets/mmlu/mmlu_xfinder_gen_4d595a.py
Normal file
@ -0,0 +1,130 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=['input', 'A', 'B', 'C', 'D'],
|
||||
output_column='target',
|
||||
train_split='dev')
|
||||
|
||||
mmlu_all_sets = [
|
||||
'college_biology',
|
||||
'college_chemistry',
|
||||
'college_computer_science',
|
||||
'college_mathematics',
|
||||
'college_physics',
|
||||
'electrical_engineering',
|
||||
'astronomy',
|
||||
'anatomy',
|
||||
'abstract_algebra',
|
||||
'machine_learning',
|
||||
'clinical_knowledge',
|
||||
'global_facts',
|
||||
'management',
|
||||
'nutrition',
|
||||
'marketing',
|
||||
'professional_accounting',
|
||||
'high_school_geography',
|
||||
'international_law',
|
||||
'moral_scenarios',
|
||||
'computer_security',
|
||||
'high_school_microeconomics',
|
||||
'professional_law',
|
||||
'medical_genetics',
|
||||
'professional_psychology',
|
||||
'jurisprudence',
|
||||
'world_religions',
|
||||
'philosophy',
|
||||
'virology',
|
||||
'high_school_chemistry',
|
||||
'public_relations',
|
||||
'high_school_macroeconomics',
|
||||
'human_sexuality',
|
||||
'elementary_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_computer_science',
|
||||
'high_school_european_history',
|
||||
'business_ethics',
|
||||
'moral_disputes',
|
||||
'high_school_statistics',
|
||||
'miscellaneous',
|
||||
'formal_logic',
|
||||
'high_school_government_and_politics',
|
||||
'prehistory',
|
||||
'security_studies',
|
||||
'high_school_biology',
|
||||
'logical_fallacies',
|
||||
'high_school_world_history',
|
||||
'professional_medicine',
|
||||
'high_school_mathematics',
|
||||
'college_medicine',
|
||||
'high_school_us_history',
|
||||
'sociology',
|
||||
'econometrics',
|
||||
'high_school_psychology',
|
||||
'human_aging',
|
||||
'us_foreign_policy',
|
||||
'conceptual_physics',
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for _name in mmlu_all_sets:
|
||||
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=
|
||||
f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
|
||||
),
|
||||
dict(role='BOT', prompt='{target}\n')
|
||||
]),
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin='</E>',
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
|
||||
),
|
||||
],
|
||||
),
|
||||
ice_token='</E>',
|
||||
),
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='alphabet_option',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f'lukaemon_mmlu_{_name}',
|
||||
type=MMLUDataset,
|
||||
path='opencompass/mmlu',
|
||||
name=_name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _hint
|
37
opencompass/configs/datasets/nq/nq_xfinder_gen_3dcea1.py
Normal file
37
opencompass/configs/datasets/nq/nq_xfinder_gen_3dcea1.py
Normal file
@ -0,0 +1,37 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import NaturalQuestionDataset, NQEvaluator
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
nq_reader_cfg = dict(
|
||||
input_columns=['question'], output_column='answer', train_split='test')
|
||||
|
||||
nq_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='Question: {question}?\nAnswer: '),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
nq_eval_cfg = dict(
|
||||
evaluator=dict(type=NQEvaluator), pred_role='BOT',
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='short_text',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
nq_datasets = [
|
||||
dict(
|
||||
type=NaturalQuestionDataset,
|
||||
abbr='nq',
|
||||
path='opencompass/natural_question',
|
||||
reader_cfg=nq_reader_cfg,
|
||||
infer_cfg=nq_infer_cfg,
|
||||
eval_cfg=nq_eval_cfg)
|
||||
]
|
@ -198,6 +198,26 @@ class OpenICLEvalTask(BaseTask):
|
||||
else:
|
||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||
|
||||
model_pred_strs = []
|
||||
if 'model_postprocessor' in self.eval_cfg:
|
||||
references = (test_set[self.output_column]
|
||||
if self.output_column else None)
|
||||
model_pred_dicts = copy.deepcopy(pred_dicts)
|
||||
for i, pred_dict in enumerate(model_pred_dicts):
|
||||
pred_dict['reference'] = [references[i]]
|
||||
self.logger.info('Postprocessing model predictions...')
|
||||
kwargs = self.eval_cfg['model_postprocessor']
|
||||
proc = kwargs.pop('type')
|
||||
if isinstance(proc, str):
|
||||
proc = TEXT_POSTPROCESSORS.get(proc)
|
||||
if pred_list_flag:
|
||||
model_pred_strs = [[
|
||||
proc(model_pred_dict, **kwargs)
|
||||
for model_pred_dict in model_pred_dicts
|
||||
]]
|
||||
else:
|
||||
model_pred_strs = proc(model_pred_dicts, **kwargs)
|
||||
|
||||
# Get majority voting predictions if use self-consistency
|
||||
if sc_size is not None:
|
||||
pred_strs = [
|
||||
@ -229,12 +249,29 @@ class OpenICLEvalTask(BaseTask):
|
||||
}
|
||||
result = icl_evaluator.score(**preds)
|
||||
|
||||
# Get model postprocess result
|
||||
model_details = None
|
||||
model_result = None
|
||||
if 'model_postprocessor' in self.eval_cfg:
|
||||
model_preds = copy.deepcopy(preds)
|
||||
model_preds['predictions'] = model_pred_strs
|
||||
model_result = icl_evaluator.score(**model_preds)
|
||||
for key in model_result:
|
||||
if key == 'details':
|
||||
model_details = model_result[key]
|
||||
continue
|
||||
new_key = 'model_postprocess_' + key
|
||||
result[new_key] = model_result[key]
|
||||
|
||||
if self.dump_details:
|
||||
details = result.get('details', None)
|
||||
try:
|
||||
result['details'] = self.format_details(
|
||||
pred_strs, test_set[self.output_column], details,
|
||||
pred_strs, model_pred_strs,
|
||||
test_set[self.output_column], details, model_details,
|
||||
pred_dicts)
|
||||
self.logger.warning(
|
||||
f"result['details'] : {result['details']}"),
|
||||
result['type'] = result['details'].pop('type', None)
|
||||
if self.cal_extract_rate:
|
||||
# Calculate the extraction success rate for prediction
|
||||
@ -253,13 +290,27 @@ class OpenICLEvalTask(BaseTask):
|
||||
self.logger.error(
|
||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
|
||||
return
|
||||
else:
|
||||
elif model_result is None:
|
||||
result_wo_details = {
|
||||
i: result[i]
|
||||
for i in result if i != 'details'
|
||||
}
|
||||
self.logger.info(
|
||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}')
|
||||
else:
|
||||
result_wo_details = {
|
||||
i: result[i]
|
||||
for i in result if i != 'details'
|
||||
}
|
||||
model_result_wo_details = {
|
||||
i: model_result[i]
|
||||
for i in model_result if i != 'details'
|
||||
}
|
||||
self.logger.info(
|
||||
f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}')
|
||||
self.logger.info(
|
||||
'Model Postprocess Task: ' +
|
||||
f'{task_abbr_from_cfg(self.cfg)}:{model_result_wo_details}')
|
||||
|
||||
# Save result
|
||||
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
|
||||
@ -286,7 +337,8 @@ class OpenICLEvalTask(BaseTask):
|
||||
success_rate = 100 - len(invalid_extractions) / len(details) * 100
|
||||
return success_rate
|
||||
|
||||
def format_details(self, predictions, references, details, pred_dicts):
|
||||
def format_details(self, predictions, model_pred_strs, references, details,
|
||||
model_details, pred_dicts):
|
||||
"""This function is responsible for formatting prediction details.
|
||||
|
||||
Args:
|
||||
@ -323,6 +375,19 @@ class OpenICLEvalTask(BaseTask):
|
||||
result['predictions'] = str(predictions[i])
|
||||
result['references'] = str(references[i])
|
||||
result['correct'] = str(predictions[i]) == str(references[i])
|
||||
elif details is not None and model_details is not None:
|
||||
assert model_pred_strs != [], \
|
||||
'Model details is not None, but model_pred_strs is empty'
|
||||
self.logger.info(
|
||||
f"model_details[i]['pred']: {model_details[i]['pred']}")
|
||||
results['type'] = 'GEN'
|
||||
result['prompt'] = origin_prediction['origin_prompt']
|
||||
result['origin_prediction'] = pred_dicts[i]['prediction']
|
||||
result['predictions'] = details[i]['pred']
|
||||
result['model_extract_predictions'] = model_details[i]['pred']
|
||||
result['references'] = details[i]['answer']
|
||||
result['correct'] = details[i]['correct']
|
||||
result['model_extract_correct'] = model_details[i]['correct']
|
||||
elif details is not None:
|
||||
results['type'] = 'GEN'
|
||||
result['prompt'] = origin_prediction['origin_prompt']
|
||||
|
@ -9,5 +9,6 @@ from .fileio import * # noqa
|
||||
from .lark import * # noqa
|
||||
from .logging import * # noqa
|
||||
from .menu import * # noqa
|
||||
from .model_postprocessors import * # noqa
|
||||
from .prompt import * # noqa
|
||||
from .text_postprocessors import * # noqa
|
||||
|
77
opencompass/utils/model_postprocessors.py
Normal file
77
opencompass/utils/model_postprocessors.py
Normal file
@ -0,0 +1,77 @@
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from typing import Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.registry import TEXT_POSTPROCESSORS
|
||||
|
||||
from .postprocessors.xfinder.extractor import Extractor
|
||||
from .postprocessors.xfinder.xfinder_utils import (DataProcessor,
|
||||
convert_to_xfinder_format)
|
||||
|
||||
|
||||
def gen_output(ori_data, extractor):
|
||||
ext_cor_pairs = []
|
||||
extracted_data = []
|
||||
extracted_answers = []
|
||||
for item in tqdm(ori_data):
|
||||
user_input = extractor.prepare_input(item)
|
||||
extracted_answer = extractor.gen_output(user_input)
|
||||
ext_cor_pairs.append([
|
||||
item['key_answer_type'], item['standard_answer_range'],
|
||||
extracted_answer, item['correct_answer']
|
||||
])
|
||||
item['xfinder_extracted_answer'] = extracted_answer
|
||||
extracted_answers.append(extracted_answer)
|
||||
extracted_data.append(item)
|
||||
|
||||
return extracted_answers, ext_cor_pairs, extracted_data
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('xfinder')
|
||||
def xfinder_postprocess(preds: list, question_type: str,
|
||||
xfinder_model_name: str,
|
||||
xfiner_api_url: Union[str, list], **kwargs) -> list:
|
||||
"""Postprocess the text extracted by xFinder model.
|
||||
Args:
|
||||
preds (list): The question, reference answer and model prediction.
|
||||
question_type (str): The type of the question.
|
||||
url (Union[str, list]): The api url of the xFinder model.
|
||||
|
||||
|
||||
Returns:
|
||||
list: The postprocessed texts.
|
||||
"""
|
||||
|
||||
def _eval_pred(texts, data_processor, extractor, num_processes=8):
|
||||
ori_data = data_processor.read_data(texts)
|
||||
extracted_correct_pairs = []
|
||||
extracted_data = []
|
||||
extracted_answers = []
|
||||
batched_ori_data = []
|
||||
# Split data into batches
|
||||
num_processes = min(num_processes, len(ori_data))
|
||||
batch_size = len(ori_data) // num_processes
|
||||
for i in range(0, len(ori_data), batch_size):
|
||||
batched_ori_data.append(ori_data[i:i + batch_size])
|
||||
with Pool(num_processes) as p:
|
||||
results = p.map(partial(gen_output, extractor=extractor),
|
||||
batched_ori_data)
|
||||
for result in results:
|
||||
extracted_answers += result[0]
|
||||
extracted_correct_pairs += result[1]
|
||||
extracted_data += result[2]
|
||||
return extracted_answers
|
||||
|
||||
format_data = convert_to_xfinder_format(question_type, preds)
|
||||
assert xfiner_api_url is not None, 'Please provide the api url.'
|
||||
data_processor = DataProcessor()
|
||||
extractor = Extractor(model_name=xfinder_model_name,
|
||||
url=xfiner_api_url.split(',')
|
||||
if ',' in xfiner_api_url else xfiner_api_url)
|
||||
calc_acc_func = partial(_eval_pred,
|
||||
data_processor=data_processor,
|
||||
extractor=extractor)
|
||||
extracted_answers = calc_acc_func(format_data)
|
||||
return extracted_answers
|
194
opencompass/utils/postprocessors/xfinder/README.md
Normal file
194
opencompass/utils/postprocessors/xfinder/README.md
Normal file
@ -0,0 +1,194 @@
|
||||
## Extract Final Answers with Postprocess Models
|
||||
|
||||
OpenCompass now support postprocess (extract) prediction answers with postprocess models, to get the true ability level of models. Now, we use [XFinder](https://github.com/IAAR-Shanghai/xFinder) as our first postprocess model to extract the final answers from the model outputs.
|
||||
|
||||
We support four types of task types now:
|
||||
|
||||
1. **math**: for math questions with numerical pr formula answers, like GSM8k, Math, etc.
|
||||
2. **alphabet_option**: for alphabet option questions with alphabet answers, like CommonsenseQA, MMLU, etc.
|
||||
3. **short_text**: for questions answer type is a short text with selected short text answers.
|
||||
|
||||
Here are some examples of the question type supported from the official xFinder repo:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"model_name": "Llama3_8B_Instruct",
|
||||
"dataset": "MetaMathQA",
|
||||
"key_answer_type": "math",
|
||||
"question": "Find the smallest, positive four-digit multiple of $15.$",
|
||||
"llm_output": "The smallest four-digit multiple of $15$ is $\\boxed{1005}.$",
|
||||
"correct_answer": "1005",
|
||||
"standard_answer_range": "a(n) number / set / vector / matrix / interval / expression / function / equation / inequality"
|
||||
},
|
||||
{
|
||||
"model_name": "LLaMA2_7B_Chat",
|
||||
"dataset": "SIQA_text",
|
||||
"key_answer_type": "short_text",
|
||||
"question": "Jordan kept Carson up all night talking because Jordan's girlfriend left them and they were quite upset. How would Carson feel as a result? Answer Choices: drowsy / upset their girlfriend left / energetic",
|
||||
"llm_output": "Based on the information provided in the question, Carson would likely feel upset because Jordan's girlfriend left them and they were quite upset. Therefore, the correct answer is (B) upset their girlfriend left.",
|
||||
"correct_answer": "drowsy",
|
||||
"standard_answer_range": [
|
||||
"drowsy",
|
||||
"upset their girlfriend left",
|
||||
"energetic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"model_name": "PHI2",
|
||||
"dataset": "CommonsenseQA_alpha",
|
||||
"key_answer_type": "alphabet_option",
|
||||
"question": "Why would you go to an office? Answer Choices: (A) work (B) school building (C) paper (D) city (E) habit",
|
||||
"llm_output": "A) work (B) school building (C) paper (D) city (E) habit\n\nExplanation:\n We are looking for a reason why someone would go to an office. The only answer choice that matches this is option A) work. Therefore, the correct answer is A) work.",
|
||||
"correct_answer": "A",
|
||||
"standard_answer_range": [
|
||||
[
|
||||
"A",
|
||||
"work"
|
||||
],
|
||||
[
|
||||
"B",
|
||||
"school building"
|
||||
],
|
||||
[
|
||||
"C",
|
||||
"paper"
|
||||
],
|
||||
[
|
||||
"D",
|
||||
"city"
|
||||
],
|
||||
[
|
||||
"E",
|
||||
"habit"
|
||||
]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## How to Use Model Postprocess in OpenCompass
|
||||
|
||||
### Step 1: Deploy the Postprocess Model Server
|
||||
|
||||
For now, there are two xFinder models can use, you can download them from Huggingface model hub:
|
||||
|
||||
1. **IAAR-Shanghai/xFinder-qwen1505**
|
||||
2. **IAAR-Shanghai/xFinder-llama38it**
|
||||
|
||||
You can use LMDeploy or vLLM to deploy the xFinder model server, for example, you can use the following command to deploy the xFinder model server with LMDeploy:
|
||||
|
||||
```bash
|
||||
lmdeploy serve api_server IAAR-Shanghai/xFinder-qwen1505 --model-name xFinder-qwen1505 --server-port 23333 --backend turbomind --tp 1
|
||||
```
|
||||
|
||||
### Step 2: Set the Postprocess Model Config in the Dataset Configuration
|
||||
|
||||
We make the postprocess as a common postprocess function in OpenCompass, so you can use it by setting the `postprocess` parameter in the `predict` function of OpenCompass. It can be used with the default postprocess regularization extract function at the same time. The only thing you need to do is to deploy the postprocess model server and set the `model_postprocessor` to the original `eval_cfg` in the dataset configuration, like the following example:
|
||||
|
||||
```python
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
...
|
||||
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='math',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
```
|
||||
|
||||
Explanation of the parameters:
|
||||
|
||||
- `question_type`: the type of the question, which can be one of the three types mentioned above.
|
||||
- `xfinder_model_name`: the name of the model you deploying the model server.
|
||||
- `xfiner_api_url`: the URL of the model server, you can set multiple URLs with `,` to use multiple model servers, which can accelerate the postprocess speed.
|
||||
|
||||
📢:**Please attention following points**:
|
||||
|
||||
1. Now only support extract questions with Zero-shot setting.
|
||||
2. For alphabet_option problems, the option should be like '\\nA. xxx\\nB. xxx\\nC. xxx\\nD. xxx\\nE. xxx\\n ...' or '\\n(A) xxx\\n(B) xxx\\n(C) xxx\\n(D) xxx\\n(E) xxx\\n ...' format, and the correct answer should be the alphabet of the correct answer, like 'A', 'B', 'C', 'D', 'E'.
|
||||
|
||||
For more details about the xFinder model, you can refer to the [xFinder](https://github.com/IAAR-Shanghai/xFinder), and for a complete example, you can refer to the following example, which is the configuration of the GSM8K dataset with the xFinder postprocess model:
|
||||
|
||||
```python
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess, Gsm8kEvaluator
|
||||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
|
||||
from opencompass.utils.model_postprocessors import xfinder_postprocess
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512),
|
||||
)
|
||||
|
||||
gsm8k_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator, version='v2'),
|
||||
pred_postprocessor=dict(type=math_postprocess_v2),
|
||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
||||
model_postprocessor=dict(
|
||||
type=xfinder_postprocess,
|
||||
question_type='math',
|
||||
xfinder_model_name='xFinder-qwen1505',
|
||||
xfiner_api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
|
||||
)
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=GSM8KDataset,
|
||||
path='opencompass/gsm8k',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg,
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
For evaluation results, `accuracy` is the result using default postprocess, and `model_postprocess_accuracy` is the result using xFinder postprocess, the gap can be wider when the model is not good answering the questions properly.
|
||||
|
||||
You can also use the `--dump-eval-details` command to dump the detailed evaluation details to see the model postprocess results from the `results` folder.
|
||||
|
||||
## Results Comparison with Different Question Types
|
||||
|
||||
We have tested the model postprocess method with XFinder model on the GSM8K, MMLU, Natural Questions (NQ) datasets for `Meta-Llama-3-8B-Instruct` with above settings, and the results are as follows:
|
||||
|
||||
| Dataset | Type | Config Name | Regex Postprocess Score | Model Postprocess Score |
|
||||
| ------- | --------------- | ------------------------ | ----------------------- | ----------------------- |
|
||||
| gsm8k | math | gsm8k_xfinder_gen_a58960 | 73.46 | 78.09 |
|
||||
| nq | short_text | nq_xfinder_gen_3dcea1 | 22.33 | 37.53 |
|
||||
| mmlu | alphabet_option | mmlu_xfinder_gen_4d595a | 67.89 | 67.93 |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{2023opencompass,
|
||||
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
||||
author={OpenCompass Contributors},
|
||||
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@misc{yu2024xfinderrobustpinpointanswer,
|
||||
title={xFinder: Robust and Pinpoint Answer Extraction for Large Language Models},
|
||||
author={Qingchen Yu and Zifan Zheng and Shichao Song and Zhiyu Li and Feiyu Xiong and Bo Tang and Ding Chen},
|
||||
year={2024},
|
||||
eprint={2405.11874},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2405.11874},
|
||||
}
|
||||
|
||||
```
|
175
opencompass/utils/postprocessors/xfinder/extractor.py
Normal file
175
opencompass/utils/postprocessors/xfinder/extractor.py
Normal file
@ -0,0 +1,175 @@
|
||||
import json
|
||||
import time
|
||||
from logging import getLogger
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from .xfinder_utils import PROMPT_TEMPLATE
|
||||
|
||||
Instruction = """I will provide you with a question, output sentences along with an answer range. The output sentences are the response of the question provided. The answer range could either describe the type of answer expected or list all possible valid answers. Using the information provided, you must accurately and precisely determine and extract the intended key answer from the output sentences. Please don't have your subjective thoughts about the question.
|
||||
First, you need to determine whether the content of the output sentences is relevant to the given question. If the entire output sentences are unrelated to the question (meaning the output sentences are not addressing the question), then output [No valid answer].
|
||||
Otherwise, ignore the parts of the output sentences that have no relevance to the question and then extract the key answer that matches the answer range.
|
||||
Below are some special cases you need to be aware of:
|
||||
(1) If the output sentences present multiple different answers, carefully determine if the later provided answer is a correction or modification of a previous one. If so, extract this corrected or modified answer as the final response. Conversely, if the output sentences fluctuate between multiple answers without a clear final answer, you should output [No valid answer].
|
||||
(2) If the answer range is a list and the key answer in the output sentences is not explicitly listed among the candidate options in the answer range, also output [No valid answer].
|
||||
|
||||
""" # noqa
|
||||
|
||||
|
||||
class Extractor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
model_path=None,
|
||||
url=None,
|
||||
temperature=0,
|
||||
max_tokens=3000,
|
||||
api_key='EMPTY',
|
||||
SYSTEM='You are a help assistant tasked with extracting the precise key answer from given output sentences. You must only provide the extracted key answer without including any additional text.' # noqa
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.PROMPT_TEMPLATE = PROMPT_TEMPLATE[model_name]
|
||||
self.SYSTEM = SYSTEM
|
||||
self.model_path = model_path
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.mode = 'API' if self.url is not None else 'Local'
|
||||
self.logger = getLogger(__name__)
|
||||
|
||||
if self.mode == 'Local':
|
||||
from vllm import LLM, SamplingParams
|
||||
self.sampling_params = SamplingParams(temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=[
|
||||
'<|endoftext|>',
|
||||
'<|im_end|>', '<eoa>',
|
||||
'<||>', '<end_of_turn>',
|
||||
'<|eot_id|>'
|
||||
])
|
||||
self.llm = LLM(model=self.model_path, gpu_memory_utilization=0.5)
|
||||
|
||||
@staticmethod
|
||||
def prepare_input(item):
|
||||
user_input = Instruction + \
|
||||
"Question: \"\"\"" + item['question'] + "\"\"\"\n\n" + \
|
||||
"Output sentences: \"\"\"" + item['llm_output'] + "\"\"\"\n\n" + \
|
||||
'Answer range: ' + item['standard_answer_range'] + '\n\n' + \
|
||||
'Key extracted answer: '
|
||||
|
||||
return user_input
|
||||
|
||||
def gen_output(self, query):
|
||||
if self.mode == 'API':
|
||||
# return self.send_request(query)
|
||||
return self.openai_infer(query)
|
||||
else:
|
||||
return self.offline_infer(query)
|
||||
|
||||
def send_request(self, query: str) -> str:
|
||||
"""Send a request to the model's API and return the response.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
|
||||
Returns:
|
||||
str: The extracted answer (xFinder's output).
|
||||
"""
|
||||
prompt = self.PROMPT_TEMPLATE.format(system=self.SYSTEM, input=query)
|
||||
payload = json.dumps({
|
||||
'prompt':
|
||||
prompt,
|
||||
'temperature':
|
||||
self.temperature,
|
||||
'max_tokens':
|
||||
self.max_tokens,
|
||||
'stop': [
|
||||
'<|endoftext|>', '<|im_end|>', '<eoa>', '<||>',
|
||||
'<end_of_turn>', '<|eot_id|>'
|
||||
],
|
||||
})
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
res = requests.request('POST', self.url, headers=headers, data=payload)
|
||||
res = res.json()['text'][0]
|
||||
res = res.replace(prompt, '')
|
||||
# res = requests.post(self.url, json=payload)
|
||||
# res = res.json()['text']
|
||||
res = res.strip()
|
||||
return res
|
||||
|
||||
def openai_infer(self, query: str, retry=9) -> str:
|
||||
"""Perform inference on the OpenAI model.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
|
||||
Returns:
|
||||
str: The extracted answer (xFinder's output).
|
||||
"""
|
||||
if isinstance(self.url, list):
|
||||
# Randomly api for better load balancing
|
||||
import random
|
||||
self.url = random.choice(self.url)
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.url,
|
||||
)
|
||||
self.retry = retry
|
||||
|
||||
t = time.time()
|
||||
retry = self.retry
|
||||
response = ''
|
||||
while retry > 0:
|
||||
try:
|
||||
chat_response = self.client.chat.completions.create(
|
||||
model=self.client.models.list().data[0].id
|
||||
if self.model_name == '' else self.model_name,
|
||||
messages=[
|
||||
{
|
||||
'role': 'system',
|
||||
'content': self.SYSTEM
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': query
|
||||
},
|
||||
],
|
||||
stop=[
|
||||
'<|endoftext|>', '<|im_end|>', '<eoa>', '<||>',
|
||||
'<end_of_turn>', '<|eot_id|>'
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
js_response = json.loads(chat_response.model_dump_json())
|
||||
response = js_response['choices'][0]['message']['content']
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.info(f'Error: {e}')
|
||||
self.logger.info(f'{self.url} is down. Retrying...')
|
||||
self.logger.info(f'Time elapsed: {time.time() - t} seconds')
|
||||
time.sleep(6)
|
||||
retry -= 1
|
||||
if retry == 0:
|
||||
response = 'Error: Failed to get response.'
|
||||
self.logger.info(f'{response} after {self.retry} tries.')
|
||||
raise ValueError('The api is down')
|
||||
return response.strip()
|
||||
|
||||
def offline_infer(self, query: str) -> str:
|
||||
"""Perform inference on the local xFinder model.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
|
||||
Returns:
|
||||
str: The extracted answer (xFinder's output).
|
||||
"""
|
||||
prompt = self.PROMPT_TEMPLATE.format(system=self.SYSTEM, input=query)
|
||||
res = self.llm.generate(prompt, self.sampling_params)
|
||||
res = res[0]
|
||||
res = res.outputs[0].text.strip()
|
||||
return res
|
@ -0,0 +1,14 @@
|
||||
PROMPT_TEMPLATE = {
|
||||
'xFinder-qwen1505':
|
||||
"""<|System|>:{system}
|
||||
<|User|>:{input}
|
||||
<|Bot|>:""",
|
||||
'xFinder-llama38it':
|
||||
"""<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{system}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
""",
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
from .convert_data import * # noqa
|
||||
from .data_process import * # noqa
|
||||
from .PROMPT_TEMPLATE import * # noqa
|
@ -0,0 +1,123 @@
|
||||
# Convert OpenCompass prediction data to XFinder format
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
|
||||
xfinder_template = {
|
||||
'math': {
|
||||
'model_name':
|
||||
'',
|
||||
'dataset':
|
||||
'',
|
||||
'key_answer_type':
|
||||
'math',
|
||||
'question':
|
||||
'',
|
||||
'llm_output':
|
||||
'',
|
||||
'correct_answer':
|
||||
'',
|
||||
'standard_answer_range':
|
||||
'a(n) number / set / vector / matrix / interval / expression / function / equation / inequality' # noqa
|
||||
},
|
||||
'alphabet_option': {
|
||||
'model_name': '',
|
||||
'dataset': '',
|
||||
'key_answer_type': 'alphabet_option',
|
||||
'question': '',
|
||||
'llm_output': '.',
|
||||
'correct_answer': '',
|
||||
'standard_answer_range': []
|
||||
},
|
||||
'categorical_label': {
|
||||
'model_name': '',
|
||||
'dataset': '',
|
||||
'key_answer_type': '',
|
||||
'question': '',
|
||||
'llm_output': '',
|
||||
'correct_answer': '',
|
||||
'standard_answer_range': []
|
||||
},
|
||||
'short_text': {
|
||||
'model_name': '',
|
||||
'dataset': '',
|
||||
'key_answer_type': 'short_text',
|
||||
'question': '',
|
||||
'llm_output': '',
|
||||
'correct_answer': '',
|
||||
'standard_answer_range': []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def parse_options(text: str):
|
||||
lines = text.split('\n')
|
||||
parsed_options = []
|
||||
option_pattern = r'^[A-Z]\)|[A-Z]\.|[A-Z]\)|[A-Z]:|\([A-Z]\)'
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
match = re.match(option_pattern, line)
|
||||
if match:
|
||||
option = ''
|
||||
# 等于第一个属于选项的字符
|
||||
for c in line:
|
||||
if c.isalpha():
|
||||
option = c
|
||||
break
|
||||
content_start = match.end() + 1
|
||||
content = line[content_start:].strip()
|
||||
parsed_options.append([option, content])
|
||||
|
||||
return parsed_options
|
||||
|
||||
|
||||
def convert_to_xfinder_format(typ, data, model_name='', dataset_name=''):
|
||||
assert typ in xfinder_template.keys(), f'Invalid type {typ}'
|
||||
format_data = []
|
||||
for item in data:
|
||||
template = copy.deepcopy(xfinder_template[typ])
|
||||
question = item['origin_prompt'][-1]['prompt']
|
||||
llm_output = item['prediction']
|
||||
correct_answer = item['reference'] if item['reference'] else item[
|
||||
'gold']
|
||||
template['correct_answer'] = correct_answer
|
||||
template['model_name'] = model_name
|
||||
template['dataset'] = dataset_name
|
||||
template['question'] = question
|
||||
template['llm_output'] = llm_output
|
||||
try:
|
||||
assert typ in list(xfinder_template.keys())
|
||||
if typ == 'alphabet_option':
|
||||
options = parse_options(question)
|
||||
template['standard_answer_range'] = options
|
||||
elif typ == 'short_text':
|
||||
template['standard_answer_range'] = item['gold']
|
||||
elif typ == 'categorical_label':
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f'Error when parsing question options: {e}, skipping...')
|
||||
continue
|
||||
|
||||
format_data.append(template)
|
||||
return format_data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Test
|
||||
example_data = {
|
||||
'origin_prompt': [{
|
||||
'role':
|
||||
'HUMAN',
|
||||
'prompt':
|
||||
'Alice, Bob, Claire, Dave, and Eve are dancers at a square dance. At the start of a song, they each have a partner: Alice is dancing with Ophelia, Bob is dancing with Jamie, Claire is dancing with Melissa, Dave is dancing with Rodrigo, and Eve is dancing with Patrick.\nThroughout the song, the dancers often trade partners. First, Claire and Bob switch partners. Then, Claire and Eve switch partners. Then, Claire and Bob switch partners. Then, Eve and Dave switch partners. Finally, Claire and Alice switch partners. At the end of the dance, Alice is dancing with\nOptions:\n(A) Ophelia\n(B) Jamie\n(C) Melissa\n(D) Rodrigo\n(E) Patrick' # noqa
|
||||
}],
|
||||
'origin_prediction':
|
||||
'\n 答案: B) 前者小于后者',
|
||||
'prediction':
|
||||
'B',
|
||||
'reference':
|
||||
'A'
|
||||
}
|
||||
example_data = convert_to_xfinder_format('alphabet_option', [example_data],
|
||||
'GPT-3', 'OpenAI')
|
||||
print(json.dumps(example_data, indent=4, ensure_ascii=False))
|
@ -0,0 +1,24 @@
|
||||
import ast
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def read_data(self, data):
|
||||
for item in data:
|
||||
if isinstance(item['standard_answer_range'],
|
||||
str) and item['key_answer_type'] != 'math':
|
||||
try:
|
||||
item['standard_answer_range'] = ast.literal_eval(
|
||||
item['standard_answer_range'])
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
print('Please check the form of standard_answer_range')
|
||||
exit(0)
|
||||
|
||||
item['standard_answer_range'] = str(item['standard_answer_range'])
|
||||
item['key_answer_type'] = str(item['key_answer_type'])
|
||||
|
||||
return data
|
Loading…
Reference in New Issue
Block a user