import io import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from multiprocessing.pool import ThreadPool from typing import Any, Callable import jinja2 import numpy as np import requests from tqdm import tqdm from .types import EvalResult, Message, SamplerBase, SingleEvalResult QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. {Question} A) {A} B) {B} C) {C} D) {D} """.strip() ANSWER_PATTERN_MULTICHOICE = r'(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?' ANSWER_PATTERN = r'(?i)Answer\s*:\s*([^\n]+)' MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( '(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])') # All the different ways "Answer" is written in different languages MULTILINGUAL_ANSWER_REGEXES = [ 'Answer\s*:', 'Answer\s*:​​​​​​', # Korean invisible character 'উত্তর\s*:', 'उत्तर\s*:', 'উত্তরঃ', 'উত্তর\s*:', 'Antwort\s*:', '답변\s*:', '정답\s*:', '답\s*:', '答案\s*:', '答案\s*:', '答\s*:', '答\s*:', '答复\s*:', '答曰\s*:', 'الإجابة:', 'الجواب:', 'إجابة:', 'الإجابة النهائية:', 'الإجابة الصحيحة:', 'الإجابة الصحيحة هي:', 'الإجابة هي:', 'الجواب النهائي:', 'Respuesta\s*:', 'Risposta\s*:', '答え\s*:', '答え\s*:', '回答\s*:', '回答\s*:', '解答\s*:', 'Jawaban\s*:', 'Réponse\s*:', 'Resposta\s*:', 'Jibu\s*:', 'Idahun\s*:', 'Ìdáhùn\s*:', 'Idáhùn\s*:', 'Àmọ̀nà\s*:', 'Àdáhùn\s*:', 'Ànúgọ\s*:', 'Àṣàyàn\s*:', ] EQUALITY_TEMPLATE = r""" Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications Examples: Expression 1: $2x+3$ Expression 2: $3+2x$ Yes Expression 1: 3/2 Expression 2: 1.5 Yes Expression 1: $x^2+2x+1$ Expression 2: $y^2+2y+1$ No Expression 1: $x^2+2x+1$ Expression 2: $(x+1)^2$ Yes Expression 1: 3245/5 Expression 2: 649 No (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) Expression 1: 2/(-3) Expression 2: -2/3 Yes (trivial simplifications are allowed) Expression 1: 72 degrees Expression 2: 72 Yes (give benefit of the doubt to units) Expression 1: 64 Expression 2: 64 square feet Yes (give benefit of the doubt to units) --- YOUR TASK Respond with only "Yes" or "No" (without quotes). Do not include a rationale. Expression 1: %(expression1)s Expression 2: %(expression2)s """.strip() HTML_JINJA = """

Prompt conversation

{% for message in prompt_messages %} {{ message_to_html(message) | safe }} {% endfor %}

Sampled message

{{ message_to_html(next_message) | safe }}

Results

Correct Answer: {{ correct_answer }}

Extracted Answer: {{ extracted_answer }}

Score: {{ score }}

