rename 0-shot

This commit is contained in:
yusun-nlp 2025-05-28 13:39:22 +00:00
parent 523f0f1f18
commit f376bb1598
3 changed files with 27 additions and 31 deletions

View File

@ -22,10 +22,10 @@ pp_acc_hint_dict = {
}
name_dict = {
'BBBP': "property_prediction-bbbp",
'ClinTox': "property_prediction-clintox",
'HIV': "property_prediction-hiv",
'SIDER': "property_prediction-sider",
'BBBP': 'property_prediction-bbbp',
'ClinTox': 'property_prediction-clintox',
'HIV': 'property_prediction-hiv',
'SIDER': 'property_prediction-sider',
}
pp_acc_datasets_0shot_instruct = []

View File

@ -32,11 +32,10 @@ class SmolInstructDataset(BaseDataset):
def extract_chemical_data(text):
other_patterns = [
"reactants and reagents are:\n```\n",
"reactants and reagents:\n```\n",
"Reactants and Reagents:**\n```\n",
"Reactants and Reagents SMILES:**\n```\n",
"Reactants and Reagents:** \n`"
'reactants and reagents are:\n```\n', 'reactants and reagents:\n```\n',
'Reactants and Reagents:**\n```\n',
'Reactants and Reagents SMILES:**\n```\n',
'Reactants and Reagents:** \n`'
]
pattern = re.compile(r'<(MOLFORMULA|SMILES|IUPAC)>(.*?)</\1>', re.DOTALL)
@ -44,7 +43,7 @@ def extract_chemical_data(text):
if not matches:
for other_pattern in other_patterns:
if other_pattern in text:
text = text.split(other_pattern)[-1].split("\n")[0]
text = text.split(other_pattern)[-1].split('\n')[0]
break
return [text]
return [match[1].strip() for match in matches]
@ -220,7 +219,7 @@ class NCElementMatchEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
# topk的prediction要拆开
@ -271,7 +270,7 @@ class NCExactMatchEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
predictions = [
extract_chemical_data(prediction) for prediction in predictions
@ -304,8 +303,7 @@ class NCExactMatchEvaluator(BaseEvaluator):
def extract_number(text):
pattern = re.compile(
r'(?:<NUMBER>\s*|\\boxed\{)\s*(-?\d*\.?\d+)\s*(?:</NUMBER>|\})'
)
r'(?:<NUMBER>\s*|\\boxed\{)\s*(-?\d*\.?\d+)\s*(?:</NUMBER>|\})')
matches = pattern.findall(text)
return [float(match) for match in matches]
@ -321,7 +319,7 @@ class RMSEEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
avg_score = 0
@ -338,7 +336,7 @@ class RMSEEvaluator(BaseEvaluator):
except:
raise ValueError(f'ans: {reference}')
detail = {'pred': pred, 'answer': ans}
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans)) ** 2))
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans))**2))
detail['score'] = rmse_score
avg_score += rmse_score
details.append(detail)
@ -359,7 +357,7 @@ class FTSEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
predictions = [
@ -419,7 +417,7 @@ class MeteorEvaluator(BaseEvaluator):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
'length'
}
avg_score = 0
details = []
@ -452,24 +450,22 @@ def smolinstruct_acc_postprocess(text: str) -> str:
@TEXT_POSTPROCESSORS.register_module('smolinstruct-acc-0shot')
def smolinstruct_acc_0shot_postprocess(text: str) -> str:
# Remove <think> tags if they exist
if "</think>" in text:
text = text.split("</think>")[-1].strip()
if '</think>' in text:
text = text.split('</think>')[-1].strip()
# Check for exact "yes" or "no" responses
if text.strip().lower() == 'yes':
return "<BOOLEAN> Yes </BOOLEAN>"
return '<BOOLEAN> Yes </BOOLEAN>'
elif text.strip().lower() == 'no':
return "<BOOLEAN> No </BOOLEAN>"
return '<BOOLEAN> No </BOOLEAN>'
# Define regex patterns to match various formats of "yes" or "no"
patterns = [
r'\\boxed\{\s*(yes|no)\s*\}',
r'[Th]he\s+answer\s+is\s*[\.:\'"“‘’\-]*\s*(yes|no)[\s\.,!?:;\'"”’\-]*',
r'[Aa]nswer:\s*(yes|no)\b',
r'\*\*[Aa]nswer:\*\*\s*(yes|no)\b',
r'[Aa]nswer:\s*(yes|no)\b', r'\*\*[Aa]nswer:\*\*\s*(yes|no)\b',
r'\*\*[Aa]nswer\*\*:\s*(yes|no)\b',
r'<BOOLEAN>\s*(yes|no)\s*</BOOLEAN>',
r'^\s*(yes|no)[\.\?!]?\s*$'
r'<BOOLEAN>\s*(yes|no)\s*</BOOLEAN>', r'^\s*(yes|no)[\.\?!]?\s*$'
]
for pattern in patterns:
text = text.strip()
@ -477,16 +473,16 @@ def smolinstruct_acc_0shot_postprocess(text: str) -> str:
if match:
answer = match.group(1) # modified
if answer.lower() == 'yes':
return "<BOOLEAN> Yes </BOOLEAN>"
return '<BOOLEAN> Yes </BOOLEAN>'
elif answer.lower() == 'no':
return "<BOOLEAN> No </BOOLEAN>"
return '<BOOLEAN> No </BOOLEAN>'
# If no patterns matched, check for simple "yes" or "no"
text = text.strip().lower()
if text.startswith('yes') or text.endswith('yes'):
return "<BOOLEAN> Yes </BOOLEAN>"
return '<BOOLEAN> Yes </BOOLEAN>'
elif text.startswith('no') or text.endswith('no'):
return "<BOOLEAN> No </BOOLEAN>"
return '<BOOLEAN> No </BOOLEAN>'
# If no patterns matched, return an empty string
return ""
return ''