diff --git a/configs/datasets/ChemBench/ChemBench_gen.py b/configs/datasets/ChemBench/ChemBench_gen.py new file mode 100644 index 00000000..9327a0da --- /dev/null +++ b/configs/datasets/ChemBench/ChemBench_gen.py @@ -0,0 +1,77 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import ChemBenchDataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + + +chembench_reader_cfg = dict( + input_columns=["input", "A", "B", "C", "D"], + output_column="target", + train_split='dev') + +chembench_all_sets = [ + 'Name_Conversion', + 'Property_Prediction', + 'Mol2caption', + 'Caption2mol', + 'Product_Prediction', + 'Retrosynthesis', + 'Yield_Prediction', + 'Temperature_Prediction', + 'Solvent_Prediction' +] + + +chembench_datasets = [] +for _name in chembench_all_sets: + # _hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.' + _hint = f'There is a single choice question about chemistry. Answer the question by replying A, B, C or D.' + + chembench_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + dict(role="BOT", prompt="{target}\n") + ]), + ), + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + ], + ), + ice_token="", + ), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), + ) + + chembench_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=first_capital_postprocess)) + + chembench_datasets.append( + dict( + abbr=f"ChemBench_{_name}", + type=ChemBenchDataset, + path="./data/ChemBench/", + name=_name, + reader_cfg=chembench_reader_cfg, + infer_cfg=chembench_infer_cfg, + eval_cfg=chembench_eval_cfg, + )) + +del _name, _hint diff --git a/configs/eval_chembench.py b/configs/eval_chembench.py new file mode 100644 index 00000000..00be08a0 --- /dev/null +++ b/configs/eval_chembench.py @@ -0,0 +1,22 @@ +from mmengine.config import read_base + +with read_base(): + from .datasets.ChemBench.ChemBench_gen import chembench_datasets + from .models.mistral.hf_mistral_7b_instruct_v0_2 import models + +datasets = [*chembench_datasets] +models = [*models] + +''' +dataset version metric mode mistral-7b-instruct-v0.2-hf +-------------------------------- --------- -------- ------ ----------------------------- +ChemBench_Name_Conversion d4e6a1 accuracy gen 45.43 +ChemBench_Property_Prediction d4e6a1 accuracy gen 47.11 +ChemBench_Mol2caption d4e6a1 accuracy gen 64.21 +ChemBench_Caption2mol d4e6a1 accuracy gen 35.38 +ChemBench_Product_Prediction d4e6a1 accuracy gen 38.67 +ChemBench_Retrosynthesis d4e6a1 accuracy gen 27 +ChemBench_Yield_Prediction d4e6a1 accuracy gen 27 +ChemBench_Temperature_Prediction d4e6a1 accuracy gen 26.73 +ChemBench_Solvent_Prediction d4e6a1 accuracy gen 32.67 +''' diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 852d33ab..e3ead435 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -12,6 +12,7 @@ from .bustum import * # noqa: F401, F403 from .c3 import * # noqa: F401, F403 from .cb import * # noqa: F401, F403 from .ceval import * # noqa: F401, F403 +from .chembench import * # noqa: F401, F403 from .chid import * # noqa: F401, F403 from .cibench import * # noqa: F401, F403 from .circular import * # noqa: F401, F403 diff --git a/opencompass/datasets/chembench.py b/opencompass/datasets/chembench.py new file mode 100644 index 00000000..887c11c9 --- /dev/null +++ b/opencompass/datasets/chembench.py @@ -0,0 +1,34 @@ +import json +import os.path as osp + +from datasets import Dataset, DatasetDict + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class ChemBenchDataset(BaseDataset): + + @staticmethod + def load(path: str, name: str): + dataset = DatasetDict() + for split in ['dev', 'test']: + raw_data = [] + filename = osp.join(path, split, f'{name}_benchmark.json') + with open(filename, 'r', encoding='utf-8') as json_file: + data = json.load(json_file) + + for item in data: + raw_data.append({ + 'input': item['question'], + 'A': item['A'], + 'B': item['B'], + 'C': item['C'], + 'D': item['D'], + 'target': item['answer'], + }) + + dataset[split] = Dataset.from_list(raw_data) + return dataset