OpenCompass/opencompass/datasets/healthbench/common.py
2025-05-15 08:50:05 +00:00

381 lines
11 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing.pool import ThreadPool
from typing import Any, Callable
import jinja2
import numpy as np
import requests
from tqdm import tqdm
from .types import EvalResult, Message, SamplerBase, SingleEvalResult
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
ANSWER_PATTERN_MULTICHOICE = r'(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?'
ANSWER_PATTERN = r'(?i)Answer\s*:\s*([^\n]+)'
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
'(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])')
# All the different ways "Answer" is written in different languages
MULTILINGUAL_ANSWER_REGEXES = [
'Answer\s*:',
'Answer\s*:', # Korean invisible character
'উত্তর\s*:',
'उत्तर\s*:',
'উত্তরঃ',
'উত্তর\s*:',
'Antwort\s*:',
'답변\s*:',
'정답\s*:',
'\s*:',
'答案\s*',
'答案\s*:',
'\s*',
'\s*:',
'答复\s*',
'答曰\s*',
'الإجابة:',
'الجواب:',
'إجابة:',
'الإجابة النهائية:',
'الإجابة الصحيحة:',
'الإجابة الصحيحة هي:',
'الإجابة هي:',
'الجواب النهائي:',
'Respuesta\s*:',
'Risposta\s*:',
'答え\s*:',
'答え\s*',
'回答\s*:',
'回答\s*',
'解答\s*:',
'Jawaban\s*:',
'Réponse\s*:',
'Resposta\s*:',
'Jibu\s*:',
'Idahun\s*:',
'Ìdáhùn\s*:',
'Idáhùn\s*:',
'Àmọ̀nà\s*:',
'Àdáhùn\s*:',
'Ànúgọ\s*:',
'Àṣàyàn\s*:',
]
EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Yes
Expression 1: 3/2
Expression 2: 1.5
Yes
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
No
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Yes
Expression 1: 3245/5
Expression 2: 649
No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Yes
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Yes
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Yes
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s
Expression 2: %(expression2)s
""".strip()
HTML_JINJA = """
<h3>Prompt conversation</h3>
{% for message in prompt_messages %}
{{ message_to_html(message) | safe }}
{% endfor %}
<h3>Sampled message</h3>
{{ message_to_html(next_message) | safe }}
<h3>Results</h3>
<p>Correct Answer: {{ correct_answer }}</p>
<p>Extracted Answer: {{ extracted_answer }}</p>
<p>Score: {{ score }}</p>
"""
def format_multichoice_question(row):
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
prompt = EQUALITY_TEMPLATE % {'expression1': expr1, 'expression2': expr2}
sampler_response = sampler([dict(content=prompt, role='user')])
response_text = sampler_response.response_text
return response_text.lower().strip() == 'yes'
def _compute_stat(values: list, stat: str):
if stat == 'mean':
return np.mean(values)
elif stat == 'std':
return np.std(values)
elif stat == 'min':
return np.min(values)
elif stat == 'max':
return np.max(values)
elif stat == 'n_samples':
return len(values)
elif stat == 'bootstrap_std':
return np.std([
np.mean(np.random.choice(values, len(values))) for _ in range(1000)
])
else:
raise ValueError(f'Unknown {stat =}')
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str, ...] = ('mean', 'std'),
name2stats: dict[str, tuple[str]] | None = None,
) -> EvalResult:
"""Aggregate results from multiple evaluations into a single EvalResult."""
name2stats = name2stats or {}
name2values = defaultdict(list)
htmls = []
convos = []
metadata = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values['score'].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
metadata.append(single_eval_result.example_level_metadata)
final_metrics = {}
for name, values in name2values.items():
stats = name2stats.get(name, default_stats)
for stat in stats:
key = name if stat == 'mean' else f'{name}:{stat}'
final_metrics[key] = _compute_stat(values, stat)
return EvalResult(
score=final_metrics.pop('score', None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
metadata={'example_level_metadata': metadata},
)
def map_with_progress(
f: Callable,
xs: list[Any],
num_threads: int = os.cpu_count() or 10,
pbar: bool = True,
):
"""Apply f to each element of xs, using a ThreadPool, and show progress."""
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
if os.getenv('debug'):
return list(map(f, pbar_fn(xs, total=len(xs))))
else:
with ThreadPool(min(num_threads, len(xs))) as pool:
return list(pbar_fn(pool.imap(f, xs), total=len(xs)))
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
undefined=jinja2.StrictUndefined,
autoescape=jinja2.select_autoescape(['html', 'xml']),
)
_message_template = """
<div class="message {{ role }}">
<div class="role">
{{ role }}
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
</div>
<div class="content">
<pre>{{ content }}</pre>
</div>
</div>
"""
def message_to_html(message: Message) -> str:
"""Generate HTML snippet (inside a <div>) for a message."""
return jinja_env.from_string(_message_template).render(
role=message['role'],
content=message['content'],
variant=message.get('variant', None),
)
jinja_env.globals['message_to_html'] = message_to_html
_report_template = """<!DOCTYPE html>
<html>
<head>
<style>
.message {
padding: 8px 16px;
margin-bottom: 8px;
border-radius: 4px;
}
.message.user {
background-color: #B2DFDB;
color: #00695C;
}
.message.assistant {
background-color: #B39DDB;
color: #4527A0;
}
.message.system {
background-color: #EEEEEE;
color: #212121;
}
.role {
font-weight: bold;
margin-bottom: 4px;
}
.variant {
color: #795548;
}
table, th, td {
border: 1px solid black;
}
pre {
white-space: pre-wrap;
}
</style>
</head>
<body>
{% if metrics %}
<h1>Metrics</h1>
<table>
<tr>
<th>Metric</th>
<th>Value</th>
</tr>
<tr>
<td><b>Score</b></td>
<td>{{ score | float | round(3) }}</td>
</tr>
{% for name, value in metrics.items() %}
<tr>
<td>{{ name }}</td>
<td>{{ value }}</td>
</tr>
{% endfor %}
</table>
{% endif %}
<h1>Examples</h1>
{% for html in htmls %}
{{ html | safe }}
<hr>
{% endfor %}
</body>
</html>
"""
def make_report(eval_result: EvalResult) -> str:
"""Create a standalone HTML report from an EvalResult."""
return jinja_env.from_string(_report_template).render(
score=eval_result.score,
metrics=eval_result.metrics,
htmls=eval_result.htmls,
)
def make_report_from_example_htmls(htmls: list[str]):
"""Create a standalone HTML report from a list of example htmls."""
return jinja_env.from_string(_report_template).render(score=None,
metrics={},
htmls=htmls)
def normalize_response(response: str) -> str:
"""Normalize the response by removing markdown and LaTeX formatting that
may prevent a match."""
return (response.replace('**', '').replace('$\\boxed{', '').replace(
'}$', '').replace('\\$', '').replace('$\\text{', '').replace(
'$', '').replace('\\mathrm{', '').replace('\\{', '').replace(
'\\text',
'').replace('\\(',
'').replace('\\mathbf{',
'').replace('{',
'').replace('\\boxed', ''))
def normalize_extracted_answer(extracted_answer: str) -> str:
return (
# In arabic these are the letters used for A-D in multiple choice questions
extracted_answer.replace('أ', ' A').replace('ب', ' B').replace(
'ج', ' C').replace('د', ' D')
# In Bengali these are the letters used for A-D in multiple choice questions
.replace('', ' A').replace('',
' B').replace('',
' C').replace('', ' D')
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
.replace('', ' A').replace('',
' B').replace('',
' C').replace('',
' D').strip())
def url_to_fileobj(url: str, binary=False) -> Any:
response = requests.get(url)
response.raise_for_status()
return io.BytesIO(response.content) if binary else io.StringIO(
response.text)
def has_only_user_assistant_messages(messages: list[Message]) -> bool:
"""Check if the messages only contain user and assistant messages."""
return all(m['role'] in ('user', 'assistant') for m in messages)