0-shot Smolinstruct

Add 0-shot evaluation and postprocess functions for Smolinstruct
This commit is contained in:
Yu Sun 2025-05-26 21:20:06 +08:00
parent c3779ebfc1
commit 830315a78c
7 changed files with 349 additions and 17 deletions

View File

@ -0,0 +1,55 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.smolinstruct import FTSEvaluator
from opencompass.datasets import SmolInstructDataset
fts_0shot_reader_cfg = dict(
input_columns=['input'],
output_column='output',
train_split='validation')
fts_hint_dict = {
'MG': """You are an expert chemist. Given the description of a molecule, your task is to generate the potential SMILES representation of the molecule.
The input contains the description of the molecule. Your reply should contain the potential SMILES representation of the molecule wrapped in <SMILES> and </SMILES> tags. Your reply must be valid and chemically reasonable.""",
'FS': """You are an expert chemist. Given the SMILES representation of reactants and reagents, your task is to predict the potential product using your chemical reaction knowledge.
The input contains both reactants and reagents, and different reactants and reagents are separated by ".". Your reply should contain the SMILES representation of the predicted product wrapped in <SMILES> and </SMILES> tags. Your reply must be valid and chemically reasonable.""",
'RS': """You are an expert chemist. Given the SMILES representation of the product, your task is to predict the potential reactants and reagents using your chemical reaction knowledge.
The input contains the SMILES representation of the product. Your reply should contain the SMILES representation of both reactants and reagents, and all reactants and reagents should be enclosed **together** within a single pair of <SMILES> and </SMILES> tags, separated by ".". Your reply must be valid and chemically reasonable.""",
}
name_dict = {
'MG': 'molecule_generation',
'FS': 'forward_synthesis',
'RS': 'retrosynthesis'
}
fts_0shot_instruct_datasets = []
for _name in name_dict:
_hint = fts_hint_dict[_name]
fts_0shot_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=f'{_hint}\nQuestion: {{input}}\nAnswer: ',
# template=f'<s>[INST] {{input}} [/INST]',
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
fts_0shot_eval_cfg = dict(
evaluator=dict(type=FTSEvaluator),
)
fts_0shot_instruct_datasets.append(
dict(
abbr=f'{_name}-0shot-instruct',
type=SmolInstructDataset,
path='osunlp/SMolInstruct',
name=name_dict[_name],
reader_cfg=fts_0shot_reader_cfg,
infer_cfg=fts_0shot_infer_cfg,
eval_cfg=fts_0shot_eval_cfg,
))
del _name

View File

@ -0,0 +1,10 @@
from mmengine.config import read_base
with read_base():
from opencompass.configs.datasets.SmolInstruct.smolinstruct_nc_0shot_instruct import nc_0shot_instruct_datasets
from opencompass.configs.datasets.SmolInstruct.smolinstruct_pp_acc_0_shot_instruct import pp_acc_datasets_0shot_instruct
from opencompass.configs.datasets.SmolInstruct.smolinstruct_rmse_0shot_instruct import pp_rmse_0shot_instruct_datasets
from opencompass.configs.datasets.SmolInstruct.smolinstruct_fts_0shot_instruct import fts_0shot_instruct_datasets
from opencompass.configs.datasets.SmolInstruct.smolinstruct_meteor_0shot_instruct import meteor_0shot_instruct_datasets
smolinstruct_datasets_0shot_instruct = nc_0shot_instruct_datasets + pp_rmse_0shot_instruct_datasets + pp_acc_datasets_0shot_instruct + meteor_0shot_instruct_datasets + fts_0shot_instruct_datasets

View File

@ -0,0 +1,49 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.smolinstruct import MeteorEvaluator
from opencompass.datasets import SmolInstructDataset
meteor_0shot_reader_cfg = dict(
input_columns=['input'],
output_column='output',
train_split='validation')
meteor_hint_dict = {
'MC': """You are an expert chemist. Given the SMILES representation of a molecule, your task is to describe the molecule in natural language.
The input contains the SMILES representation of the molecule. Your reply should contain a natural language description of the molecule. Your reply must be valid and chemically reasonable.""",
}
name_dict = {
'MC': 'molecule_captioning',
}
meteor_0shot_instruct_datasets = []
for _name in name_dict:
_hint = meteor_hint_dict[_name]
meteor_0shot_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=f'{_hint}\nQuestion: {{input}}\nAnswer: ',
# template=f'<s>[INST] {{input}} [/INST]',
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
meteor_0shot_eval_cfg = dict(
evaluator=dict(type=MeteorEvaluator),
)
meteor_0shot_instruct_datasets.append(
dict(
abbr=f'{_name}-0shot-instruct',
type=SmolInstructDataset,
path='osunlp/SMolInstruct',
name=name_dict[_name],
reader_cfg=meteor_0shot_reader_cfg,
infer_cfg=meteor_0shot_infer_cfg,
eval_cfg=meteor_0shot_eval_cfg,
))
del _name

