OpenCompass/opencompass/datasets/healthbench/drop_eval.py
2025-05-15 08:50:05 +00:00

332 lines
12 KiB
Python

"""
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner
https://arxiv.org/abs/1903.00161
"""
import gzip
import json
import random
import re
import string
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
from scipy.optimize import linear_sum_assignment
from . import common
from .common import ANSWER_PATTERN, HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
"""
From here through _normalize_answer was originally copied from:
https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
Then cleaned up and modified a bit.
The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc
/eval/drop_eval.py
"""
def _remove_articles(text: str) -> str:
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def _white_space_fix(text: str) -> str:
return ' '.join(text.split())
EXCLUDE = set(string.punctuation)
def _remove_punc(text: str) -> str:
if not _is_number(text):
return ''.join(ch for ch in text if ch not in EXCLUDE)
else:
return text
def _lower(text: str) -> str:
return text.lower()
def _tokenize(text: str) -> List[str]:
return re.split(' |-', text)
def _normalize_answer(text: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
parts = [
_white_space_fix(
_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
for token in _tokenize(text)
]
parts = [part for part in parts if part.strip()]
normalized = ' '.join(parts).strip()
return normalized
def _is_number(text: str) -> bool:
try:
float(text)
return True
except ValueError:
return False
def _normalize_number(text: str) -> str:
if _is_number(text):
return str(float(text))
else:
return text
def _answer_to_bags(
answer: Union[str, List[str], Tuple[str, ...]]
) -> Tuple[List[str], List[Set[str]]]:
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans: List[str] = []
token_bags = []
for raw_span in raw_spans:
normalized_span = _normalize_answer(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(predicted: List[Set[str]],
gold: List[Set[str]]) -> List[float]:
"""Takes gold and predicted answer sets and first finds the optimal 1-1
alignment between them and gets maximum metric values over all the
answers."""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if _match_numbers_if_present(gold_item, pred_item):
scores[gold_index,
pred_index] = _compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = ((2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0) else 0.0) * 100
return f1
def _match_numbers_if_present(gold_bag: Set[str],
predicted_bag: Set[str]) -> bool:
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if _is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if _is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
def get_drop_metrics(
predicted: Union[str, List[str], Tuple[str, ...]],
gold: Union[str, List[str], Tuple[str, ...]]) -> Tuple[float, float]:
"""Takes a predicted answer and a gold answer (that are both either a
string or a list of strings), and returns exact match and the DROP F1
metric for the prediction.
If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]) == len(gold_bags[0]):
exact_match = 1.0
else:
exact_match = 0.0
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
def answer_json_to_strings(
answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
"""Takes an answer JSON blob from the DROP data release and converts it
into strings used for evaluation."""
if 'number' in answer and answer['number']:
return tuple([str(answer['number'])]), 'number'
elif 'spans' in answer and answer['spans']:
return tuple(
answer['spans']), 'span' if len(answer['spans']) == 1 else 'spans'
elif 'date' in answer:
return (
tuple([
'{0} {1} {2}'.format(answer['date']['day'],
answer['date']['month'],
answer['date']['year']).strip()
]),
'date',
)
else:
raise ValueError(
f'Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}'
)
def answer_json_to_string(answer_json):
return json.dumps(answer_json_to_strings(answer_json))
def normalize(s: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
s = s.lower()
exclude = set(string.punctuation)
s = ''.join(char for char in s if char not in exclude)
s = re.sub(r'\b(a|an|the)\b', ' ', s)
s = ' '.join(s.split())
return s
def fuzzy_match(s1: str, s2: str) -> bool:
s1 = normalize(s1)
s2 = normalize(s2)
if s1 == '' or s2 == '':
return s1 == s2
return s1 in s2 or s2 in s1
def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:
em_scores = []
f1_scores = []
for answer in reference:
if answer.strip() != '':
em, f1 = get_drop_metrics(sample, answer)
em_scores.append(em)
f1_scores.append(f1)
return (max(em_scores), max(f1_scores))
class DropEval(Eval):
def __init__(self,
num_examples: int | None = None,
train_samples_per_prompt: int = 3):
self.seed = 42
self._num_examples = num_examples
self._train_samples_per_prompt = train_samples_per_prompt
self.train_jsonl = (
'https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz'
)
self.test_jsonl = (
'https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz'
)
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl,
binary=True),
mode='rb') as f:
self.train_samples = list(map(json.loads, f.readlines()))
with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl,
binary=True),
mode='rb') as f:
self.test_samples = list(map(json.loads, f.readlines()))
if self._num_examples:
self.test_samples = random.Random(self.seed).sample(
self.test_samples, self._num_examples)
def __call__(self, sampler: SamplerBase) -> EvalResult:
rng = random.Random(self.seed)
def fn(example: dict[str, str]):
stuffing = rng.sample(self.train_samples,
self._train_samples_per_prompt)
# prompt = """TASK: Read the provided passage, then identify the correct answer to questions below."""
prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below."""
prompt += '\n\n# Examples'
samples = stuffing + [example]
for i, sample in enumerate(samples):
is_test = i == len(stuffing)
prompt += '\n# Your Task\n' if is_test else ''
prompt += f"""
---
{sample["context"]} """
a = sample['completion']
correct_answers = sample['ref_text'].split('|')
if not is_test:
prompt += a + '\n'
else:
prompt += """\n
Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.
"""
prompt_messages = [
sampler._pack_message(content=prompt, role='user')
]
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(
1) if match else response_text
em_score, f1_score = drop_metric(extracted_answer,
correct_answers)
matches = [
fuzzy_match(extracted_answer, correct_answer)
for correct_answer in correct_answers
]
extracted_answers = [
extracted_answer for i in range(len(correct_answers))
if matches[i]
]
score = True in matches
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=extracted_answer,
role='assistant'),
score=score,
correct_answer=correct_answers,
extracted_answer=extracted_answers,
)
convo = actual_queried_prompt_messages + [
dict(content=extracted_answer, role='assistant')
]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics={
'em_score': em_score,
'f1_score': f1_score
},
)
results = common.map_with_progress(fn, self.test_samples)
return common.aggregate_results(results)