mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
332 lines
12 KiB
Python
332 lines
12 KiB
Python
"""
|
|
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
|
|
Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner
|
|
https://arxiv.org/abs/1903.00161
|
|
"""
|
|
|
|
import gzip
|
|
import json
|
|
import random
|
|
import re
|
|
import string
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import numpy as np
|
|
from scipy.optimize import linear_sum_assignment
|
|
|
|
from . import common
|
|
from .common import ANSWER_PATTERN, HTML_JINJA
|
|
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
|
|
"""
|
|
From here through _normalize_answer was originally copied from:
|
|
https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
|
|
Then cleaned up and modified a bit.
|
|
|
|
The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc
|
|
/eval/drop_eval.py
|
|
"""
|
|
|
|
|
|
def _remove_articles(text: str) -> str:
|
|
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
|
return re.sub(regex, ' ', text)
|
|
|
|
|
|
def _white_space_fix(text: str) -> str:
|
|
return ' '.join(text.split())
|
|
|
|
|
|
EXCLUDE = set(string.punctuation)
|
|
|
|
|
|
def _remove_punc(text: str) -> str:
|
|
if not _is_number(text):
|
|
return ''.join(ch for ch in text if ch not in EXCLUDE)
|
|
else:
|
|
return text
|
|
|
|
|
|
def _lower(text: str) -> str:
|
|
return text.lower()
|
|
|
|
|
|
def _tokenize(text: str) -> List[str]:
|
|
return re.split(' |-', text)
|
|
|
|
|
|
def _normalize_answer(text: str) -> str:
|
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
|
|
|
parts = [
|
|
_white_space_fix(
|
|
_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
|
|
for token in _tokenize(text)
|
|
]
|
|
parts = [part for part in parts if part.strip()]
|
|
normalized = ' '.join(parts).strip()
|
|
return normalized
|
|
|
|
|
|
def _is_number(text: str) -> bool:
|
|
try:
|
|
float(text)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def _normalize_number(text: str) -> str:
|
|
if _is_number(text):
|
|
return str(float(text))
|
|
else:
|
|
return text
|
|
|
|
|
|
def _answer_to_bags(
|
|
answer: Union[str, List[str], Tuple[str, ...]]
|
|
) -> Tuple[List[str], List[Set[str]]]:
|
|
if isinstance(answer, (list, tuple)):
|
|
raw_spans = answer
|
|
else:
|
|
raw_spans = [answer]
|
|
normalized_spans: List[str] = []
|
|
token_bags = []
|
|
for raw_span in raw_spans:
|
|
normalized_span = _normalize_answer(raw_span)
|
|
normalized_spans.append(normalized_span)
|
|
token_bags.append(set(normalized_span.split()))
|
|
return normalized_spans, token_bags
|
|
|
|
|
|
def _align_bags(predicted: List[Set[str]],
|
|
gold: List[Set[str]]) -> List[float]:
|
|
"""Takes gold and predicted answer sets and first finds the optimal 1-1
|
|
alignment between them and gets maximum metric values over all the
|
|
answers."""
|
|
scores = np.zeros([len(gold), len(predicted)])
|
|
for gold_index, gold_item in enumerate(gold):
|
|
for pred_index, pred_item in enumerate(predicted):
|
|
if _match_numbers_if_present(gold_item, pred_item):
|
|
scores[gold_index,
|
|
pred_index] = _compute_f1(pred_item, gold_item)
|
|
row_ind, col_ind = linear_sum_assignment(-scores)
|
|
|
|
max_scores = np.zeros([max(len(gold), len(predicted))])
|
|
for row, column in zip(row_ind, col_ind):
|
|
max_scores[row] = max(max_scores[row], scores[row, column])
|
|
return max_scores
|
|
|
|
|
|
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
|
intersection = len(gold_bag.intersection(predicted_bag))
|
|
if not predicted_bag:
|
|
precision = 1.0
|
|
else:
|
|
precision = intersection / float(len(predicted_bag))
|
|
if not gold_bag:
|
|
recall = 1.0
|
|
else:
|
|
recall = intersection / float(len(gold_bag))
|
|
f1 = ((2 * precision * recall) / (precision + recall)
|
|
if not (precision == 0.0 and recall == 0.0) else 0.0) * 100
|
|
return f1
|
|
|
|
|
|
def _match_numbers_if_present(gold_bag: Set[str],
|
|
predicted_bag: Set[str]) -> bool:
|
|
gold_numbers = set()
|
|
predicted_numbers = set()
|
|
for word in gold_bag:
|
|
if _is_number(word):
|
|
gold_numbers.add(word)
|
|
for word in predicted_bag:
|
|
if _is_number(word):
|
|
predicted_numbers.add(word)
|
|
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_drop_metrics(
|
|
predicted: Union[str, List[str], Tuple[str, ...]],
|
|
gold: Union[str, List[str], Tuple[str, ...]]) -> Tuple[float, float]:
|
|
"""Takes a predicted answer and a gold answer (that are both either a
|
|
string or a list of strings), and returns exact match and the DROP F1
|
|
metric for the prediction.
|
|
|
|
If you are
|
|
writing a script for evaluating objects in memory (say, the output of predictions during
|
|
validation, or while training), this is the function you want to call, after using
|
|
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
|
|
"""
|
|
predicted_bags = _answer_to_bags(predicted)
|
|
gold_bags = _answer_to_bags(gold)
|
|
|
|
if set(predicted_bags[0]) == set(gold_bags[0]) and len(
|
|
predicted_bags[0]) == len(gold_bags[0]):
|
|
exact_match = 1.0
|
|
else:
|
|
exact_match = 0.0
|
|
|
|
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
|
f1 = np.mean(f1_per_bag)
|
|
f1 = round(f1, 2)
|
|
return exact_match, f1
|
|
|
|
|
|
def answer_json_to_strings(
|
|
answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
|
|
"""Takes an answer JSON blob from the DROP data release and converts it
|
|
into strings used for evaluation."""
|
|
if 'number' in answer and answer['number']:
|
|
return tuple([str(answer['number'])]), 'number'
|
|
elif 'spans' in answer and answer['spans']:
|
|
return tuple(
|
|
answer['spans']), 'span' if len(answer['spans']) == 1 else 'spans'
|
|
elif 'date' in answer:
|
|
return (
|
|
tuple([
|
|
'{0} {1} {2}'.format(answer['date']['day'],
|
|
answer['date']['month'],
|
|
answer['date']['year']).strip()
|
|
]),
|
|
'date',
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f'Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}'
|
|
)
|
|
|
|
|
|
def answer_json_to_string(answer_json):
|
|
return json.dumps(answer_json_to_strings(answer_json))
|
|
|
|
|
|
def normalize(s: str) -> str:
|
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
|
s = s.lower()
|
|
exclude = set(string.punctuation)
|
|
s = ''.join(char for char in s if char not in exclude)
|
|
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
|
s = ' '.join(s.split())
|
|
return s
|
|
|
|
|
|
def fuzzy_match(s1: str, s2: str) -> bool:
|
|
s1 = normalize(s1)
|
|
s2 = normalize(s2)
|
|
|
|
if s1 == '' or s2 == '':
|
|
return s1 == s2
|
|
|
|
return s1 in s2 or s2 in s1
|
|
|
|
|
|
def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:
|
|
em_scores = []
|
|
f1_scores = []
|
|
for answer in reference:
|
|
if answer.strip() != '':
|
|
em, f1 = get_drop_metrics(sample, answer)
|
|
em_scores.append(em)
|
|
f1_scores.append(f1)
|
|
return (max(em_scores), max(f1_scores))
|
|
|
|
|
|
class DropEval(Eval):
|
|
|
|
def __init__(self,
|
|
num_examples: int | None = None,
|
|
train_samples_per_prompt: int = 3):
|
|
self.seed = 42
|
|
self._num_examples = num_examples
|
|
self._train_samples_per_prompt = train_samples_per_prompt
|
|
self.train_jsonl = (
|
|
'https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz'
|
|
)
|
|
self.test_jsonl = (
|
|
'https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz'
|
|
)
|
|
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl,
|
|
binary=True),
|
|
mode='rb') as f:
|
|
self.train_samples = list(map(json.loads, f.readlines()))
|
|
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl,
|
|
binary=True),
|
|
mode='rb') as f:
|
|
self.test_samples = list(map(json.loads, f.readlines()))
|
|
if self._num_examples:
|
|
self.test_samples = random.Random(self.seed).sample(
|
|
self.test_samples, self._num_examples)
|
|
|
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
|
rng = random.Random(self.seed)
|
|
|
|
def fn(example: dict[str, str]):
|
|
stuffing = rng.sample(self.train_samples,
|
|
self._train_samples_per_prompt)
|
|
|
|
# prompt = """TASK: Read the provided passage, then identify the correct answer to questions below."""
|
|
prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below."""
|
|
prompt += '\n\n# Examples'
|
|
samples = stuffing + [example]
|
|
for i, sample in enumerate(samples):
|
|
is_test = i == len(stuffing)
|
|
prompt += '\n# Your Task\n' if is_test else ''
|
|
prompt += f"""
|
|
---
|
|
{sample["context"]} """
|
|
|
|
a = sample['completion']
|
|
correct_answers = sample['ref_text'].split('|')
|
|
|
|
if not is_test:
|
|
prompt += a + '\n'
|
|
else:
|
|
prompt += """\n
|
|
Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.
|
|
"""
|
|
prompt_messages = [
|
|
sampler._pack_message(content=prompt, role='user')
|
|
]
|
|
sampler_response = sampler(prompt_messages)
|
|
response_text = sampler_response.response_text
|
|
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
|
|
match = re.search(ANSWER_PATTERN, response_text)
|
|
extracted_answer = match.group(
|
|
1) if match else response_text
|
|
em_score, f1_score = drop_metric(extracted_answer,
|
|
correct_answers)
|
|
matches = [
|
|
fuzzy_match(extracted_answer, correct_answer)
|
|
for correct_answer in correct_answers
|
|
]
|
|
extracted_answers = [
|
|
extracted_answer for i in range(len(correct_answers))
|
|
if matches[i]
|
|
]
|
|
score = True in matches
|
|
html = common.jinja_env.from_string(HTML_JINJA).render(
|
|
prompt_messages=actual_queried_prompt_messages,
|
|
next_message=dict(content=extracted_answer,
|
|
role='assistant'),
|
|
score=score,
|
|
correct_answer=correct_answers,
|
|
extracted_answer=extracted_answers,
|
|
)
|
|
convo = actual_queried_prompt_messages + [
|
|
dict(content=extracted_answer, role='assistant')
|
|
]
|
|
return SingleEvalResult(
|
|
html=html,
|
|
score=score,
|
|
convo=convo,
|
|
metrics={
|
|
'em_score': em_score,
|
|
'f1_score': f1_score
|
|
},
|
|
)
|
|
|
|
results = common.map_with_progress(fn, self.test_samples)
|
|
return common.aggregate_results(results)
|