mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
381 lines
11 KiB
Python
381 lines
11 KiB
Python
import io
|
||
import os
|
||
from collections import defaultdict
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from multiprocessing.pool import ThreadPool
|
||
from typing import Any, Callable
|
||
|
||
import jinja2
|
||
import numpy as np
|
||
import requests
|
||
from tqdm import tqdm
|
||
|
||
from .types import EvalResult, Message, SamplerBase, SingleEvalResult
|
||
|
||
QUERY_TEMPLATE_MULTICHOICE = """
|
||
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||
|
||
{Question}
|
||
|
||
A) {A}
|
||
B) {B}
|
||
C) {C}
|
||
D) {D}
|
||
""".strip()
|
||
|
||
ANSWER_PATTERN_MULTICHOICE = r'(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?'
|
||
ANSWER_PATTERN = r'(?i)Answer\s*:\s*([^\n]+)'
|
||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||
'(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])')
|
||
# All the different ways "Answer" is written in different languages
|
||
MULTILINGUAL_ANSWER_REGEXES = [
|
||
'Answer\s*:',
|
||
'Answer\s*:', # Korean invisible character
|
||
'উত্তর\s*:',
|
||
'उत्तर\s*:',
|
||
'উত্তরঃ',
|
||
'উত্তর\s*:',
|
||
'Antwort\s*:',
|
||
'답변\s*:',
|
||
'정답\s*:',
|
||
'답\s*:',
|
||
'答案\s*:',
|
||
'答案\s*:',
|
||
'答\s*:',
|
||
'答\s*:',
|
||
'答复\s*:',
|
||
'答曰\s*:',
|
||
'الإجابة:',
|
||
'الجواب:',
|
||
'إجابة:',
|
||
'الإجابة النهائية:',
|
||
'الإجابة الصحيحة:',
|
||
'الإجابة الصحيحة هي:',
|
||
'الإجابة هي:',
|
||
'الجواب النهائي:',
|
||
'Respuesta\s*:',
|
||
'Risposta\s*:',
|
||
'答え\s*:',
|
||
'答え\s*:',
|
||
'回答\s*:',
|
||
'回答\s*:',
|
||
'解答\s*:',
|
||
'Jawaban\s*:',
|
||
'Réponse\s*:',
|
||
'Resposta\s*:',
|
||
'Jibu\s*:',
|
||
'Idahun\s*:',
|
||
'Ìdáhùn\s*:',
|
||
'Idáhùn\s*:',
|
||
'Àmọ̀nà\s*:',
|
||
'Àdáhùn\s*:',
|
||
'Ànúgọ\s*:',
|
||
'Àṣàyàn\s*:',
|
||
]
|
||
|
||
EQUALITY_TEMPLATE = r"""
|
||
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||
|
||
Examples:
|
||
|
||
Expression 1: $2x+3$
|
||
Expression 2: $3+2x$
|
||
|
||
Yes
|
||
|
||
Expression 1: 3/2
|
||
Expression 2: 1.5
|
||
|
||
Yes
|
||
|
||
Expression 1: $x^2+2x+1$
|
||
Expression 2: $y^2+2y+1$
|
||
|
||
No
|
||
|
||
Expression 1: $x^2+2x+1$
|
||
Expression 2: $(x+1)^2$
|
||
|
||
Yes
|
||
|
||
Expression 1: 3245/5
|
||
Expression 2: 649
|
||
|
||
No
|
||
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||
|
||
Expression 1: 2/(-3)
|
||
Expression 2: -2/3
|
||
|
||
Yes
|
||
(trivial simplifications are allowed)
|
||
|
||
Expression 1: 72 degrees
|
||
Expression 2: 72
|
||
|
||
Yes
|
||
(give benefit of the doubt to units)
|
||
|
||
Expression 1: 64
|
||
Expression 2: 64 square feet
|
||
|
||
Yes
|
||
(give benefit of the doubt to units)
|
||
|
||
---
|
||
|
||
YOUR TASK
|
||
|
||
|
||
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||
|
||
Expression 1: %(expression1)s
|
||
Expression 2: %(expression2)s
|
||
""".strip()
|
||
|
||
HTML_JINJA = """
|
||
<h3>Prompt conversation</h3>
|
||
{% for message in prompt_messages %}
|
||
{{ message_to_html(message) | safe }}
|
||
{% endfor %}
|
||
<h3>Sampled message</h3>
|
||
{{ message_to_html(next_message) | safe }}
|
||
<h3>Results</h3>
|
||
<p>Correct Answer: {{ correct_answer }}</p>
|
||
<p>Extracted Answer: {{ extracted_answer }}</p>
|
||
<p>Score: {{ score }}</p>
|
||
"""
|
||
|
||
|
||
def format_multichoice_question(row):
|
||
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||
|
||
|
||
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
||
prompt = EQUALITY_TEMPLATE % {'expression1': expr1, 'expression2': expr2}
|
||
sampler_response = sampler([dict(content=prompt, role='user')])
|
||
response_text = sampler_response.response_text
|
||
return response_text.lower().strip() == 'yes'
|
||
|
||
|
||
def _compute_stat(values: list, stat: str):
|
||
if stat == 'mean':
|
||
return np.mean(values)
|
||
elif stat == 'std':
|
||
return np.std(values)
|
||
elif stat == 'min':
|
||
return np.min(values)
|
||
elif stat == 'max':
|
||
return np.max(values)
|
||
elif stat == 'n_samples':
|
||
return len(values)
|
||
elif stat == 'bootstrap_std':
|
||
return np.std([
|
||
np.mean(np.random.choice(values, len(values))) for _ in range(1000)
|
||
])
|
||
else:
|
||
raise ValueError(f'Unknown {stat =}')
|
||
|
||
|
||
def aggregate_results(
|
||
single_eval_results: list[SingleEvalResult],
|
||
default_stats: tuple[str, ...] = ('mean', 'std'),
|
||
name2stats: dict[str, tuple[str]] | None = None,
|
||
) -> EvalResult:
|
||
"""Aggregate results from multiple evaluations into a single EvalResult."""
|
||
name2stats = name2stats or {}
|
||
name2values = defaultdict(list)
|
||
htmls = []
|
||
convos = []
|
||
metadata = []
|
||
for single_eval_result in single_eval_results:
|
||
for name, value in single_eval_result.metrics.items():
|
||
name2values[name].append(value)
|
||
if single_eval_result.score is not None:
|
||
name2values['score'].append(single_eval_result.score)
|
||
htmls.append(single_eval_result.html)
|
||
convos.append(single_eval_result.convo)
|
||
metadata.append(single_eval_result.example_level_metadata)
|
||
final_metrics = {}
|
||
for name, values in name2values.items():
|
||
stats = name2stats.get(name, default_stats)
|
||
for stat in stats:
|
||
key = name if stat == 'mean' else f'{name}:{stat}'
|
||
final_metrics[key] = _compute_stat(values, stat)
|
||
return EvalResult(
|
||
score=final_metrics.pop('score', None),
|
||
metrics=final_metrics,
|
||
htmls=htmls,
|
||
convos=convos,
|
||
metadata={'example_level_metadata': metadata},
|
||
)
|
||
|
||
|
||
def map_with_progress(
|
||
f: Callable,
|
||
xs: list[Any],
|
||
num_threads: int = os.cpu_count() or 10,
|
||
pbar: bool = True,
|
||
):
|
||
"""Apply f to each element of xs, using a ThreadPool, and show progress."""
|
||
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
|
||
|
||
if os.getenv('debug'):
|
||
return list(map(f, pbar_fn(xs, total=len(xs))))
|
||
else:
|
||
with ThreadPool(min(num_threads, len(xs))) as pool:
|
||
return list(pbar_fn(pool.imap(f, xs), total=len(xs)))
|
||
|
||
|
||
jinja_env = jinja2.Environment(
|
||
loader=jinja2.BaseLoader(),
|
||
undefined=jinja2.StrictUndefined,
|
||
autoescape=jinja2.select_autoescape(['html', 'xml']),
|
||
)
|
||
_message_template = """
|
||
<div class="message {{ role }}">
|
||
<div class="role">
|
||
{{ role }}
|
||
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
||
</div>
|
||
<div class="content">
|
||
<pre>{{ content }}</pre>
|
||
</div>
|
||
</div>
|
||
"""
|
||
|
||
|
||
def message_to_html(message: Message) -> str:
|
||
"""Generate HTML snippet (inside a <div>) for a message."""
|
||
return jinja_env.from_string(_message_template).render(
|
||
role=message['role'],
|
||
content=message['content'],
|
||
variant=message.get('variant', None),
|
||
)
|
||
|
||
|
||
jinja_env.globals['message_to_html'] = message_to_html
|
||
|
||
_report_template = """<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<style>
|
||
.message {
|
||
padding: 8px 16px;
|
||
margin-bottom: 8px;
|
||
border-radius: 4px;
|
||
}
|
||
.message.user {
|
||
background-color: #B2DFDB;
|
||
color: #00695C;
|
||
}
|
||
.message.assistant {
|
||
background-color: #B39DDB;
|
||
color: #4527A0;
|
||
}
|
||
.message.system {
|
||
background-color: #EEEEEE;
|
||
color: #212121;
|
||
}
|
||
.role {
|
||
font-weight: bold;
|
||
margin-bottom: 4px;
|
||
}
|
||
.variant {
|
||
color: #795548;
|
||
}
|
||
table, th, td {
|
||
border: 1px solid black;
|
||
}
|
||
pre {
|
||
white-space: pre-wrap;
|
||
}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
{% if metrics %}
|
||
<h1>Metrics</h1>
|
||
<table>
|
||
<tr>
|
||
<th>Metric</th>
|
||
<th>Value</th>
|
||
</tr>
|
||
<tr>
|
||
<td><b>Score</b></td>
|
||
<td>{{ score | float | round(3) }}</td>
|
||
</tr>
|
||
{% for name, value in metrics.items() %}
|
||
<tr>
|
||
<td>{{ name }}</td>
|
||
<td>{{ value }}</td>
|
||
</tr>
|
||
{% endfor %}
|
||
</table>
|
||
{% endif %}
|
||
<h1>Examples</h1>
|
||
{% for html in htmls %}
|
||
{{ html | safe }}
|
||
<hr>
|
||
{% endfor %}
|
||
</body>
|
||
</html>
|
||
"""
|
||
|
||
|
||
def make_report(eval_result: EvalResult) -> str:
|
||
"""Create a standalone HTML report from an EvalResult."""
|
||
return jinja_env.from_string(_report_template).render(
|
||
score=eval_result.score,
|
||
metrics=eval_result.metrics,
|
||
htmls=eval_result.htmls,
|
||
)
|
||
|
||
|
||
def make_report_from_example_htmls(htmls: list[str]):
|
||
"""Create a standalone HTML report from a list of example htmls."""
|
||
return jinja_env.from_string(_report_template).render(score=None,
|
||
metrics={},
|
||
htmls=htmls)
|
||
|
||
|
||
def normalize_response(response: str) -> str:
|
||
"""Normalize the response by removing markdown and LaTeX formatting that
|
||
may prevent a match."""
|
||
|
||
return (response.replace('**', '').replace('$\\boxed{', '').replace(
|
||
'}$', '').replace('\\$', '').replace('$\\text{', '').replace(
|
||
'$', '').replace('\\mathrm{', '').replace('\\{', '').replace(
|
||
'\\text',
|
||
'').replace('\\(',
|
||
'').replace('\\mathbf{',
|
||
'').replace('{',
|
||
'').replace('\\boxed', ''))
|
||
|
||
|
||
def normalize_extracted_answer(extracted_answer: str) -> str:
|
||
return (
|
||
# In arabic these are the letters used for A-D in multiple choice questions
|
||
extracted_answer.replace('أ', ' A').replace('ب', ' B').replace(
|
||
'ج', ' C').replace('د', ' D')
|
||
# In Bengali these are the letters used for A-D in multiple choice questions
|
||
.replace('অ', ' A').replace('ব',
|
||
' B').replace('ড',
|
||
' C').replace('ঢ', ' D')
|
||
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
|
||
.replace('A', ' A').replace('B',
|
||
' B').replace('C',
|
||
' C').replace('D',
|
||
' D').strip())
|
||
|
||
|
||
def url_to_fileobj(url: str, binary=False) -> Any:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
return io.BytesIO(response.content) if binary else io.StringIO(
|
||
response.text)
|
||
|
||
|
||
def has_only_user_assistant_messages(messages: list[Message]) -> bool:
|
||
"""Check if the messages only contain user and assistant messages."""
|
||
return all(m['role'] in ('user', 'assistant') for m in messages)
|