View File

@ -0,0 +1,63 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.smolinstruct import NCExactMatchEvaluator, NCElementMatchEvaluator
from opencompass.datasets import SmolInstructDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess
nc_0shot_reader_cfg = dict(
input_columns=['input'],
output_column='output',
train_split='validation')
nc_hint_dict = {
'I2F': """You are an expert chemist. Given the IUPAC representation of compounds, your task is to predict the molecular formula of the compound.
The input contains the IUPAC representation of the compound. Your reply should contain only the molecular formula of the compound wrapped in <MOLFORMULA> and </MOLFORMULA> tags and no other text. Your reply must be valid and chemically reasonable.""",
'I2S': """You are an expert chemist. Given the IUPAC representation of compounds, your task is to predict the SMILES representation of the compound.
The input contains the IUPAC representation of the compound. Your reply should contain only the SMILES representation of the compound wrapped in <SMILES> and </SMILES> tags and no other text. Your reply must be valid and chemically reasonable.""",
'S2F': """You are an expert chemist. Given the SMILES representation of compounds, your task is to predict the molecular formula of the compound.
The input contains the SMILES representation of the compound. Your reply should contain only the molecular formula of the compound wrapped in <MOLFORMULA> and </MOLFORMULA> tags and no other text. Your reply must be valid and chemically reasonable.""",
'S2I': """You are an expert chemist. Given the SMILES representation of compounds, your task is to predict the IUPAC representation of the compound.
The input contains the SMILES representation of the compound. Your reply should contain only the IUPAC representation of the compound wrapped in <IUPAC> and </IUPAC> tags and no other text. Your reply must be valid and chemically reasonable.""",
}
name_dict = {
'I2F': 'name_conversion-i2f',
'I2S': 'name_conversion-i2s',
'S2F': 'name_conversion-s2f',
'S2I': 'name_conversion-s2i',
}
nc_0shot_instruct_datasets = []
for _name in name_dict:
_hint = nc_hint_dict[_name]
nc_0shot_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=f'{_hint}\nQuestion: {{input}}\nAnswer: ',
# template=f'<s>[INST] {{input}} [/INST]',
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
if _name in ['I2F', 'S2F']:
nc_0shot_eval_cfg = dict(
evaluator=dict(type=NCElementMatchEvaluator),
)
else:
nc_0shot_eval_cfg = dict(
evaluator=dict(type=NCExactMatchEvaluator),
)
nc_0shot_instruct_datasets.append(
dict(
abbr=f'NC-{_name}-0shot-instruct',
type=SmolInstructDataset,
path='osunlp/SMolInstruct',
name=name_dict[_name],
reader_cfg=nc_0shot_reader_cfg,
infer_cfg=nc_0shot_infer_cfg,
eval_cfg=nc_0shot_eval_cfg,
))
del _name

View File

