OpenCompass/opencompass/datasets/smolinstruct.py

427 lines
14 KiB
Python
Raw Normal View History

# flake8: noqa: W605
import re
from collections import defaultdict
import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from .base import BaseDataset
@LOAD_DATASET.register_module()
class SmolInstructDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
raw_dataset = load_dataset(path)
for split in ['validation', 'test']:
raw_data = []
for data in raw_dataset[split]:
if data['task'] == name:
raw_data.append(data)
dataset[split] = Dataset.from_list(raw_data)
return dataset
def extract_chemical_data(text):
pattern = re.compile(r'<(MOLFORMULA|SMILES|IUPAC)>(.*?)</\1>', re.DOTALL)
matches = pattern.findall(text)
if not matches:
return []
return [match[1].strip() for match in matches]
def parse_molecule(molecular_formula):
valid = re.match('([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula)
if valid is None:
raise ValueError("Molecular formula \"%s\" is not valid." %
molecular_formula)
stack = [defaultdict(int)]
def _parse_formula(formula, _stack):
# Set remainder equal to 'None'
r = None
# Regular expression matching for each of the three cases:
atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula)
opening = re.match(r'[\(\[\{]', formula)
closing = re.match(r'[\)\]\}](\d+)?', formula)
# If atom is identified:
if atom:
r = formula[len(atom.group()):]
_stack[-1][atom.group(1)] += int(atom.group(2) or 1)
# If opening brackets encountered:
elif opening:
r = formula[len(
opening.group()
):] # this sets the remainder equal to everything after the opening brackets
_stack.append(defaultdict(int))
# If closing brackets encountered:
elif closing:
r = formula[len(
closing.group()
):] # this sets the remainder equal to everything after the closing brackets
for k, v in _stack.pop().items():
_stack[-1][k] += v * int(
closing.group(1)
or 1) # v times amount of molecule k, depending on nesting
# If anything remains, process remainders recursively as nested formulas:
if r:
_parse_formula(r, _stack)
return dict(_stack[0])
result = _parse_formula(molecular_formula, stack)
charge = re.search('[\+\-]\d*', molecular_formula)
if charge is not None:
charge_str = charge.group()
charge_type = charge_str[0]
if len(charge_str) == 1:
charge_num = 1
else:
charge_num = int(charge_str[1:])
result[charge_type] = charge_num
return result
def calculate_single_element_match_for_list(predictions, references):
# 抽取SMILES里的化学式
predictions = [
extract_chemical_data(prediction) for prediction in predictions
]
references = [extract_chemical_data(reference) for reference in references]
ele_match_labels = []
ele_invalid_labels = []
details = []
for pred_formula, gold_formula in zip(predictions, references):
gold_formula = gold_formula[0]
if pred_formula:
pred_formula = pred_formula[0]
detail = {'pred': [pred_formula], 'answer': gold_formula}
if not pred_formula or not pred_formula:
ele_invalid_labels.append(False)
ele_match_labels.append(False)
detail['score'] = [False]
details.append(detail)
continue
try:
pred_ele = parse_molecule(pred_formula)
except KeyboardInterrupt:
raise
except:
# print(pred_formula)
# print('=====')
ele_invalid_labels.append(True)
ele_match_labels.append(False)
detail['score'] = [False]
details.append(detail)
continue
ele_invalid_labels.append(False)
ele_match = False
gold_ele = parse_molecule(gold_formula)
if pred_ele == gold_ele:
ele_match = True
ele_match_labels.append(ele_match)
detail['score'] = [ele_match]
details.append(detail)
score = sum(ele_match_labels) / len(predictions) * 100
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
return {'score': score, 'valid_score': valid_score, 'details': details}
def calculate_single_element_match(predictions, references):
# 抽取SMILES里的化学式
predictions = [
extract_chemical_data(prediction) for prediction in predictions
]
references = [extract_chemical_data(reference) for reference in references]
ele_match_labels = []
ele_invalid_labels = []
details = []
for pred_formula, gold_formula in zip(predictions, references):
gold_formula = gold_formula[0]
if pred_formula:
pred_formula = pred_formula[0]
detail = {'pred': pred_formula, 'answer': gold_formula}
if not pred_formula or not pred_formula:
ele_invalid_labels.append(False)
ele_match_labels.append(False)
detail['score'] = False
details.append(detail)
continue
try:
pred_ele = parse_molecule(pred_formula)
except KeyboardInterrupt:
raise
except:
# print(pred_formula)
# print('=====')
ele_invalid_labels.append(True)
ele_match_labels.append(False)
detail['score'] = False
details.append(detail)
continue
ele_invalid_labels.append(False)
ele_match = False
gold_ele = parse_molecule(gold_formula)
if pred_ele == gold_ele:
ele_match = True
ele_match_labels.append(ele_match)
detail['score'] = ele_match
details.append(detail)
score = sum(ele_match_labels) / len(predictions) * 100
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
return {'score': score, 'valid_score': valid_score, 'details': details}
@ICL_EVALUATORS.register_module()
class NCElementMatchEvaluator(BaseEvaluator):
"""Element match evaluator for name conversion."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
print('len(predictions):', len(predictions))
print('len(references):', len(references))
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
# topk的prediction要拆开
if isinstance(predictions[0], str):
return calculate_single_element_match(predictions, references)
else:
num_k = len(predictions[0])
scores = []
for i in range(num_k):
pred = [prediction[i] for prediction in predictions]
ref = references
score = calculate_single_element_match_for_list(pred, ref)
scores.append(score)
# 按照instance合并成一个完整的dict
final_details = scores[0]['details']
final_scores = [scores[0]['score']]
final_valid_scores = [scores[0]['valid_score']]
for _k in scores[1:]:
for i, _d in enumerate(_k['details']):
# print(_d)
final_details[i]['pred'].extend(_d['pred'])
final_details[i]['score'].extend(_d['score'])
final_scores.append(_k['score'])
final_valid_scores.append(_k['valid_score'])
avg_score = []
for _d in final_details:
if True in _d['score']:
avg_score.append(1)
else:
avg_score.append(0)
max_score = sum(avg_score) / len(avg_score) * 100
return {
'score': max_score,
'all_score': final_scores,
'valid_score': final_valid_scores,
'details': final_details,
}
@ICL_EVALUATORS.register_module()
class NCExactMatchEvaluator(BaseEvaluator):
"""Exact match evaluator for name conversion."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
predictions = [
extract_chemical_data(prediction) for prediction in predictions
]
references = [
extract_chemical_data(reference) for reference in references
]
cnt = 0
valid_cnt = 0
details = []
for pred, ans in zip(predictions, references):
ans = ans[0]
if pred:
pred = pred[0]
valid_cnt += 1
detail = {'pred': pred, 'answer': ans}
if pred and pred.strip() == ans.strip():
cnt += 1
detail['correct'] = True
else:
detail['correct'] = False
details.append(detail)
score = cnt / len(predictions) * 100
valid_score = valid_cnt / len(predictions) * 100
return {'score': score, 'valid_score': valid_score, 'details': details}
def extract_number(text):
pattern = re.compile(r'<NUMBER>\s*(-?\d*\.?\d+)\s*</NUMBER>')
matches = pattern.findall(text)
return [float(match) for match in matches]
@ICL_EVALUATORS.register_module()
class RMSEEvaluator(BaseEvaluator):
"""Exact match evaluator for name conversion."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
avg_score = 0
details = []
for prediction, reference in zip(predictions, references):
pred = extract_number(prediction)
ans = extract_number(reference)
if not pred:
pred = 0
else:
pred = pred[0]
try:
ans = ans[0]
except:
raise ValueError(f'ans: {reference}')
detail = {'pred': pred, 'answer': ans}
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans))**2))
detail['score'] = rmse_score
avg_score += rmse_score
details.append(detail)
score = avg_score / len(predictions)
return {'score': score, 'details': details}
@ICL_EVALUATORS.register_module()
class FTSEvaluator(BaseEvaluator):
"""Exact match evaluator for name conversion."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
predictions = [
extract_chemical_data(prediction) for prediction in predictions
]
references = [
extract_chemical_data(reference) for reference in references
]
avg_score = 0
valid_cnt = 0
details = []
for pred, ans in zip(predictions, references):
ans = ans[0]
if not pred:
detail = {'pred': '', 'answer': ans, 'score': 0}
details.append(detail)
continue
pred = pred[0]
detail = {'pred': pred, 'answer': ans}
# 将 SMILES 转换为 RDKit 分子对象
from rdkit import Chem
mol1 = Chem.MolFromSmiles(pred)
mol2 = Chem.MolFromSmiles(ans)
if mol1 is None or mol2 is None:
detail['score'] = 0
details.append(detail)
continue
valid_cnt += 1
# 生成 Morgan 指纹(等同于 ECFP4
# fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048)
# fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048)
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
generator = GetMorganGenerator(radius=2, fpSize=2048)
fp1 = generator.GetFingerprint(mol1)
fp2 = generator.GetFingerprint(mol2)
from rdkit.Chem import DataStructs
similarity = DataStructs.TanimotoSimilarity(fp1, fp2) * 100
detail['score'] = similarity
avg_score += similarity
details.append(detail)
score = avg_score / len(predictions)
valid_score = valid_cnt / len(predictions) * 100
return {'score': score, 'valid_score': valid_score, 'details': details}
@ICL_EVALUATORS.register_module()
class MeteorEvaluator(BaseEvaluator):
"""Exact match evaluator for name conversion."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
avg_score = 0
details = []
for pred, ans in zip(predictions, references):
score = meteor_score([ans.split()], pred.split())
avg_score += score
detail = {'pred': pred, 'answer': ans, 'score': score}
details.append(detail)
score = avg_score / len(predictions)
return {'score': score, 'details': details}
@TEXT_POSTPROCESSORS.register_module('smolinstruct-acc')
def smolinstruct_acc_postprocess(text: str) -> str:
if 'yes' in text.lower():
return '<BOOLEAN> Yes </BOOLEAN>'
elif 'no' in text.lower():
return '<BOOLEAN> No </BOOLEAN>'