""" def format_multichoice_question(row): return QUERY_TEMPLATE_MULTICHOICE.format(**row) def check_equality(sampler: SamplerBase, expr1: str, expr2: str): prompt = EQUALITY_TEMPLATE % {'expression1': expr1, 'expression2': expr2} sampler_response = sampler([dict(content=prompt, role='user')]) response_text = sampler_response.response_text return response_text.lower().strip() == 'yes' def _compute_stat(values: list, stat: str): if stat == 'mean': return np.mean(values) elif stat == 'std': return np.std(values) elif stat == 'min': return np.min(values) elif stat == 'max': return np.max(values) elif stat == 'n_samples': return len(values) elif stat == 'bootstrap_std': return np.std([ np.mean(np.random.choice(values, len(values))) for _ in range(1000) ]) else: raise ValueError(f'Unknown {stat =}') def aggregate_results( single_eval_results: list[SingleEvalResult], default_stats: tuple[str, ...] = ('mean', 'std'), name2stats: dict[str, tuple[str]] | None = None, ) -> EvalResult: """Aggregate results from multiple evaluations into a single EvalResult.""" name2stats = name2stats or {} name2values = defaultdict(list) htmls = [] convos = [] metadata = [] for single_eval_result in single_eval_results: for name, value in single_eval_result.metrics.items(): name2values[name].append(value) if single_eval_result.score is not None: name2values['score'].append(single_eval_result.score) htmls.append(single_eval_result.html) convos.append(single_eval_result.convo) metadata.append(single_eval_result.example_level_metadata) final_metrics = {} for name, values in name2values.items(): stats = name2stats.get(name, default_stats) for stat in stats: key = name if stat == 'mean' else f'{name}:{stat}' final_metrics[key] = _compute_stat(values, stat) return EvalResult( score=final_metrics.pop('score', None), metrics=final_metrics, htmls=htmls, convos=convos, metadata={'example_level_metadata': metadata}, ) def map_with_progress( f: Callable, xs: list[Any], num_threads: int = os.cpu_count() or 10, pbar: bool = True, ): """Apply f to each element of xs, using a ThreadPool, and show progress.""" pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x if os.getenv('debug'): return list(map(f, pbar_fn(xs, total=len(xs)))) else: with ThreadPool(min(num_threads, len(xs))) as pool: return list(pbar_fn(pool.imap(f, xs), total=len(xs))) jinja_env = jinja2.Environment( loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined, autoescape=jinja2.select_autoescape(['html', 'xml']), ) _message_template = """
{{ role }} {% if variant %}({{ variant }}){% endif %}
{{ content }}
""" def message_to_html(message: Message) -> str: """Generate HTML snippet (inside a
) for a message.""" return jinja_env.from_string(_message_template).render( role=message['role'], content=message['content'], variant=message.get('variant', None), ) jinja_env.globals['message_to_html'] = message_to_html _report_template = """ {% if metrics %}

Metrics

{% for name, value in metrics.items() %} {% endfor %}
Metric Value
Score {{ score | float | round(3) }}
{{ name }} {{ value }}
{% endif %}

Examples

{% for html in htmls %} {{ html | safe }}
{% endfor %} """ def make_report(eval_result: EvalResult) -> str: """Create a standalone HTML report from an EvalResult.""" return jinja_env.from_string(_report_template).render( score=eval_result.score, metrics=eval_result.metrics, htmls=eval_result.htmls, ) def make_report_from_example_htmls(htmls: list[str]): """Create a standalone HTML report from a list of example htmls.""" return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls) def normalize_response(response: str) -> str: """Normalize the response by removing markdown and LaTeX formatting that may prevent a match.""" return (response.replace('**', '').replace('$\\boxed{', '').replace( '}$', '').replace('\\$', '').replace('$\\text{', '').replace( '$', '').replace('\\mathrm{', '').replace('\\{', '').replace( '\\text', '').replace('\\(', '').replace('\\mathbf{', '').replace('{', '').replace('\\boxed', '')) def normalize_extracted_answer(extracted_answer: str) -> str: return ( # In arabic these are the letters used for A-D in multiple choice questions extracted_answer.replace('أ', ' A').replace('ب', ' B').replace( 'ج', ' C').replace('د', ' D') # In Bengali these are the letters used for A-D in multiple choice questions .replace('অ', ' A').replace('ব', ' B').replace('ড', ' C').replace('ঢ', ' D') # In Japanese these are the letters sometimes used for A-D in multiple choice questions .replace('A', ' A').replace('B', ' B').replace('C', ' C').replace('D', ' D').strip()) def url_to_fileobj(url: str, binary=False) -> Any: response = requests.get(url) response.raise_for_status() return io.BytesIO(response.content) if binary else io.StringIO( response.text) def has_only_user_assistant_messages(messages: list[Message]) -> bool: """Check if the messages only contain user and assistant messages.""" return all(m['role'] in ('user', 'assistant') for m in messages)