@ -0,0 +1,61 @@
from opencompass.openicl import AccEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import SmolInstructDataset
from opencompass.datasets.smolinstruct import smolinstruct_acc_0shot_postprocess
pp_acc_reader_cfg = dict(
input_columns=['input'],
output_column='output',
train_split='validation')
pp_acc_hint_dict = {
'BBBP': """You are an expert chemist. Given the smiles representation of the compound, your task is to predict whether blood-brain barrier permeability (BBBP) is a property of the compound.
The input contains the compound. Your reply should only contain Yes or No. Your reply must be valid and chemically reasonable.""",
'ClinTox': """You are an expert chemist. Given the smiles representation of the compound, your task is to predict whether the compound is toxic.
The input contains the compound. Your reply should contain only Yes or No. Your reply must be valid and chemically reasonable.""",
'HIV': """You are an expert chemist. Given the smiles representation of the compound, your task is to predict whether the compound serve as an inhibitor of HIV replication.
The input contains the compound. Your reply should contain only Yes or No. Your reply must be valid and chemically reasonable.""",
'SIDER': """You are an expert chemist. Given the smiles representation of the compound, your task is to predict whether the compound has any side effects.
The input contains the compound. Your reply should contain only Yes or No. Your reply must be valid and chemically reasonable.""",
}
name_dict = {
'BBBP': "property_prediction-bbbp",
'ClinTox': "property_prediction-clintox",
'HIV': "property_prediction-hiv",
'SIDER': "property_prediction-sider",
}
pp_acc_datasets_0shot_instruct = []
for _name in pp_acc_hint_dict:
_hint = pp_acc_hint_dict[_name]
pp_acc_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=f'{_hint}\nQuestion: {{input}}\nAnswer: ',
# template=f'<s>[INST] {{input}} [/INST]',
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
pp_acc_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=smolinstruct_acc_0shot_postprocess)
)
pp_acc_datasets_0shot_instruct.append(
dict(
abbr=f'PP-{_name}-0shot-instruct',
type=SmolInstructDataset,
path='osunlp/SMolInstruct',
name=name_dict[_name],
reader_cfg=pp_acc_reader_cfg,
infer_cfg=pp_acc_infer_cfg,
eval_cfg=pp_acc_eval_cfg,
))
del _name, _hint

View File

@ -0,0 +1,52 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.smolinstruct import RMSEEvaluator
from opencompass.datasets import SmolInstructDataset
pp_rmse_0shot_reader_cfg = dict(
input_columns=['input'],
output_column='output',
train_split='validation')
pp_rmse_hint_dict = {
'ESOL': """You are an expert chemist. Given the SMILES representation of compounds, your task is to predict the log solubility of the compound.
The input contains the SMILES representation of the compound. Your reply should contain the log solubility of the compound wrapped in \\boxed{}. Your reply must be valid and chemically reasonable.""",
'Lipo': """You are an expert chemist. Given the SMILES representation of compounds, your task is to predict the octanol/water partition coefficient of the compound.
The input contains the SMILES representation of the compound. Your reply should contain the octanol/water partition coefficient of the compound wrapped in \\boxed{}. Your reply must be valid and chemically reasonable."""
}
name_dict = {
'ESOL': 'property_prediction-esol',
'Lipo': 'property_prediction-lipo'
}
pp_rmse_0shot_instruct_datasets = []
for _name in name_dict:
_hint = pp_rmse_hint_dict[_name]
pp_rmse_0shot_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=f'{_hint}\nQuestion: {{input}}\nAnswer: ',
# template=f'<s>[INST] {{input}} [/INST]',
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
pp_rmse_0shot_eval_cfg = dict(
evaluator=dict(type=RMSEEvaluator),
)
pp_rmse_0shot_instruct_datasets.append(
dict(
abbr=f'PP-{_name}-0shot-instruct',
type=SmolInstructDataset,
path='osunlp/SMolInstruct',
name=name_dict[_name],
reader_cfg=pp_rmse_0shot_reader_cfg,
infer_cfg=pp_rmse_0shot_infer_cfg,
eval_cfg=pp_rmse_0shot_eval_cfg,
))
del _name

View File

