mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Use local accuracy from hf implements (#416)
* use local accuracy from hf implements * add load from hf fallback
This commit is contained in:
parent
94755f8e2f
commit
ae0cd8752f
@ -2,7 +2,8 @@ exclude: |
|
||||
(?x)^(
|
||||
tests/data/|
|
||||
opencompass/models/internal/|
|
||||
opencompass/utils/internal/
|
||||
opencompass/utils/internal/|
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/
|
||||
)
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
|
@ -2,7 +2,8 @@ exclude: |
|
||||
(?x)^(
|
||||
tests/data/|
|
||||
opencompass/models/internal/|
|
||||
opencompass/utils/internal/
|
||||
opencompass/utils/internal/|
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/
|
||||
)
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
|
106
opencompass/openicl/icl_evaluator/hf_metrics/accuracy.py
Normal file
106
opencompass/openicl/icl_evaluator/hf_metrics/accuracy.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Accuracy metric."""
|
||||
|
||||
import datasets
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
import evaluate
|
||||
|
||||
|
||||
_DESCRIPTION = """
|
||||
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
|
||||
Accuracy = (TP + TN) / (TP + TN + FP + FN)
|
||||
Where:
|
||||
TP: True positive
|
||||
TN: True negative
|
||||
FP: False positive
|
||||
FN: False negative
|
||||
"""
|
||||
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Args:
|
||||
predictions (`list` of `int`): Predicted labels.
|
||||
references (`list` of `int`): Ground truth labels.
|
||||
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
|
||||
sample_weight (`list` of `float`): Sample weights Defaults to None.
|
||||
|
||||
Returns:
|
||||
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
|
||||
|
||||
Examples:
|
||||
|
||||
Example 1-A simple example
|
||||
>>> accuracy_metric = evaluate.load("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.5}
|
||||
|
||||
Example 2-The same as Example 1, except with `normalize` set to `False`.
|
||||
>>> accuracy_metric = evaluate.load("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
|
||||
>>> print(results)
|
||||
{'accuracy': 3.0}
|
||||
|
||||
Example 3-The same as Example 1, except with `sample_weight` set.
|
||||
>>> accuracy_metric = evaluate.load("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.8778625954198473}
|
||||
"""
|
||||
|
||||
|
||||
_CITATION = """
|
||||
@article{scikit-learn,
|
||||
title={Scikit-learn: Machine Learning in {P}ython},
|
||||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
||||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
||||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
||||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
||||
journal={Journal of Machine Learning Research},
|
||||
volume={12},
|
||||
pages={2825--2830},
|
||||
year={2011}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Accuracy(evaluate.Metric):
|
||||
def _info(self):
|
||||
return evaluate.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("int32")),
|
||||
"references": datasets.Sequence(datasets.Value("int32")),
|
||||
}
|
||||
if self.config_name == "multilabel"
|
||||
else {
|
||||
"predictions": datasets.Value("int32"),
|
||||
"references": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, normalize=True, sample_weight=None):
|
||||
return {
|
||||
"accuracy": float(
|
||||
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
|
||||
)
|
||||
}
|
158
opencompass/openicl/icl_evaluator/hf_metrics/rouge.py
Normal file
158
opencompass/openicl/icl_evaluator/hf_metrics/rouge.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2020 The HuggingFace Evaluate Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ROUGE metric from Google Research github repo. """
|
||||
|
||||
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
|
||||
import absl # Here to have a nice missing dependency error message early on
|
||||
import datasets
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy # Here to have a nice missing dependency error message early on
|
||||
import six # Here to have a nice missing dependency error message early on
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
|
||||
import evaluate
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{lin-2004-rouge,
|
||||
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
|
||||
author = "Lin, Chin-Yew",
|
||||
booktitle = "Text Summarization Branches Out",
|
||||
month = jul,
|
||||
year = "2004",
|
||||
address = "Barcelona, Spain",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/W04-1013",
|
||||
pages = "74--81",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
|
||||
evaluating automatic summarization and machine translation software in natural language processing.
|
||||
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
|
||||
|
||||
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
|
||||
|
||||
This metrics is a wrapper around Google Research reimplementation of ROUGE:
|
||||
https://github.com/google-research/google-research/tree/master/rouge
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Calculates average rouge scores for a list of hypotheses and references
|
||||
Args:
|
||||
predictions: list of predictions to score. Each prediction
|
||||
should be a string with tokens separated by spaces.
|
||||
references: list of reference for each prediction. Each
|
||||
reference should be a string with tokens separated by spaces.
|
||||
rouge_types: A list of rouge types to calculate.
|
||||
Valid names:
|
||||
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
|
||||
`"rougeL"`: Longest common subsequence based scoring.
|
||||
`"rougeLsum"`: rougeLsum splits text using `"\n"`.
|
||||
See details in https://github.com/huggingface/datasets/issues/617
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
|
||||
use_aggregator: Return aggregates if this is set to True
|
||||
Returns:
|
||||
rouge1: rouge_1 (f1),
|
||||
rouge2: rouge_2 (f1),
|
||||
rougeL: rouge_l (f1),
|
||||
rougeLsum: rouge_lsum (f1)
|
||||
Examples:
|
||||
|
||||
>>> rouge = evaluate.load('rouge')
|
||||
>>> predictions = ["hello there", "general kenobi"]
|
||||
>>> references = ["hello there", "general kenobi"]
|
||||
>>> results = rouge.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}
|
||||
"""
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Helper class to wrap a callable into a class with a `tokenize` method as used by rouge-score."""
|
||||
|
||||
def __init__(self, tokenizer_func):
|
||||
self.tokenizer_func = tokenizer_func
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.tokenizer_func(text)
|
||||
|
||||
|
||||
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Rouge(evaluate.Metric):
|
||||
def _info(self):
|
||||
return evaluate.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=[
|
||||
datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Sequence(datasets.Value("string", id="sequence")),
|
||||
}
|
||||
),
|
||||
datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Value("string", id="sequence"),
|
||||
}
|
||||
),
|
||||
],
|
||||
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
|
||||
"https://github.com/google-research/google-research/tree/master/rouge",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(
|
||||
self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False, tokenizer=None
|
||||
):
|
||||
if rouge_types is None:
|
||||
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||
|
||||
multi_ref = isinstance(references[0], list)
|
||||
|
||||
if tokenizer is not None:
|
||||
tokenizer = Tokenizer(tokenizer)
|
||||
|
||||
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer, tokenizer=tokenizer)
|
||||
if use_aggregator:
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
else:
|
||||
scores = []
|
||||
|
||||
for ref, pred in zip(references, predictions):
|
||||
if multi_ref:
|
||||
score = scorer.score_multi(ref, pred)
|
||||
else:
|
||||
score = scorer.score(ref, pred)
|
||||
if use_aggregator:
|
||||
aggregator.add_scores(score)
|
||||
else:
|
||||
scores.append(score)
|
||||
|
||||
if use_aggregator:
|
||||
result = aggregator.aggregate()
|
||||
for key in result:
|
||||
result[key] = result[key].mid.fmeasure
|
||||
|
||||
else:
|
||||
result = {}
|
||||
for key in scores[0]:
|
||||
result[key] = list(score[key].fmeasure for score in scores)
|
||||
|
||||
return result
|
178
opencompass/openicl/icl_evaluator/hf_metrics/sacrebleu.py
Normal file
178
opencompass/openicl/icl_evaluator/hf_metrics/sacrebleu.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2020 The HuggingFace Evaluate Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SACREBLEU metric. """
|
||||
|
||||
import datasets
|
||||
import sacrebleu as scb
|
||||
from packaging import version
|
||||
|
||||
import evaluate
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{post-2018-call,
|
||||
title = "A Call for Clarity in Reporting {BLEU} Scores",
|
||||
author = "Post, Matt",
|
||||
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
|
||||
month = oct,
|
||||
year = "2018",
|
||||
address = "Belgium, Brussels",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/W18-6319",
|
||||
pages = "186--191",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
SacreBLEU provides hassle-free computation of shareable, comparable, and reproducible BLEU scores.
|
||||
Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text.
|
||||
It also knows all the standard test sets and handles downloading, processing, and tokenization for you.
|
||||
|
||||
See the [README.md] file at https://github.com/mjpost/sacreBLEU for more information.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Produces BLEU scores along with its sufficient statistics
|
||||
from a source against one or more references.
|
||||
|
||||
Args:
|
||||
predictions (`list` of `str`): list of translations to score. Each translation should be tokenized into a list of tokens.
|
||||
references (`list` of `list` of `str`): A list of lists of references. The contents of the first sub-list are the references for the first prediction, the contents of the second sub-list are for the second prediction, etc. Note that there must be the same number of references for each prediction (i.e. all sub-lists must be of the same length).
|
||||
smooth_method (`str`): The smoothing method to use, defaults to `'exp'`. Possible values are:
|
||||
- `'none'`: no smoothing
|
||||
- `'floor'`: increment zero counts
|
||||
- `'add-k'`: increment num/denom by k for n>1
|
||||
- `'exp'`: exponential decay
|
||||
smooth_value (`float`): The smoothing value. Only valid when `smooth_method='floor'` (in which case `smooth_value` defaults to `0.1`) or `smooth_method='add-k'` (in which case `smooth_value` defaults to `1`).
|
||||
tokenize (`str`): Tokenization method to use for BLEU. If not provided, defaults to `'zh'` for Chinese, `'ja-mecab'` for Japanese and `'13a'` (mteval) otherwise. Possible values are:
|
||||
- `'none'`: No tokenization.
|
||||
- `'zh'`: Chinese tokenization.
|
||||
- `'13a'`: mimics the `mteval-v13a` script from Moses.
|
||||
- `'intl'`: International tokenization, mimics the `mteval-v14` script from Moses
|
||||
- `'char'`: Language-agnostic character-level tokenization.
|
||||
- `'ja-mecab'`: Japanese tokenization. Uses the [MeCab tokenizer](https://pypi.org/project/mecab-python3).
|
||||
lowercase (`bool`): If `True`, lowercases the input, enabling case-insensitivity. Defaults to `False`.
|
||||
force (`bool`): If `True`, insists that your tokenized input is actually detokenized. Defaults to `False`.
|
||||
use_effective_order (`bool`): If `True`, stops including n-gram orders for which precision is 0. This should be `True`, if sentence-level BLEU will be computed. Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
'score': BLEU score,
|
||||
'counts': Counts,
|
||||
'totals': Totals,
|
||||
'precisions': Precisions,
|
||||
'bp': Brevity penalty,
|
||||
'sys_len': predictions length,
|
||||
'ref_len': reference length,
|
||||
|
||||
Examples:
|
||||
|
||||
Example 1:
|
||||
>>> predictions = ["hello there general kenobi", "foo bar foobar"]
|
||||
>>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]]
|
||||
>>> sacrebleu = evaluate.load("sacrebleu")
|
||||
>>> results = sacrebleu.compute(predictions=predictions, references=references)
|
||||
>>> print(list(results.keys()))
|
||||
['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len']
|
||||
>>> print(round(results["score"], 1))
|
||||
100.0
|
||||
|
||||
Example 2:
|
||||
>>> predictions = ["hello there general kenobi",
|
||||
... "on our way to ankh morpork"]
|
||||
>>> references = [["hello there general kenobi", "hello there !"],
|
||||
... ["goodbye ankh morpork", "ankh morpork"]]
|
||||
>>> sacrebleu = evaluate.load("sacrebleu")
|
||||
>>> results = sacrebleu.compute(predictions=predictions,
|
||||
... references=references)
|
||||
>>> print(list(results.keys()))
|
||||
['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len']
|
||||
>>> print(round(results["score"], 1))
|
||||
39.8
|
||||
"""
|
||||
|
||||
|
||||
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Sacrebleu(evaluate.Metric):
|
||||
def _info(self):
|
||||
if version.parse(scb.__version__) < version.parse("1.4.12"):
|
||||
raise ImportWarning(
|
||||
"To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n"
|
||||
'You can install it with `pip install "sacrebleu>=1.4.12"`.'
|
||||
)
|
||||
return evaluate.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
homepage="https://github.com/mjpost/sacreBLEU",
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=[
|
||||
datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
||||
}
|
||||
),
|
||||
datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Value("string", id="sequence"),
|
||||
}
|
||||
),
|
||||
],
|
||||
codebase_urls=["https://github.com/mjpost/sacreBLEU"],
|
||||
reference_urls=[
|
||||
"https://github.com/mjpost/sacreBLEU",
|
||||
"https://en.wikipedia.org/wiki/BLEU",
|
||||
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(
|
||||
self,
|
||||
predictions,
|
||||
references,
|
||||
smooth_method="exp",
|
||||
smooth_value=None,
|
||||
force=False,
|
||||
lowercase=False,
|
||||
tokenize=None,
|
||||
use_effective_order=False,
|
||||
):
|
||||
# if only one reference is provided make sure we still use list of lists
|
||||
if isinstance(references[0], str):
|
||||
references = [[ref] for ref in references]
|
||||
|
||||
references_per_prediction = len(references[0])
|
||||
if any(len(refs) != references_per_prediction for refs in references):
|
||||
raise ValueError("Sacrebleu requires the same number of references for each prediction")
|
||||
transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)]
|
||||
output = scb.corpus_bleu(
|
||||
predictions,
|
||||
transformed_references,
|
||||
smooth_method=smooth_method,
|
||||
smooth_value=smooth_value,
|
||||
force=force,
|
||||
lowercase=lowercase,
|
||||
use_effective_order=use_effective_order,
|
||||
**(dict(tokenize=tokenize) if tokenize else {}),
|
||||
)
|
||||
output_dict = {
|
||||
"score": output.score,
|
||||
"counts": output.counts,
|
||||
"totals": output.totals,
|
||||
"precisions": output.precisions,
|
||||
"bp": output.bp,
|
||||
"sys_len": output.sys_len,
|
||||
"ref_len": output.ref_len,
|
||||
}
|
||||
return output_dict
|
111
opencompass/openicl/icl_evaluator/hf_metrics/squad.py
Normal file
111
opencompass/openicl/icl_evaluator/hf_metrics/squad.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2020 The HuggingFace Evaluate Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SQuAD metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
import evaluate
|
||||
|
||||
from .compute_score import compute_score
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{Rajpurkar2016SQuAD10,
|
||||
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||
booktitle={EMNLP},
|
||||
year={2016}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """
|
||||
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
|
||||
|
||||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||
from the corresponding reading passage, or the question might be unanswerable.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes SQuAD scores (F1 and EM).
|
||||
Args:
|
||||
predictions: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair as given in the references (see below)
|
||||
- 'prediction_text': the text of the answer
|
||||
references: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair (see above),
|
||||
- 'answers': a Dict in the SQuAD dataset format
|
||||
{
|
||||
'text': list of possible texts for the answer, as a list of strings
|
||||
'answer_start': list of start positions for the answer, as a list of ints
|
||||
}
|
||||
Note that answer_start values are not taken into account to compute the metric.
|
||||
Returns:
|
||||
'exact_match': Exact match (the normalized answer exactly match the gold answer)
|
||||
'f1': The F-score of predicted tokens versus the gold answer
|
||||
Examples:
|
||||
|
||||
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> squad_metric = evaluate.load("squad")
|
||||
>>> results = squad_metric.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'exact_match': 100.0, 'f1': 100.0}
|
||||
"""
|
||||
|
||||
|
||||
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Squad(evaluate.Metric):
|
||||
def _info(self):
|
||||
return evaluate.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
|
||||
"references": {
|
||||
"id": datasets.Value("string"),
|
||||
"answers": datasets.features.Sequence(
|
||||
{
|
||||
"text": datasets.Value("string"),
|
||||
"answer_start": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
},
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references):
|
||||
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
|
||||
dataset = [
|
||||
{
|
||||
"paragraphs": [
|
||||
{
|
||||
"qas": [
|
||||
{
|
||||
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
|
||||
"id": ref["id"],
|
||||
}
|
||||
for ref in references
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
score = compute_score(dataset=dataset, predictions=pred_dict)
|
||||
return score
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
@ -72,7 +73,13 @@ class HuggingfaceEvaluator(BaseEvaluator):
|
||||
f'length. len(predictions): {len(predictions)}, '
|
||||
f'len(references): {len(references)}'
|
||||
}
|
||||
metric = evaluate.load(self.metric)
|
||||
# use codes pre-downloaded to opencompass repo, avoid downloading
|
||||
local_path = os.path.join(os.dirname(os.path.abspath(__file__)),
|
||||
'hf_metrics', self.metric + '.py')
|
||||
if os.path.exists(local_path):
|
||||
metric = evaluate.load(local_path)
|
||||
else:
|
||||
metric = evaluate.load(self.metric)
|
||||
scores = metric.compute(**self._preprocess(predictions, references))
|
||||
result = self._postprocess(scores)
|
||||
random.setstate(random_state)
|
||||
|
Loading…
Reference in New Issue
Block a user