mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Evaluating acc based on minimum edit distance, update SIQA (#130)
* [Feature] Support evaluating acc based on minimum edit distance, update SIQA * update
This commit is contained in:
parent
e9b7b8ab02
commit
c00179d46b
@ -1,13 +1,12 @@
|
|||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
from opencompass.openicl.icl_evaluator import EDAccEvaluator
|
||||||
from opencompass.datasets import siqaDataset_V2
|
from opencompass.datasets import siqaDataset_V2
|
||||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
|
||||||
|
|
||||||
siqa_reader_cfg = dict(
|
siqa_reader_cfg = dict(
|
||||||
input_columns=["context", "question", "answerA", "answerB", "answerC"],
|
input_columns=["context", "question", "answerA", "answerB", "answerC"],
|
||||||
output_column="label",
|
output_column="all_labels",
|
||||||
test_split="validation")
|
test_split="validation")
|
||||||
|
|
||||||
siqa_infer_cfg = dict(
|
siqa_infer_cfg = dict(
|
||||||
@ -27,9 +26,8 @@ siqa_infer_cfg = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
siqa_eval_cfg = dict(
|
siqa_eval_cfg = dict(
|
||||||
evaluator=dict(type=AccEvaluator),
|
evaluator=dict(type=EDAccEvaluator),
|
||||||
pred_role="BOT",
|
pred_role="BOT",
|
||||||
pred_postprocessor=dict(type=first_capital_postprocess),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
siqa_datasets = [
|
siqa_datasets = [
|
||||||
|
@ -13,6 +13,15 @@ class siqaDataset_V2(BaseDataset):
|
|||||||
dataset = load_dataset(**kwargs)
|
dataset = load_dataset(**kwargs)
|
||||||
|
|
||||||
def preprocess(example):
|
def preprocess(example):
|
||||||
|
example['all_labels'] = {
|
||||||
|
'candidates': [
|
||||||
|
f'A. {example["answerA"]}',
|
||||||
|
f'B. {example["answerB"]}',
|
||||||
|
f'C. {example["answerC"]}',
|
||||||
|
],
|
||||||
|
'label':
|
||||||
|
int(example['label']) - 1
|
||||||
|
}
|
||||||
example['label'] = ' ABC'[int(example['label'])]
|
example['label'] = ' ABC'[int(example['label'])]
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
@ -208,3 +208,52 @@ class SquadEvaluator(HuggingfaceEvaluator):
|
|||||||
dict: postprocessed scores.
|
dict: postprocessed scores.
|
||||||
"""
|
"""
|
||||||
return scores['f1']
|
return scores['f1']
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_EVALUATORS.register_module()
|
||||||
|
class EDAccEvaluator(AccEvaluator):
|
||||||
|
"""Edit distance based accuracy evaluator.
|
||||||
|
|
||||||
|
This implementation requires the un-postprocessed outputs from the model,
|
||||||
|
and the reference list where each item is structured as:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{
|
||||||
|
'candidates': [], # a list of informative answer candidates
|
||||||
|
'label': 0, # the index of the gold answer
|
||||||
|
}
|
||||||
|
|
||||||
|
It always matches the model's output to a valid answer with the citerion
|
||||||
|
as the minimum editing distance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
from rapidfuzz.distance import Levenshtein
|
||||||
|
self.dist = Levenshtein.distance
|
||||||
|
|
||||||
|
def _preprocess(self, predictions: List, references: List) -> dict:
|
||||||
|
"""Preprocess the final predictions and references to needed format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions (List): List of predictions of each sample.
|
||||||
|
references (List): List of targets for each sample.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: preprocessed results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
golds = []
|
||||||
|
|
||||||
|
for i in range(len(predictions)):
|
||||||
|
pred, ref = predictions[i], references[i]
|
||||||
|
dists = [self.dist(pred, cand) for cand in ref['candidates']]
|
||||||
|
preds.append(np.argmin(dists))
|
||||||
|
golds.append(ref['label'])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'predictions': preds,
|
||||||
|
'references': golds,
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user