mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* [Dataset] Add SmolInstruct, Update Chembench * Add dataset metadata * update * update * update
427 lines
14 KiB
Python
427 lines
14 KiB
Python
# flake8: noqa: W605
|
||
import re
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
from datasets import Dataset, DatasetDict, load_dataset
|
||
|
||
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
|
||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||
TEXT_POSTPROCESSORS)
|
||
|
||
from .base import BaseDataset
|
||
|
||
|
||
@LOAD_DATASET.register_module()
|
||
class SmolInstructDataset(BaseDataset):
|
||
|
||
@staticmethod
|
||
def load(path: str, name: str):
|
||
dataset = DatasetDict()
|
||
raw_dataset = load_dataset(path)
|
||
for split in ['validation', 'test']:
|
||
raw_data = []
|
||
for data in raw_dataset[split]:
|
||
if data['task'] == name:
|
||
raw_data.append(data)
|
||
dataset[split] = Dataset.from_list(raw_data)
|
||
return dataset
|
||
|
||
|
||
def extract_chemical_data(text):
|
||
pattern = re.compile(r'<(MOLFORMULA|SMILES|IUPAC)>(.*?)</\1>', re.DOTALL)
|
||
matches = pattern.findall(text)
|
||
if not matches:
|
||
return []
|
||
return [match[1].strip() for match in matches]
|
||
|
||
|
||
def parse_molecule(molecular_formula):
|
||
valid = re.match('([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula)
|
||
if valid is None:
|
||
raise ValueError("Molecular formula \"%s\" is not valid." %
|
||
molecular_formula)
|
||
|
||
stack = [defaultdict(int)]
|
||
|
||
def _parse_formula(formula, _stack):
|
||
|
||
# Set remainder equal to 'None'
|
||
r = None
|
||
|
||
# Regular expression matching for each of the three cases:
|
||
atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula)
|
||
opening = re.match(r'[\(\[\{]', formula)
|
||
closing = re.match(r'[\)\]\}](\d+)?', formula)
|
||
|
||
# If atom is identified:
|
||
if atom:
|
||
r = formula[len(atom.group()):]
|
||
_stack[-1][atom.group(1)] += int(atom.group(2) or 1)
|
||
|
||
# If opening brackets encountered:
|
||
elif opening:
|
||
r = formula[len(
|
||
opening.group()
|
||
):] # this sets the remainder equal to everything after the opening brackets
|
||
_stack.append(defaultdict(int))
|
||
|
||
# If closing brackets encountered:
|
||
elif closing:
|
||
r = formula[len(
|
||
closing.group()
|
||
):] # this sets the remainder equal to everything after the closing brackets
|
||
for k, v in _stack.pop().items():
|
||
_stack[-1][k] += v * int(
|
||
closing.group(1)
|
||
or 1) # v times amount of molecule k, depending on nesting
|
||
|
||
# If anything remains, process remainders recursively as nested formulas:
|
||
if r:
|
||
_parse_formula(r, _stack)
|
||
|
||
return dict(_stack[0])
|
||
|
||
result = _parse_formula(molecular_formula, stack)
|
||
|
||
charge = re.search('[\+\-]\d*', molecular_formula)
|
||
if charge is not None:
|
||
charge_str = charge.group()
|
||
charge_type = charge_str[0]
|
||
if len(charge_str) == 1:
|
||
charge_num = 1
|
||
else:
|
||
charge_num = int(charge_str[1:])
|
||
result[charge_type] = charge_num
|
||
|
||
return result
|
||
|
||
|
||
def calculate_single_element_match_for_list(predictions, references):
|
||
# 抽取SMILES里的化学式
|
||
predictions = [
|
||
extract_chemical_data(prediction) for prediction in predictions
|
||
]
|
||
references = [extract_chemical_data(reference) for reference in references]
|
||
|
||
ele_match_labels = []
|
||
ele_invalid_labels = []
|
||
details = []
|
||
for pred_formula, gold_formula in zip(predictions, references):
|
||
gold_formula = gold_formula[0]
|
||
if pred_formula:
|
||
pred_formula = pred_formula[0]
|
||
detail = {'pred': [pred_formula], 'answer': gold_formula}
|
||
if not pred_formula or not pred_formula:
|
||
ele_invalid_labels.append(False)
|
||
ele_match_labels.append(False)
|
||
detail['score'] = [False]
|
||
details.append(detail)
|
||
continue
|
||
try:
|
||
pred_ele = parse_molecule(pred_formula)
|
||
except KeyboardInterrupt:
|
||
raise
|
||
except:
|
||
# print(pred_formula)
|
||
# print('=====')
|
||
ele_invalid_labels.append(True)
|
||
ele_match_labels.append(False)
|
||
detail['score'] = [False]
|
||
details.append(detail)
|
||
continue
|
||
ele_invalid_labels.append(False)
|
||
ele_match = False
|
||
gold_ele = parse_molecule(gold_formula)
|
||
if pred_ele == gold_ele:
|
||
ele_match = True
|
||
ele_match_labels.append(ele_match)
|
||
detail['score'] = [ele_match]
|
||
details.append(detail)
|
||
|
||
score = sum(ele_match_labels) / len(predictions) * 100
|
||
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
|
||
|
||
return {'score': score, 'valid_score': valid_score, 'details': details}
|
||
|
||
|
||
def calculate_single_element_match(predictions, references):
|
||
# 抽取SMILES里的化学式
|
||
predictions = [
|
||
extract_chemical_data(prediction) for prediction in predictions
|
||
]
|
||
references = [extract_chemical_data(reference) for reference in references]
|
||
|
||
ele_match_labels = []
|
||
ele_invalid_labels = []
|
||
details = []
|
||
for pred_formula, gold_formula in zip(predictions, references):
|
||
gold_formula = gold_formula[0]
|
||
if pred_formula:
|
||
pred_formula = pred_formula[0]
|
||
detail = {'pred': pred_formula, 'answer': gold_formula}
|
||
if not pred_formula or not pred_formula:
|
||
ele_invalid_labels.append(False)
|
||
ele_match_labels.append(False)
|
||
detail['score'] = False
|
||
details.append(detail)
|
||
continue
|
||
try:
|
||
pred_ele = parse_molecule(pred_formula)
|
||
except KeyboardInterrupt:
|
||
raise
|
||
except:
|
||
# print(pred_formula)
|
||
# print('=====')
|
||
ele_invalid_labels.append(True)
|
||
ele_match_labels.append(False)
|
||
detail['score'] = False
|
||
details.append(detail)
|
||
continue
|
||
ele_invalid_labels.append(False)
|
||
ele_match = False
|
||
gold_ele = parse_molecule(gold_formula)
|
||
if pred_ele == gold_ele:
|
||
ele_match = True
|
||
ele_match_labels.append(ele_match)
|
||
detail['score'] = ele_match
|
||
details.append(detail)
|
||
|
||
score = sum(ele_match_labels) / len(predictions) * 100
|
||
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
|
||
|
||
return {'score': score, 'valid_score': valid_score, 'details': details}
|
||
|
||
|
||
@ICL_EVALUATORS.register_module()
|
||
class NCElementMatchEvaluator(BaseEvaluator):
|
||
"""Element match evaluator for name conversion."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def score(self, predictions, references):
|
||
print('len(predictions):', len(predictions))
|
||
print('len(references):', len(references))
|
||
if len(predictions) != len(references):
|
||
return {
|
||
'error': 'predictions and references have different '
|
||
'length'
|
||
}
|
||
|
||
# topk的prediction,要拆开
|
||
if isinstance(predictions[0], str):
|
||
return calculate_single_element_match(predictions, references)
|
||
else:
|
||
num_k = len(predictions[0])
|
||
scores = []
|
||
for i in range(num_k):
|
||
pred = [prediction[i] for prediction in predictions]
|
||
ref = references
|
||
score = calculate_single_element_match_for_list(pred, ref)
|
||
scores.append(score)
|
||
# 按照instance合并成一个完整的dict
|
||
final_details = scores[0]['details']
|
||
final_scores = [scores[0]['score']]
|
||
final_valid_scores = [scores[0]['valid_score']]
|
||
for _k in scores[1:]:
|
||
for i, _d in enumerate(_k['details']):
|
||
# print(_d)
|
||
final_details[i]['pred'].extend(_d['pred'])
|
||
final_details[i]['score'].extend(_d['score'])
|
||
final_scores.append(_k['score'])
|
||
final_valid_scores.append(_k['valid_score'])
|
||
avg_score = []
|
||
for _d in final_details:
|
||
if True in _d['score']:
|
||
avg_score.append(1)
|
||
else:
|
||
avg_score.append(0)
|
||
max_score = sum(avg_score) / len(avg_score) * 100
|
||
return {
|
||
'score': max_score,
|
||
'all_score': final_scores,
|
||
'valid_score': final_valid_scores,
|
||
'details': final_details,
|
||
}
|
||
|
||
|
||
@ICL_EVALUATORS.register_module()
|
||
class NCExactMatchEvaluator(BaseEvaluator):
|
||
"""Exact match evaluator for name conversion."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def score(self, predictions, references):
|
||
if len(predictions) != len(references):
|
||
return {
|
||
'error': 'predictions and references have different '
|
||
'length'
|
||
}
|
||
predictions = [
|
||
extract_chemical_data(prediction) for prediction in predictions
|
||
]
|
||
references = [
|
||
extract_chemical_data(reference) for reference in references
|
||
]
|
||
|
||
cnt = 0
|
||
valid_cnt = 0
|
||
details = []
|
||
for pred, ans in zip(predictions, references):
|
||
ans = ans[0]
|
||
if pred:
|
||
pred = pred[0]
|
||
valid_cnt += 1
|
||
detail = {'pred': pred, 'answer': ans}
|
||
if pred and pred.strip() == ans.strip():
|
||
cnt += 1
|
||
detail['correct'] = True
|
||
else:
|
||
detail['correct'] = False
|
||
details.append(detail)
|
||
|
||
score = cnt / len(predictions) * 100
|
||
valid_score = valid_cnt / len(predictions) * 100
|
||
|
||
return {'score': score, 'valid_score': valid_score, 'details': details}
|
||
|
||
|
||
def extract_number(text):
|
||
pattern = re.compile(r'<NUMBER>\s*(-?\d*\.?\d+)\s*</NUMBER>')
|
||
matches = pattern.findall(text)
|
||
return [float(match) for match in matches]
|
||
|
||
|
||
@ICL_EVALUATORS.register_module()
|
||
class RMSEEvaluator(BaseEvaluator):
|
||
"""Exact match evaluator for name conversion."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def score(self, predictions, references):
|
||
if len(predictions) != len(references):
|
||
return {
|
||
'error': 'predictions and references have different '
|
||
'length'
|
||
}
|
||
|
||
avg_score = 0
|
||
details = []
|
||
for prediction, reference in zip(predictions, references):
|
||
pred = extract_number(prediction)
|
||
ans = extract_number(reference)
|
||
if not pred:
|
||
pred = 0
|
||
else:
|
||
pred = pred[0]
|
||
try:
|
||
ans = ans[0]
|
||
except:
|
||
raise ValueError(f'ans: {reference}')
|
||
detail = {'pred': pred, 'answer': ans}
|
||
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans))**2))
|
||
detail['score'] = rmse_score
|
||
avg_score += rmse_score
|
||
details.append(detail)
|
||
|
||
score = avg_score / len(predictions)
|
||
|
||
return {'score': score, 'details': details}
|
||
|
||
|
||
@ICL_EVALUATORS.register_module()
|
||
class FTSEvaluator(BaseEvaluator):
|
||
"""Exact match evaluator for name conversion."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def score(self, predictions, references):
|
||
if len(predictions) != len(references):
|
||
return {
|
||
'error': 'predictions and references have different '
|
||
'length'
|
||
}
|
||
|
||
predictions = [
|
||
extract_chemical_data(prediction) for prediction in predictions
|
||
]
|
||
references = [
|
||
extract_chemical_data(reference) for reference in references
|
||
]
|
||
|
||
avg_score = 0
|
||
valid_cnt = 0
|
||
details = []
|
||
for pred, ans in zip(predictions, references):
|
||
ans = ans[0]
|
||
if not pred:
|
||
detail = {'pred': '', 'answer': ans, 'score': 0}
|
||
details.append(detail)
|
||
continue
|
||
pred = pred[0]
|
||
detail = {'pred': pred, 'answer': ans}
|
||
# 将 SMILES 转换为 RDKit 分子对象
|
||
from rdkit import Chem
|
||
mol1 = Chem.MolFromSmiles(pred)
|
||
mol2 = Chem.MolFromSmiles(ans)
|
||
if mol1 is None or mol2 is None:
|
||
detail['score'] = 0
|
||
details.append(detail)
|
||
continue
|
||
valid_cnt += 1
|
||
# 生成 Morgan 指纹(等同于 ECFP4)
|
||
# fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048)
|
||
# fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048)
|
||
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
|
||
generator = GetMorganGenerator(radius=2, fpSize=2048)
|
||
fp1 = generator.GetFingerprint(mol1)
|
||
fp2 = generator.GetFingerprint(mol2)
|
||
from rdkit.Chem import DataStructs
|
||
similarity = DataStructs.TanimotoSimilarity(fp1, fp2) * 100
|
||
detail['score'] = similarity
|
||
avg_score += similarity
|
||
details.append(detail)
|
||
|
||
score = avg_score / len(predictions)
|
||
valid_score = valid_cnt / len(predictions) * 100
|
||
|
||
return {'score': score, 'valid_score': valid_score, 'details': details}
|
||
|
||
|
||
@ICL_EVALUATORS.register_module()
|
||
class MeteorEvaluator(BaseEvaluator):
|
||
"""Exact match evaluator for name conversion."""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def score(self, predictions, references):
|
||
if len(predictions) != len(references):
|
||
return {
|
||
'error': 'predictions and references have different '
|
||
'length'
|
||
}
|
||
avg_score = 0
|
||
details = []
|
||
for pred, ans in zip(predictions, references):
|
||
score = meteor_score([ans.split()], pred.split())
|
||
avg_score += score
|
||
detail = {'pred': pred, 'answer': ans, 'score': score}
|
||
details.append(detail)
|
||
|
||
score = avg_score / len(predictions)
|
||
|
||
return {'score': score, 'details': details}
|
||
|
||
|
||
@TEXT_POSTPROCESSORS.register_module('smolinstruct-acc')
|
||
def smolinstruct_acc_postprocess(text: str) -> str:
|
||
if 'yes' in text.lower():
|
||
return '<BOOLEAN> Yes </BOOLEAN>'
|
||
elif 'no' in text.lower():
|
||
return '<BOOLEAN> No </BOOLEAN>'
|