mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
![]() |
from typing import List
|
||
|
import numpy as np
|
||
|
from sklearn.metrics import roc_auc_score
|
||
|
|
||
|
from opencompass.registry import ICL_EVALUATORS
|
||
|
|
||
|
from .icl_base_evaluator import BaseEvaluator
|
||
|
|
||
|
|
||
|
@ICL_EVALUATORS.register_module()
|
||
|
class AUCROCEvaluator(BaseEvaluator):
|
||
|
"""Calculate AUC-ROC scores and accuracy according the prediction.
|
||
|
|
||
|
For some dataset, the accuracy cannot reveal the difference between
|
||
|
models because of the saturation. AUC-ROC scores can further exam
|
||
|
model abilities to distinguish different labels. More details can refer to
|
||
|
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
||
|
""" # noqa
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
def score(self, predictions: List, references: List) -> dict:
|
||
|
"""Calculate scores and accuracy.
|
||
|
|
||
|
Args:
|
||
|
predictions (List): List of probabilities for each class of each
|
||
|
sample.
|
||
|
references (List): List of target labels for each sample.
|
||
|
|
||
|
Returns:
|
||
|
dict: calculated scores.
|
||
|
"""
|
||
|
if len(predictions) != len(references):
|
||
|
return {
|
||
|
'error': 'predictions and references have different length.'
|
||
|
}
|
||
|
auc_score = roc_auc_score(references, np.array(predictions)[:, 1])
|
||
|
accuracy = sum(
|
||
|
references == np.argmax(predictions, axis=1)) / len(references)
|
||
|
return dict(auc_score=auc_score * 100, accuracy=accuracy * 100)
|