mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
381 lines
11 KiB
Python
381 lines
11 KiB
Python
![]() |
import io
|
|||
|
import os
|
|||
|
from collections import defaultdict
|
|||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
from multiprocessing.pool import ThreadPool
|
|||
|
from typing import Any, Callable
|
|||
|
|
|||
|
import jinja2
|
|||
|
import numpy as np
|
|||
|
import requests
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
from .types import EvalResult, Message, SamplerBase, SingleEvalResult
|
|||
|
|
|||
|
QUERY_TEMPLATE_MULTICHOICE = """
|
|||
|
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
|||
|
|
|||
|
{Question}
|
|||
|
|
|||
|
A) {A}
|
|||
|
B) {B}
|
|||
|
C) {C}
|
|||
|
D) {D}
|
|||
|
""".strip()
|
|||
|
|
|||
|
ANSWER_PATTERN_MULTICHOICE = r'(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?'
|
|||
|
ANSWER_PATTERN = r'(?i)Answer\s*:\s*([^\n]+)'
|
|||
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
|||
|
'(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])')
|
|||
|
# All the different ways "Answer" is written in different languages
|
|||
|
MULTILINGUAL_ANSWER_REGEXES = [
|
|||
|
'Answer\s*:',
|
|||
|
'Answer\s*:', # Korean invisible character
|
|||
|
'উত্তর\s*:',
|
|||
|
'उत्तर\s*:',
|
|||
|
'উত্তরঃ',
|
|||
|
'উত্তর\s*:',
|
|||
|
'Antwort\s*:',
|
|||
|
'답변\s*:',
|
|||
|
'정답\s*:',
|
|||
|
'답\s*:',
|
|||
|
'答案\s*:',
|
|||
|
'答案\s*:',
|
|||
|
'答\s*:',
|
|||
|
'答\s*:',
|
|||
|
'答复\s*:',
|
|||
|
'答曰\s*:',
|
|||
|
'الإجابة:',
|
|||
|
'الجواب:',
|
|||
|
'إجابة:',
|
|||
|
'الإجابة النهائية:',
|
|||
|
'الإجابة الصحيحة:',
|
|||
|
'الإجابة الصحيحة هي:',
|
|||
|
'الإجابة هي:',
|
|||
|
'الجواب النهائي:',
|
|||
|
'Respuesta\s*:',
|
|||
|
'Risposta\s*:',
|
|||
|
'答え\s*:',
|
|||
|
'答え\s*:',
|
|||
|
'回答\s*:',
|
|||
|
'回答\s*:',
|
|||
|
'解答\s*:',
|
|||
|
'Jawaban\s*:',
|
|||
|
'Réponse\s*:',
|
|||
|
'Resposta\s*:',
|
|||
|
'Jibu\s*:',
|
|||
|
'Idahun\s*:',
|
|||
|
'Ìdáhùn\s*:',
|
|||
|
'Idáhùn\s*:',
|
|||
|
'Àmọ̀nà\s*:',
|
|||
|
'Àdáhùn\s*:',
|
|||
|
'Ànúgọ\s*:',
|
|||
|
'Àṣàyàn\s*:',
|
|||
|
]
|
|||
|
|
|||
|
EQUALITY_TEMPLATE = r"""
|
|||
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
|||
|
|
|||
|
Examples:
|
|||
|
|
|||
|
Expression 1: $2x+3$
|
|||
|
Expression 2: $3+2x$
|
|||
|
|
|||
|
Yes
|
|||
|
|
|||
|
Expression 1: 3/2
|
|||
|
Expression 2: 1.5
|
|||
|
|
|||
|
Yes
|
|||
|
|
|||
|
Expression 1: $x^2+2x+1$
|
|||
|
Expression 2: $y^2+2y+1$
|
|||
|
|
|||
|
No
|
|||
|
|
|||
|
Expression 1: $x^2+2x+1$
|
|||
|
Expression 2: $(x+1)^2$
|
|||
|
|
|||
|
Yes
|
|||
|
|
|||
|
Expression 1: 3245/5
|
|||
|
Expression 2: 649
|
|||
|
|
|||
|
No
|
|||
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
|||
|
|
|||
|
Expression 1: 2/(-3)
|
|||
|
Expression 2: -2/3
|
|||
|
|
|||
|
Yes
|
|||
|
(trivial simplifications are allowed)
|
|||
|
|
|||
|
Expression 1: 72 degrees
|
|||
|
Expression 2: 72
|
|||
|
|
|||
|
Yes
|
|||
|
(give benefit of the doubt to units)
|
|||
|
|
|||
|
Expression 1: 64
|
|||
|
Expression 2: 64 square feet
|
|||
|
|
|||
|
Yes
|
|||
|
(give benefit of the doubt to units)
|
|||
|
|
|||
|
---
|
|||
|
|
|||
|
YOUR TASK
|
|||
|
|
|||
|
|
|||
|
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
|||
|
|
|||
|
Expression 1: %(expression1)s
|
|||
|
Expression 2: %(expression2)s
|
|||
|
""".strip()
|
|||
|
|
|||
|
HTML_JINJA = """
|
|||
|
<h3>Prompt conversation</h3>
|
|||
|
{% for message in prompt_messages %}
|
|||
|
{{ message_to_html(message) | safe }}
|
|||
|
{% endfor %}
|
|||
|
<h3>Sampled message</h3>
|
|||
|
{{ message_to_html(next_message) | safe }}
|
|||
|
<h3>Results</h3>
|
|||
|
<p>Correct Answer: {{ correct_answer }}</p>
|
|||
|
<p>Extracted Answer: {{ extracted_answer }}</p>
|
|||
|
<p>Score: {{ score }}</p>
|
|||
|
"""
|
|||
|
|
|||
|
|
|||
|
def format_multichoice_question(row):
|
|||
|
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
|||
|
|
|||
|
|
|||
|
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
|||
|
prompt = EQUALITY_TEMPLATE % {'expression1': expr1, 'expression2': expr2}
|
|||
|
sampler_response = sampler([dict(content=prompt, role='user')])
|
|||
|
response_text = sampler_response.response_text
|
|||
|
return response_text.lower().strip() == 'yes'
|
|||
|
|
|||
|
|
|||
|
def _compute_stat(values: list, stat: str):
|
|||
|
if stat == 'mean':
|
|||
|
return np.mean(values)
|
|||
|
elif stat == 'std':
|
|||
|
return np.std(values)
|
|||
|
elif stat == 'min':
|
|||
|
return np.min(values)
|
|||
|
elif stat == 'max':
|
|||
|
return np.max(values)
|
|||
|
elif stat == 'n_samples':
|
|||
|
return len(values)
|
|||
|
elif stat == 'bootstrap_std':
|
|||
|
return np.std([
|
|||
|
np.mean(np.random.choice(values, len(values))) for _ in range(1000)
|
|||
|
])
|
|||
|
else:
|
|||
|
raise ValueError(f'Unknown {stat =}')
|
|||
|
|
|||
|
|
|||
|
def aggregate_results(
|
|||
|
single_eval_results: list[SingleEvalResult],
|
|||
|
default_stats: tuple[str, ...] = ('mean', 'std'),
|
|||
|
name2stats: dict[str, tuple[str]] | None = None,
|
|||
|
) -> EvalResult:
|
|||
|
"""Aggregate results from multiple evaluations into a single EvalResult."""
|
|||
|
name2stats = name2stats or {}
|
|||
|
name2values = defaultdict(list)
|
|||
|
htmls = []
|
|||
|
convos = []
|
|||
|
metadata = []
|
|||
|
for single_eval_result in single_eval_results:
|
|||
|
for name, value in single_eval_result.metrics.items():
|
|||
|
name2values[name].append(value)
|
|||
|
if single_eval_result.score is not None:
|
|||
|
name2values['score'].append(single_eval_result.score)
|
|||
|
htmls.append(single_eval_result.html)
|
|||
|
convos.append(single_eval_result.convo)
|
|||
|
metadata.append(single_eval_result.example_level_metadata)
|
|||
|
final_metrics = {}
|
|||
|
for name, values in name2values.items():
|
|||
|
stats = name2stats.get(name, default_stats)
|
|||
|
for stat in stats:
|
|||
|
key = name if stat == 'mean' else f'{name}:{stat}'
|
|||
|
final_metrics[key] = _compute_stat(values, stat)
|
|||
|
return EvalResult(
|
|||
|
score=final_metrics.pop('score', None),
|
|||
|
metrics=final_metrics,
|
|||
|
htmls=htmls,
|
|||
|
convos=convos,
|
|||
|
metadata={'example_level_metadata': metadata},
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def map_with_progress(
|
|||
|
f: Callable,
|
|||
|
xs: list[Any],
|
|||
|
num_threads: int = os.cpu_count() or 10,
|
|||
|
pbar: bool = True,
|
|||
|
):
|
|||
|
"""Apply f to each element of xs, using a ThreadPool, and show progress."""
|
|||
|
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
|
|||
|
|
|||
|
if os.getenv('debug'):
|
|||
|
return list(map(f, pbar_fn(xs, total=len(xs))))
|
|||
|
else:
|
|||
|
with ThreadPool(min(num_threads, len(xs))) as pool:
|
|||
|
return list(pbar_fn(pool.imap(f, xs), total=len(xs)))
|
|||
|
|
|||
|
|
|||
|
jinja_env = jinja2.Environment(
|
|||
|
loader=jinja2.BaseLoader(),
|
|||
|
undefined=jinja2.StrictUndefined,
|
|||
|
autoescape=jinja2.select_autoescape(['html', 'xml']),
|
|||
|
)
|
|||
|
_message_template = """
|
|||
|
<div class="message {{ role }}">
|
|||
|
<div class="role">
|
|||
|
{{ role }}
|
|||
|
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
|||
|
</div>
|
|||
|
<div class="content">
|
|||
|
<pre>{{ content }}</pre>
|
|||
|
</div>
|
|||
|
</div>
|
|||
|
"""
|
|||
|
|
|||
|
|
|||
|
def message_to_html(message: Message) -> str:
|
|||
|
"""Generate HTML snippet (inside a <div>) for a message."""
|
|||
|
return jinja_env.from_string(_message_template).render(
|
|||
|
role=message['role'],
|
|||
|
content=message['content'],
|
|||
|
variant=message.get('variant', None),
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
jinja_env.globals['message_to_html'] = message_to_html
|
|||
|
|
|||
|
_report_template = """<!DOCTYPE html>
|
|||
|
<html>
|
|||
|
<head>
|
|||
|
<style>
|
|||
|
.message {
|
|||
|
padding: 8px 16px;
|
|||
|
margin-bottom: 8px;
|
|||
|
border-radius: 4px;
|
|||
|
}
|
|||
|
.message.user {
|
|||
|
background-color: #B2DFDB;
|
|||
|
color: #00695C;
|
|||
|
}
|
|||
|
.message.assistant {
|
|||
|
background-color: #B39DDB;
|
|||
|
color: #4527A0;
|
|||
|
}
|
|||
|
.message.system {
|
|||
|
background-color: #EEEEEE;
|
|||
|
color: #212121;
|
|||
|
}
|
|||
|
.role {
|
|||
|
font-weight: bold;
|
|||
|
margin-bottom: 4px;
|
|||
|
}
|
|||
|
.variant {
|
|||
|
color: #795548;
|
|||
|
}
|
|||
|
table, th, td {
|
|||
|
border: 1px solid black;
|
|||
|
}
|
|||
|
pre {
|
|||
|
white-space: pre-wrap;
|
|||
|
}
|
|||
|
</style>
|
|||
|
</head>
|
|||
|
<body>
|
|||
|
{% if metrics %}
|
|||
|
<h1>Metrics</h1>
|
|||
|
<table>
|
|||
|
<tr>
|
|||
|
<th>Metric</th>
|
|||
|
<th>Value</th>
|
|||
|
</tr>
|
|||
|
<tr>
|
|||
|
<td><b>Score</b></td>
|
|||
|
<td>{{ score | float | round(3) }}</td>
|
|||
|
</tr>
|
|||
|
{% for name, value in metrics.items() %}
|
|||
|
<tr>
|
|||
|
<td>{{ name }}</td>
|
|||
|
<td>{{ value }}</td>
|
|||
|
</tr>
|
|||
|
{% endfor %}
|
|||
|
</table>
|
|||
|
{% endif %}
|
|||
|
<h1>Examples</h1>
|
|||
|
{% for html in htmls %}
|
|||
|
{{ html | safe }}
|
|||
|
<hr>
|
|||
|
{% endfor %}
|
|||
|
</body>
|
|||
|
</html>
|
|||
|
"""
|
|||
|
|
|||
|
|
|||
|
def make_report(eval_result: EvalResult) -> str:
|
|||
|
"""Create a standalone HTML report from an EvalResult."""
|
|||
|
return jinja_env.from_string(_report_template).render(
|
|||
|
score=eval_result.score,
|
|||
|
metrics=eval_result.metrics,
|
|||
|
htmls=eval_result.htmls,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def make_report_from_example_htmls(htmls: list[str]):
|
|||
|
"""Create a standalone HTML report from a list of example htmls."""
|
|||
|
return jinja_env.from_string(_report_template).render(score=None,
|
|||
|
metrics={},
|
|||
|
htmls=htmls)
|
|||
|
|
|||
|
|
|||
|
def normalize_response(response: str) -> str:
|
|||
|
"""Normalize the response by removing markdown and LaTeX formatting that
|
|||
|
may prevent a match."""
|
|||
|
|
|||
|
return (response.replace('**', '').replace('$\\boxed{', '').replace(
|
|||
|
'}$', '').replace('\\$', '').replace('$\\text{', '').replace(
|
|||
|
'$', '').replace('\\mathrm{', '').replace('\\{', '').replace(
|
|||
|
'\\text',
|
|||
|
'').replace('\\(',
|
|||
|
'').replace('\\mathbf{',
|
|||
|
'').replace('{',
|
|||
|
'').replace('\\boxed', ''))
|
|||
|
|
|||
|
|
|||
|
def normalize_extracted_answer(extracted_answer: str) -> str:
|
|||
|
return (
|
|||
|
# In arabic these are the letters used for A-D in multiple choice questions
|
|||
|
extracted_answer.replace('أ', ' A').replace('ب', ' B').replace(
|
|||
|
'ج', ' C').replace('د', ' D')
|
|||
|
# In Bengali these are the letters used for A-D in multiple choice questions
|
|||
|
.replace('অ', ' A').replace('ব',
|
|||
|
' B').replace('ড',
|
|||
|
' C').replace('ঢ', ' D')
|
|||
|
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
|
|||
|
.replace('A', ' A').replace('B',
|
|||
|
' B').replace('C',
|
|||
|
' C').replace('D',
|
|||
|
' D').strip())
|
|||
|
|
|||
|
|
|||
|
def url_to_fileobj(url: str, binary=False) -> Any:
|
|||
|
response = requests.get(url)
|
|||
|
response.raise_for_status()
|
|||
|
return io.BytesIO(response.content) if binary else io.StringIO(
|
|||
|
response.text)
|
|||
|
|
|||
|
|
|||
|
def has_only_user_assistant_messages(messages: list[Message]) -> bool:
|
|||
|
"""Check if the messages only contain user and assistant messages."""
|
|||
|
return all(m['role'] in ('user', 'assistant') for m in messages)
|