mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
427 lines
14 KiB
Python
427 lines
14 KiB
Python
![]() |
# flake8: noqa: W605
|
|||
|
import re
|
|||
|
from collections import defaultdict
|
|||
|
|
|||
|
import numpy as np
|
|||
|
from datasets import Dataset, DatasetDict, load_dataset
|
|||
|
|
|||
|
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
|
|||
|
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
|||
|
TEXT_POSTPROCESSORS)
|
|||
|
|
|||
|
from .base import BaseDataset
|
|||
|
|
|||
|
|
|||
|
@LOAD_DATASET.register_module()
|
|||
|
class SmolInstructDataset(BaseDataset):
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def load(path: str, name: str):
|
|||
|
dataset = DatasetDict()
|
|||
|
raw_dataset = load_dataset(path)
|
|||
|
for split in ['validation', 'test']:
|
|||
|
raw_data = []
|
|||
|
for data in raw_dataset[split]:
|
|||
|
if data['task'] == name:
|
|||
|
raw_data.append(data)
|
|||
|
dataset[split] = Dataset.from_list(raw_data)
|
|||
|
return dataset
|
|||
|
|
|||
|
|
|||
|
def extract_chemical_data(text):
|
|||
|
pattern = re.compile(r'<(MOLFORMULA|SMILES|IUPAC)>(.*?)</\1>', re.DOTALL)
|
|||
|
matches = pattern.findall(text)
|
|||
|
if not matches:
|
|||
|
return []
|
|||
|
return [match[1].strip() for match in matches]
|
|||
|
|
|||
|
|
|||
|
def parse_molecule(molecular_formula):
|
|||
|
valid = re.match('([A-Za-z]\d*)+([\+\-]\d*)*$', molecular_formula)
|
|||
|
if valid is None:
|
|||
|
raise ValueError("Molecular formula \"%s\" is not valid." %
|
|||
|
molecular_formula)
|
|||
|
|
|||
|
stack = [defaultdict(int)]
|
|||
|
|
|||
|
def _parse_formula(formula, _stack):
|
|||
|
|
|||
|
# Set remainder equal to 'None'
|
|||
|
r = None
|
|||
|
|
|||
|
# Regular expression matching for each of the three cases:
|
|||
|
atom = re.match(r'([A-Z][a-z]?)(\d+)?', formula)
|
|||
|
opening = re.match(r'[\(\[\{]', formula)
|
|||
|
closing = re.match(r'[\)\]\}](\d+)?', formula)
|
|||
|
|
|||
|
# If atom is identified:
|
|||
|
if atom:
|
|||
|
r = formula[len(atom.group()):]
|
|||
|
_stack[-1][atom.group(1)] += int(atom.group(2) or 1)
|
|||
|
|
|||
|
# If opening brackets encountered:
|
|||
|
elif opening:
|
|||
|
r = formula[len(
|
|||
|
opening.group()
|
|||
|
):] # this sets the remainder equal to everything after the opening brackets
|
|||
|
_stack.append(defaultdict(int))
|
|||
|
|
|||
|
# If closing brackets encountered:
|
|||
|
elif closing:
|
|||
|
r = formula[len(
|
|||
|
closing.group()
|
|||
|
):] # this sets the remainder equal to everything after the closing brackets
|
|||
|
for k, v in _stack.pop().items():
|
|||
|
_stack[-1][k] += v * int(
|
|||
|
closing.group(1)
|
|||
|
or 1) # v times amount of molecule k, depending on nesting
|
|||
|
|
|||
|
# If anything remains, process remainders recursively as nested formulas:
|
|||
|
if r:
|
|||
|
_parse_formula(r, _stack)
|
|||
|
|
|||
|
return dict(_stack[0])
|
|||
|
|
|||
|
result = _parse_formula(molecular_formula, stack)
|
|||
|
|
|||
|
charge = re.search('[\+\-]\d*', molecular_formula)
|
|||
|
if charge is not None:
|
|||
|
charge_str = charge.group()
|
|||
|
charge_type = charge_str[0]
|
|||
|
if len(charge_str) == 1:
|
|||
|
charge_num = 1
|
|||
|
else:
|
|||
|
charge_num = int(charge_str[1:])
|
|||
|
result[charge_type] = charge_num
|
|||
|
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def calculate_single_element_match_for_list(predictions, references):
|
|||
|
# 抽取SMILES里的化学式
|
|||
|
predictions = [
|
|||
|
extract_chemical_data(prediction) for prediction in predictions
|
|||
|
]
|
|||
|
references = [extract_chemical_data(reference) for reference in references]
|
|||
|
|
|||
|
ele_match_labels = []
|
|||
|
ele_invalid_labels = []
|
|||
|
details = []
|
|||
|
for pred_formula, gold_formula in zip(predictions, references):
|
|||
|
gold_formula = gold_formula[0]
|
|||
|
if pred_formula:
|
|||
|
pred_formula = pred_formula[0]
|
|||
|
detail = {'pred': [pred_formula], 'answer': gold_formula}
|
|||
|
if not pred_formula or not pred_formula:
|
|||
|
ele_invalid_labels.append(False)
|
|||
|
ele_match_labels.append(False)
|
|||
|
detail['score'] = [False]
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
try:
|
|||
|
pred_ele = parse_molecule(pred_formula)
|
|||
|
except KeyboardInterrupt:
|
|||
|
raise
|
|||
|
except:
|
|||
|
# print(pred_formula)
|
|||
|
# print('=====')
|
|||
|
ele_invalid_labels.append(True)
|
|||
|
ele_match_labels.append(False)
|
|||
|
detail['score'] = [False]
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
ele_invalid_labels.append(False)
|
|||
|
ele_match = False
|
|||
|
gold_ele = parse_molecule(gold_formula)
|
|||
|
if pred_ele == gold_ele:
|
|||
|
ele_match = True
|
|||
|
ele_match_labels.append(ele_match)
|
|||
|
detail['score'] = [ele_match]
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = sum(ele_match_labels) / len(predictions) * 100
|
|||
|
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
|
|||
|
|
|||
|
return {'score': score, 'valid_score': valid_score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
def calculate_single_element_match(predictions, references):
|
|||
|
# 抽取SMILES里的化学式
|
|||
|
predictions = [
|
|||
|
extract_chemical_data(prediction) for prediction in predictions
|
|||
|
]
|
|||
|
references = [extract_chemical_data(reference) for reference in references]
|
|||
|
|
|||
|
ele_match_labels = []
|
|||
|
ele_invalid_labels = []
|
|||
|
details = []
|
|||
|
for pred_formula, gold_formula in zip(predictions, references):
|
|||
|
gold_formula = gold_formula[0]
|
|||
|
if pred_formula:
|
|||
|
pred_formula = pred_formula[0]
|
|||
|
detail = {'pred': pred_formula, 'answer': gold_formula}
|
|||
|
if not pred_formula or not pred_formula:
|
|||
|
ele_invalid_labels.append(False)
|
|||
|
ele_match_labels.append(False)
|
|||
|
detail['score'] = False
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
try:
|
|||
|
pred_ele = parse_molecule(pred_formula)
|
|||
|
except KeyboardInterrupt:
|
|||
|
raise
|
|||
|
except:
|
|||
|
# print(pred_formula)
|
|||
|
# print('=====')
|
|||
|
ele_invalid_labels.append(True)
|
|||
|
ele_match_labels.append(False)
|
|||
|
detail['score'] = False
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
ele_invalid_labels.append(False)
|
|||
|
ele_match = False
|
|||
|
gold_ele = parse_molecule(gold_formula)
|
|||
|
if pred_ele == gold_ele:
|
|||
|
ele_match = True
|
|||
|
ele_match_labels.append(ele_match)
|
|||
|
detail['score'] = ele_match
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = sum(ele_match_labels) / len(predictions) * 100
|
|||
|
valid_score = 100 - sum(ele_invalid_labels) / len(predictions) * 100
|
|||
|
|
|||
|
return {'score': score, 'valid_score': valid_score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
@ICL_EVALUATORS.register_module()
|
|||
|
class NCElementMatchEvaluator(BaseEvaluator):
|
|||
|
"""Element match evaluator for name conversion."""
|
|||
|
|
|||
|
def __init__(self) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def score(self, predictions, references):
|
|||
|
print('len(predictions):', len(predictions))
|
|||
|
print('len(references):', len(references))
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {
|
|||
|
'error': 'predictions and references have different '
|
|||
|
'length'
|
|||
|
}
|
|||
|
|
|||
|
# topk的prediction,要拆开
|
|||
|
if isinstance(predictions[0], str):
|
|||
|
return calculate_single_element_match(predictions, references)
|
|||
|
else:
|
|||
|
num_k = len(predictions[0])
|
|||
|
scores = []
|
|||
|
for i in range(num_k):
|
|||
|
pred = [prediction[i] for prediction in predictions]
|
|||
|
ref = references
|
|||
|
score = calculate_single_element_match_for_list(pred, ref)
|
|||
|
scores.append(score)
|
|||
|
# 按照instance合并成一个完整的dict
|
|||
|
final_details = scores[0]['details']
|
|||
|
final_scores = [scores[0]['score']]
|
|||
|
final_valid_scores = [scores[0]['valid_score']]
|
|||
|
for _k in scores[1:]:
|
|||
|
for i, _d in enumerate(_k['details']):
|
|||
|
# print(_d)
|
|||
|
final_details[i]['pred'].extend(_d['pred'])
|
|||
|
final_details[i]['score'].extend(_d['score'])
|
|||
|
final_scores.append(_k['score'])
|
|||
|
final_valid_scores.append(_k['valid_score'])
|
|||
|
avg_score = []
|
|||
|
for _d in final_details:
|
|||
|
if True in _d['score']:
|
|||
|
avg_score.append(1)
|
|||
|
else:
|
|||
|
avg_score.append(0)
|
|||
|
max_score = sum(avg_score) / len(avg_score) * 100
|
|||
|
return {
|
|||
|
'score': max_score,
|
|||
|
'all_score': final_scores,
|
|||
|
'valid_score': final_valid_scores,
|
|||
|
'details': final_details,
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
@ICL_EVALUATORS.register_module()
|
|||
|
class NCExactMatchEvaluator(BaseEvaluator):
|
|||
|
"""Exact match evaluator for name conversion."""
|
|||
|
|
|||
|
def __init__(self) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def score(self, predictions, references):
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {
|
|||
|
'error': 'predictions and references have different '
|
|||
|
'length'
|
|||
|
}
|
|||
|
predictions = [
|
|||
|
extract_chemical_data(prediction) for prediction in predictions
|
|||
|
]
|
|||
|
references = [
|
|||
|
extract_chemical_data(reference) for reference in references
|
|||
|
]
|
|||
|
|
|||
|
cnt = 0
|
|||
|
valid_cnt = 0
|
|||
|
details = []
|
|||
|
for pred, ans in zip(predictions, references):
|
|||
|
ans = ans[0]
|
|||
|
if pred:
|
|||
|
pred = pred[0]
|
|||
|
valid_cnt += 1
|
|||
|
detail = {'pred': pred, 'answer': ans}
|
|||
|
if pred and pred.strip() == ans.strip():
|
|||
|
cnt += 1
|
|||
|
detail['correct'] = True
|
|||
|
else:
|
|||
|
detail['correct'] = False
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = cnt / len(predictions) * 100
|
|||
|
valid_score = valid_cnt / len(predictions) * 100
|
|||
|
|
|||
|
return {'score': score, 'valid_score': valid_score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
def extract_number(text):
|
|||
|
pattern = re.compile(r'<NUMBER>\s*(-?\d*\.?\d+)\s*</NUMBER>')
|
|||
|
matches = pattern.findall(text)
|
|||
|
return [float(match) for match in matches]
|
|||
|
|
|||
|
|
|||
|
@ICL_EVALUATORS.register_module()
|
|||
|
class RMSEEvaluator(BaseEvaluator):
|
|||
|
"""Exact match evaluator for name conversion."""
|
|||
|
|
|||
|
def __init__(self) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def score(self, predictions, references):
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {
|
|||
|
'error': 'predictions and references have different '
|
|||
|
'length'
|
|||
|
}
|
|||
|
|
|||
|
avg_score = 0
|
|||
|
details = []
|
|||
|
for prediction, reference in zip(predictions, references):
|
|||
|
pred = extract_number(prediction)
|
|||
|
ans = extract_number(reference)
|
|||
|
if not pred:
|
|||
|
pred = 0
|
|||
|
else:
|
|||
|
pred = pred[0]
|
|||
|
try:
|
|||
|
ans = ans[0]
|
|||
|
except:
|
|||
|
raise ValueError(f'ans: {reference}')
|
|||
|
detail = {'pred': pred, 'answer': ans}
|
|||
|
rmse_score = np.sqrt(np.mean((np.array(pred) - np.array(ans))**2))
|
|||
|
detail['score'] = rmse_score
|
|||
|
avg_score += rmse_score
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = avg_score / len(predictions)
|
|||
|
|
|||
|
return {'score': score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
@ICL_EVALUATORS.register_module()
|
|||
|
class FTSEvaluator(BaseEvaluator):
|
|||
|
"""Exact match evaluator for name conversion."""
|
|||
|
|
|||
|
def __init__(self) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def score(self, predictions, references):
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {
|
|||
|
'error': 'predictions and references have different '
|
|||
|
'length'
|
|||
|
}
|
|||
|
|
|||
|
predictions = [
|
|||
|
extract_chemical_data(prediction) for prediction in predictions
|
|||
|
]
|
|||
|
references = [
|
|||
|
extract_chemical_data(reference) for reference in references
|
|||
|
]
|
|||
|
|
|||
|
avg_score = 0
|
|||
|
valid_cnt = 0
|
|||
|
details = []
|
|||
|
for pred, ans in zip(predictions, references):
|
|||
|
ans = ans[0]
|
|||
|
if not pred:
|
|||
|
detail = {'pred': '', 'answer': ans, 'score': 0}
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
pred = pred[0]
|
|||
|
detail = {'pred': pred, 'answer': ans}
|
|||
|
# 将 SMILES 转换为 RDKit 分子对象
|
|||
|
from rdkit import Chem
|
|||
|
mol1 = Chem.MolFromSmiles(pred)
|
|||
|
mol2 = Chem.MolFromSmiles(ans)
|
|||
|
if mol1 is None or mol2 is None:
|
|||
|
detail['score'] = 0
|
|||
|
details.append(detail)
|
|||
|
continue
|
|||
|
valid_cnt += 1
|
|||
|
# 生成 Morgan 指纹(等同于 ECFP4)
|
|||
|
# fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048)
|
|||
|
# fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048)
|
|||
|
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
|
|||
|
generator = GetMorganGenerator(radius=2, fpSize=2048)
|
|||
|
fp1 = generator.GetFingerprint(mol1)
|
|||
|
fp2 = generator.GetFingerprint(mol2)
|
|||
|
from rdkit.Chem import DataStructs
|
|||
|
similarity = DataStructs.TanimotoSimilarity(fp1, fp2) * 100
|
|||
|
detail['score'] = similarity
|
|||
|
avg_score += similarity
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = avg_score / len(predictions)
|
|||
|
valid_score = valid_cnt / len(predictions) * 100
|
|||
|
|
|||
|
return {'score': score, 'valid_score': valid_score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
@ICL_EVALUATORS.register_module()
|
|||
|
class MeteorEvaluator(BaseEvaluator):
|
|||
|
"""Exact match evaluator for name conversion."""
|
|||
|
|
|||
|
def __init__(self) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def score(self, predictions, references):
|
|||
|
if len(predictions) != len(references):
|
|||
|
return {
|
|||
|
'error': 'predictions and references have different '
|
|||
|
'length'
|
|||
|
}
|
|||
|
avg_score = 0
|
|||
|
details = []
|
|||
|
for pred, ans in zip(predictions, references):
|
|||
|
score = meteor_score([ans.split()], pred.split())
|
|||
|
avg_score += score
|
|||
|
detail = {'pred': pred, 'answer': ans, 'score': score}
|
|||
|
details.append(detail)
|
|||
|
|
|||
|
score = avg_score / len(predictions)
|
|||
|
|
|||
|
return {'score': score, 'details': details}
|
|||
|
|
|||
|
|
|||
|
@TEXT_POSTPROCESSORS.register_module('smolinstruct-acc')
|
|||
|
def smolinstruct_acc_postprocess(text: str) -> str:
|
|||
|
if 'yes' in text.lower():
|
|||
|
return '<BOOLEAN> Yes </BOOLEAN>'
|
|||
|
elif 'no' in text.lower():
|
|||
|
return '<BOOLEAN> No </BOOLEAN>'
|