mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
from opencompass.registry import ICL_EVALUATORS
|
|
|
|
from .icl_base_evaluator import BaseEvaluator
|
|
|
|
|
|
@ICL_EVALUATORS.register_module()
|
|
class AUCROCEvaluator(BaseEvaluator):
|
|
"""Calculate AUC-ROC scores and accuracy according the prediction.
|
|
|
|
For some dataset, the accuracy cannot reveal the difference between
|
|
models because of the saturation. AUC-ROC scores can further exam
|
|
model abilities to distinguish different labels. More details can refer to
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
|
""" # noqa
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def score(self, predictions: List, references: List) -> dict:
|
|
"""Calculate scores and accuracy.
|
|
|
|
Args:
|
|
predictions (List): List of probabilities for each class of each
|
|
sample.
|
|
references (List): List of target labels for each sample.
|
|
|
|
Returns:
|
|
dict: calculated scores.
|
|
"""
|
|
if len(predictions) != len(references):
|
|
return {
|
|
'error': 'predictions and references have different length.'
|
|
}
|
|
auc_score = roc_auc_score(references, np.array(predictions)[:, 1])
|
|
accuracy = sum(
|
|
references == np.argmax(predictions, axis=1)) / len(references)
|
|
return dict(auc_score=auc_score * 100, accuracy=accuracy * 100)
|