from opencompass.datasets import MusrDataset, MusrEvaluator from opencompass.openicl import PromptTemplate, ZeroRetriever, GenInferencer DATASET_CONFIGS = { 'murder_mysteries': { 'abbr': 'musr_murder_mysteries', 'name': 'murder_mysteries', 'path': 'opencompass/musr', 'reader_cfg': dict( input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], output_column='gold_answer', ), 'infer_cfg': dict( prompt_template=dict( type=PromptTemplate, template=dict( begin=[ dict( role='SYSTEM', fallback_role='HUMAN', prompt='{system_prompt}' ) ], round=[ dict( role='HUMAN', prompt='{prompt}' ), ] ), ), retriever=dict(type=ZeroRetriever), inferencer=dict(type=GenInferencer, max_out_len=2048), ), 'eval_cfg': dict( evaluator=dict( type=MusrEvaluator, answer_index_modifier=1, self_consistency_n=1 ), ), }, 'object_placements': { 'abbr': 'musr_object_placements', 'name': 'object_placements', 'path': 'opencompass/musr', 'reader_cfg': dict( input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], output_column='gold_answer', ), 'infer_cfg': dict( prompt_template=dict( type=PromptTemplate, template=dict( begin=[ dict( role='SYSTEM', fallback_role='HUMAN', prompt='{system_prompt}' ) ], round=[ dict( role='HUMAN', prompt='{prompt}' ), ] ), ), retriever=dict(type=ZeroRetriever), inferencer=dict(type=GenInferencer, max_out_len=512), ), 'eval_cfg': dict( evaluator=dict( type=MusrEvaluator, answer_index_modifier=1, self_consistency_n=1 ), ), }, 'team_allocation': { 'abbr': 'musr_team_allocation', 'name': 'team_allocation', 'path': 'opencompass/musr', 'reader_cfg': dict( input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], output_column='gold_answer', ), 'infer_cfg': dict( prompt_template=dict( type=PromptTemplate, template=dict( begin=[ dict( role='SYSTEM', fallback_role='HUMAN', prompt='{system_prompt}' ) ], round=[ dict( role='HUMAN', prompt='{prompt}' ), ] ), ), retriever=dict(type=ZeroRetriever), inferencer=dict(type=GenInferencer, max_out_len=512), ), 'eval_cfg': dict( evaluator=dict( type=MusrEvaluator, answer_index_modifier=1, self_consistency_n=1 ), ), }, } musr_datasets = [] for config in DATASET_CONFIGS.values(): dataset = dict( abbr=config['abbr'], type=MusrDataset, path=config['path'], name=config['name'], reader_cfg=config['reader_cfg'], infer_cfg=config['infer_cfg'], eval_cfg=config['eval_cfg'], ) musr_datasets.append(dataset)