@ -20,7 +20,7 @@ class SmolInstructDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
raw_dataset = load_dataset(path)
raw_dataset = load_dataset(path, trust_remote_code=True)
for split in ['validation', 'test']:
raw_data = []
for data in raw_dataset[split]:
@ -31,10 +31,22 @@ class SmolInstructDataset(BaseDataset):
def extract_chemical_data(text):
other_patterns = [
"reactants and reagents are:\n```\n",
"reactants and reagents:\n```\n",
"Reactants and Reagents:**\n```\n",
"Reactants and Reagents SMILES:**\n```\n",
"Reactants and Reagents:** \n`"
]
pattern = re.compile(r'<(MOLFORMULA|SMILES|IUPAC)>(.*?)</\1>', re.DOTALL)
matches = pattern.findall(text)
if not matches:
return []
for other_pattern in other_patterns:
if other_pattern in text:
text = text.split(other_pattern)[-1].split("\n")[0]
break
return [text]
return [match[1].strip() for match in matches]
@ -110,9 +122,9 @@ def calculate_single_element_match_for_list(predictions, references):
ele_invalid_labels = []
details = []
for pred_formula, gold_formula in zip(predictions, references):
gold_formula = gold_formula[0]
gold_formula = gold_formula[-1]
if pred_formula:
pred_formula = pred_formula[0]
pred_formula = pred_formula[-1]
detail = {'pred': [pred_formula], 'answer': gold_formula}
if not pred_formula or not pred_formula:
ele_invalid_labels.append(False)
@ -158,9 +170,9 @@ def calculate_single_element_match(predictions, references):
ele_invalid_labels = []
details = []
for pred_formula, gold_formula in zip(predictions, references):
gold_formula = gold_formula[0]
gold_formula = gold_formula[-1]
if pred_formula:
pred_formula = pred_formula[0]
pred_formula = pred_formula[-1]
detail = {'pred': pred_formula, 'answer': gold_formula}
if not pred_formula or not pred_formula:
ele_invalid_labels.append(False)
@ -208,7 +220,7 @@ class NCElementMatchEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
# topk的prediction要拆开
@ -259,7 +271,7 @@ class NCExactMatchEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
predictions = [
extract_chemical_data(prediction) for prediction in predictions
@ -272,9 +284,9 @@ class NCExactMatchEvaluator(BaseEvaluator):
valid_cnt = 0
details = []
for pred, ans in zip(predictions, references):
ans = ans[0]
ans = ans[-1]
if pred:
pred = pred[0]
pred = pred[-1]
valid_cnt += 1
detail = {'pred': pred, 'answer': ans}
if pred and pred.strip() == ans.strip():
@ -291,7 +303,9 @@ class NCExactMatchEvaluator(BaseEvaluator):
def extract_number(text):
pattern = re.compile(r'<NUMBER>\s*(-?\d*\.?\d+)\s*</NUMBER>')
pattern = re.compile(
r'(?:<NUMBER>\s*|\\boxed\{)\s*(-?\d*\.?\d+)\s*(?:</NUMBER>|\})'
)
matches = pattern.findall(text)
return [float(match) for match in matches]
@ -307,7 +321,7 @@ class RMSEEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
avg_score = 0
@ -324,7 +338,7 @@ class RMSEEvaluator(BaseEvaluator):
except:
raise ValueError(f'ans: {reference}')
detail = {'pred': pred, 'answer': ans}
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans))**2))
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans)) ** 2))
detail['score'] = rmse_score
avg_score += rmse_score
details.append(detail)
@ -345,7 +359,7 @@ class FTSEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
predictions = [
@ -359,12 +373,12 @@ class FTSEvaluator(BaseEvaluator):
valid_cnt = 0
details = []
for pred, ans in zip(predictions, references):
ans = ans[0]
ans = ans[-1]
if not pred:
detail = {'pred': '', 'answer': ans, 'score': 0}
details.append(detail)
continue
pred = pred[0]
pred = pred[-1]
detail = {'pred': pred, 'answer': ans}
# 将 SMILES 转换为 RDKit 分子对象
from rdkit import Chem
@ -405,7 +419,7 @@ class MeteorEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
avg_score = 0
details = []
@ -433,3 +447,31 @@ def smolinstruct_acc_postprocess(text: str) -> str:
return '<BOOLEAN> Yes </BOOLEAN>'
elif 'no' in text.lower():
return '<BOOLEAN> No </BOOLEAN>'
@TEXT_POSTPROCESSORS.register_module('smolinstruct-acc-0shot')
def smolinstruct_acc_0shot_postprocess(text: str) -> str:
if text.strip().lower() == 'yes':
return "<BOOLEAN> Yes </BOOLEAN>"
elif text.strip().lower() == 'no':
return "<BOOLEAN> No </BOOLEAN>"
# Define regex patterns to match various formats of "yes" or "no"
patterns = [
r'\\boxed\{\s*(yes|no)\s*\}',
r'[Th]he\s+answer\s+is\s*[\.:\'"“‘’\-]*\s*(yes|no)[\s\.,!?:;\'"”’\-]*',
r'[Aa]nswer:\s*(yes|no)\b',
r'\*\*[Aa]nswer:\*\*\s*(yes|no)\b',
r'\*\*[Aa]nswer\*\*:\s*(yes|no)\b',
r'<BOOLEAN>\s*(yes|no)\s*</BOOLEAN>'
]
for pattern in patterns:
text = text.strip()
match = re.search(pattern, text, flags=re.IGNORECASE)
if match:
answer = match.group(1) # modified
if answer.lower() == 'yes' or answer.lower() == 'yes.':
return "<BOOLEAN> Yes </BOOLEAN>"
elif answer.lower() == 'no' or answer.lower() == 'no.':
return "<BOOLEAN> No </BOOLEAN>"
return ""