mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 6aabba778d
into d572761cef
This commit is contained in:
commit
0e8b3619ce
@ -0,0 +1,83 @@
|
||||
from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator
|
||||
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
|
||||
|
||||
# Reader configuration
|
||||
reader_cfg = dict(
|
||||
input_columns=[
|
||||
'prompt_trans',
|
||||
],
|
||||
output_column='prompt_id', # useless
|
||||
)
|
||||
|
||||
|
||||
infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{prompt_trans}', # prompt mode: zero-shot
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=ChatInferencer),
|
||||
)
|
||||
|
||||
# Evaluation configuration
|
||||
|
||||
healthbench_dataset = dict(
|
||||
type=HealthBenchDataset,
|
||||
abbr='healthbench',
|
||||
path='huihuixu/healthbench',
|
||||
subset='',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=dict(
|
||||
evaluator=dict(type=HealthBenchEvaluator, n_repeats=1, n_threads=1, subset_name=''),
|
||||
pred_role='BOT',
|
||||
),
|
||||
)
|
||||
healthbench_hard_dataset = dict(
|
||||
type=HealthBenchDataset,
|
||||
abbr='healthbench_hard',
|
||||
path='huihuixu/healthbench',
|
||||
subset='hard',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=dict(
|
||||
evaluator=dict(type=HealthBenchEvaluator, n_repeats=1, n_threads=1, subset_name='hard'),
|
||||
pred_role='BOT',
|
||||
),
|
||||
)
|
||||
healthbench_consensus_dataset = dict(
|
||||
type=HealthBenchDataset,
|
||||
abbr='healthbench_consensus',
|
||||
path='huihuixu/healthbench',
|
||||
subset='consensus',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=dict(
|
||||
evaluator=dict(type=HealthBenchEvaluator, n_repeats=1, n_threads=1, subset_name='consensus'),
|
||||
pred_role='BOT',
|
||||
),
|
||||
)
|
||||
# healthbench_meta_dataset = dict(
|
||||
# type=HealthBenchDatasetMeta,
|
||||
# abbr='healthbench_meta',
|
||||
# path='huihuixu/healthbench',
|
||||
# subset='meta',
|
||||
# reader_cfg=reader_cfg,
|
||||
# infer_cfg=infer_cfg,
|
||||
# eval_cfg=dict(
|
||||
# evaluator=dict(type=HealthBenchEvaluator, n_repeats=1, n_threads=1, subset_name=''),
|
||||
# pred_role='BOT',
|
||||
# ),
|
||||
# )
|
||||
|
||||
healthbench_all_datasets = [healthbench_dataset, healthbench_hard_dataset, healthbench_consensus_dataset, ] # healthbench_meta_dataset ]
|
@ -59,6 +59,7 @@ from .govrepcrs import * # noqa: F401, F403
|
||||
from .gpqa import * # noqa: F401, F403
|
||||
from .gsm8k import * # noqa: F401, F403
|
||||
from .gsm_hard import * # noqa: F401, F403
|
||||
from .healthbench.healthbench import *
|
||||
from .hellaswag import * # noqa: F401, F403
|
||||
from .hle import * # noqa: F401, F403
|
||||
from .huggingface import * # noqa: F401, F403
|
||||
|
380
opencompass/datasets/healthbench/common.py
Normal file
380
opencompass/datasets/healthbench/common.py
Normal file
@ -0,0 +1,380 @@
|
||||
import io
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any, Callable
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from .types import EvalResult, Message, SamplerBase, SingleEvalResult
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||
|
||||
{Question}
|
||||
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""".strip()
|
||||
|
||||
ANSWER_PATTERN_MULTICHOICE = r'(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?'
|
||||
ANSWER_PATTERN = r'(?i)Answer\s*:\s*([^\n]+)'
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||
'(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])')
|
||||
# All the different ways "Answer" is written in different languages
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
'Answer\s*:',
|
||||
'Answer\s*:', # Korean invisible character
|
||||
'উত্তর\s*:',
|
||||
'उत्तर\s*:',
|
||||
'উত্তরঃ',
|
||||
'উত্তর\s*:',
|
||||
'Antwort\s*:',
|
||||
'답변\s*:',
|
||||
'정답\s*:',
|
||||
'답\s*:',
|
||||
'答案\s*:',
|
||||
'答案\s*:',
|
||||
'答\s*:',
|
||||
'答\s*:',
|
||||
'答复\s*:',
|
||||
'答曰\s*:',
|
||||
'الإجابة:',
|
||||
'الجواب:',
|
||||
'إجابة:',
|
||||
'الإجابة النهائية:',
|
||||
'الإجابة الصحيحة:',
|
||||
'الإجابة الصحيحة هي:',
|
||||
'الإجابة هي:',
|
||||
'الجواب النهائي:',
|
||||
'Respuesta\s*:',
|
||||
'Risposta\s*:',
|
||||
'答え\s*:',
|
||||
'答え\s*:',
|
||||
'回答\s*:',
|
||||
'回答\s*:',
|
||||
'解答\s*:',
|
||||
'Jawaban\s*:',
|
||||
'Réponse\s*:',
|
||||
'Resposta\s*:',
|
||||
'Jibu\s*:',
|
||||
'Idahun\s*:',
|
||||
'Ìdáhùn\s*:',
|
||||
'Idáhùn\s*:',
|
||||
'Àmọ̀nà\s*:',
|
||||
'Àdáhùn\s*:',
|
||||
'Ànúgọ\s*:',
|
||||
'Àṣàyàn\s*:',
|
||||
]
|
||||
|
||||
EQUALITY_TEMPLATE = r"""
|
||||
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||
|
||||
Examples:
|
||||
|
||||
Expression 1: $2x+3$
|
||||
Expression 2: $3+2x$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3/2
|
||||
Expression 2: 1.5
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $y^2+2y+1$
|
||||
|
||||
No
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $(x+1)^2$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3245/5
|
||||
Expression 2: 649
|
||||
|
||||
No
|
||||
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||
|
||||
Expression 1: 2/(-3)
|
||||
Expression 2: -2/3
|
||||
|
||||
Yes
|
||||
(trivial simplifications are allowed)
|
||||
|
||||
Expression 1: 72 degrees
|
||||
Expression 2: 72
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
Expression 1: 64
|
||||
Expression 2: 64 square feet
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
---
|
||||
|
||||
YOUR TASK
|
||||
|
||||
|
||||
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||||
|
||||
Expression 1: %(expression1)s
|
||||
Expression 2: %(expression2)s
|
||||
""".strip()
|
||||
|
||||
HTML_JINJA = """
|
||||
<h3>Prompt conversation</h3>
|
||||
{% for message in prompt_messages %}
|
||||
{{ message_to_html(message) | safe }}
|
||||
{% endfor %}
|
||||
<h3>Sampled message</h3>
|
||||
{{ message_to_html(next_message) | safe }}
|
||||
<h3>Results</h3>
|
||||
<p>Correct Answer: {{ correct_answer }}</p>
|
||||
<p>Extracted Answer: {{ extracted_answer }}</p>
|
||||
<p>Score: {{ score }}</p>
|
||||
"""
|
||||
|
||||
|
||||
def format_multichoice_question(row):
|
||||
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||||
|
||||
|
||||
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
||||
prompt = EQUALITY_TEMPLATE % {'expression1': expr1, 'expression2': expr2}
|
||||
sampler_response = sampler([dict(content=prompt, role='user')])
|
||||
response_text = sampler_response.response_text
|
||||
return response_text.lower().strip() == 'yes'
|
||||
|
||||
|
||||
def _compute_stat(values: list, stat: str):
|
||||
if stat == 'mean':
|
||||
return np.mean(values)
|
||||
elif stat == 'std':
|
||||
return np.std(values)
|
||||
elif stat == 'min':
|
||||
return np.min(values)
|
||||
elif stat == 'max':
|
||||
return np.max(values)
|
||||
elif stat == 'n_samples':
|
||||
return len(values)
|
||||
elif stat == 'bootstrap_std':
|
||||
return np.std([
|
||||
np.mean(np.random.choice(values, len(values))) for _ in range(1000)
|
||||
])
|
||||
else:
|
||||
raise ValueError(f'Unknown {stat =}')
|
||||
|
||||
|
||||
def aggregate_results(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
default_stats: tuple[str, ...] = ('mean', 'std'),
|
||||
name2stats: dict[str, tuple[str]] | None = None,
|
||||
) -> EvalResult:
|
||||
"""Aggregate results from multiple evaluations into a single EvalResult."""
|
||||
name2stats = name2stats or {}
|
||||
name2values = defaultdict(list)
|
||||
htmls = []
|
||||
convos = []
|
||||
metadata = []
|
||||
for single_eval_result in single_eval_results:
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
name2values['score'].append(single_eval_result.score)
|
||||
htmls.append(single_eval_result.html)
|
||||
convos.append(single_eval_result.convo)
|
||||
metadata.append(single_eval_result.example_level_metadata)
|
||||
final_metrics = {}
|
||||
for name, values in name2values.items():
|
||||
stats = name2stats.get(name, default_stats)
|
||||
for stat in stats:
|
||||
key = name if stat == 'mean' else f'{name}:{stat}'
|
||||
final_metrics[key] = _compute_stat(values, stat)
|
||||
return EvalResult(
|
||||
score=final_metrics.pop('score', None),
|
||||
metrics=final_metrics,
|
||||
htmls=htmls,
|
||||
convos=convos,
|
||||
metadata={'example_level_metadata': metadata},
|
||||
)
|
||||
|
||||
|
||||
def map_with_progress(
|
||||
f: Callable,
|
||||
xs: list[Any],
|
||||
num_threads: int = os.cpu_count() or 10,
|
||||
pbar: bool = True,
|
||||
):
|
||||
"""Apply f to each element of xs, using a ThreadPool, and show progress."""
|
||||
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
|
||||
|
||||
if os.getenv('debug'):
|
||||
return list(map(f, pbar_fn(xs, total=len(xs))))
|
||||
else:
|
||||
with ThreadPool(min(num_threads, len(xs))) as pool:
|
||||
return list(pbar_fn(pool.imap(f, xs), total=len(xs)))
|
||||
|
||||
|
||||
jinja_env = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
undefined=jinja2.StrictUndefined,
|
||||
autoescape=jinja2.select_autoescape(['html', 'xml']),
|
||||
)
|
||||
_message_template = """
|
||||
<div class="message {{ role }}">
|
||||
<div class="role">
|
||||
{{ role }}
|
||||
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
||||
</div>
|
||||
<div class="content">
|
||||
<pre>{{ content }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def message_to_html(message: Message) -> str:
|
||||
"""Generate HTML snippet (inside a <div>) for a message."""
|
||||
return jinja_env.from_string(_message_template).render(
|
||||
role=message['role'],
|
||||
content=message['content'],
|
||||
variant=message.get('variant', None),
|
||||
)
|
||||
|
||||
|
||||
jinja_env.globals['message_to_html'] = message_to_html
|
||||
|
||||
_report_template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
.message {
|
||||
padding: 8px 16px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.message.user {
|
||||
background-color: #B2DFDB;
|
||||
color: #00695C;
|
||||
}
|
||||
.message.assistant {
|
||||
background-color: #B39DDB;
|
||||
color: #4527A0;
|
||||
}
|
||||
.message.system {
|
||||
background-color: #EEEEEE;
|
||||
color: #212121;
|
||||
}
|
||||
.role {
|
||||
font-weight: bold;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.variant {
|
||||
color: #795548;
|
||||
}
|
||||
table, th, td {
|
||||
border: 1px solid black;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{% if metrics %}
|
||||
<h1>Metrics</h1>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th>Value</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>Score</b></td>
|
||||
<td>{{ score | float | round(3) }}</td>
|
||||
</tr>
|
||||
{% for name, value in metrics.items() %}
|
||||
<tr>
|
||||
<td>{{ name }}</td>
|
||||
<td>{{ value }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
{% endif %}
|
||||
<h1>Examples</h1>
|
||||
{% for html in htmls %}
|
||||
{{ html | safe }}
|
||||
<hr>
|
||||
{% endfor %}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def make_report(eval_result: EvalResult) -> str:
|
||||
"""Create a standalone HTML report from an EvalResult."""
|
||||
return jinja_env.from_string(_report_template).render(
|
||||
score=eval_result.score,
|
||||
metrics=eval_result.metrics,
|
||||
htmls=eval_result.htmls,
|
||||
)
|
||||
|
||||
|
||||
def make_report_from_example_htmls(htmls: list[str]):
|
||||
"""Create a standalone HTML report from a list of example htmls."""
|
||||
return jinja_env.from_string(_report_template).render(score=None,
|
||||
metrics={},
|
||||
htmls=htmls)
|
||||
|
||||
|
||||
def normalize_response(response: str) -> str:
|
||||
"""Normalize the response by removing markdown and LaTeX formatting that
|
||||
may prevent a match."""
|
||||
|
||||
return (response.replace('**', '').replace('$\\boxed{', '').replace(
|
||||
'}$', '').replace('\\$', '').replace('$\\text{', '').replace(
|
||||
'$', '').replace('\\mathrm{', '').replace('\\{', '').replace(
|
||||
'\\text',
|
||||
'').replace('\\(',
|
||||
'').replace('\\mathbf{',
|
||||
'').replace('{',
|
||||
'').replace('\\boxed', ''))
|
||||
|
||||
|
||||
def normalize_extracted_answer(extracted_answer: str) -> str:
|
||||
return (
|
||||
# In arabic these are the letters used for A-D in multiple choice questions
|
||||
extracted_answer.replace('أ', ' A').replace('ب', ' B').replace(
|
||||
'ج', ' C').replace('د', ' D')
|
||||
# In Bengali these are the letters used for A-D in multiple choice questions
|
||||
.replace('অ', ' A').replace('ব',
|
||||
' B').replace('ড',
|
||||
' C').replace('ঢ', ' D')
|
||||
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
|
||||
.replace('A', ' A').replace('B',
|
||||
' B').replace('C',
|
||||
' C').replace('D',
|
||||
' D').strip())
|
||||
|
||||
|
||||
def url_to_fileobj(url: str, binary=False) -> Any:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
return io.BytesIO(response.content) if binary else io.StringIO(
|
||||
response.text)
|
||||
|
||||
|
||||
def has_only_user_assistant_messages(messages: list[Message]) -> bool:
|
||||
"""Check if the messages only contain user and assistant messages."""
|
||||
return all(m['role'] in ('user', 'assistant') for m in messages)
|
408
opencompass/datasets/healthbench/healthbench.py
Normal file
408
opencompass/datasets/healthbench/healthbench.py
Normal file
@ -0,0 +1,408 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from ..base import BaseDataset
|
||||
from . import common
|
||||
from .sampler.chat_completion_sampler import ChatCompletionSampler
|
||||
from .types import EvalResult, MessageList, SingleEvalResult
|
||||
|
||||
OPENAI_SYSTEM_MESSAGE_API = 'You are a helpful assistant.'
|
||||
max_tokens_judge = 2048
|
||||
|
||||
|
||||
class RubricItem:
|
||||
def __init__(self, criterion: str, points: float, tags: list[str]):
|
||||
self.criterion = criterion
|
||||
self.points = points
|
||||
self.tags = tags
|
||||
|
||||
def __str__(self):
|
||||
return f'[{self.points}] {self.criterion}'
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'criterion': self.criterion,
|
||||
'points': self.points,
|
||||
'tags': self.tags,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
return cls(
|
||||
criterion=d['criterion'],
|
||||
points=d['points'],
|
||||
tags=d['tags'],
|
||||
)
|
||||
|
||||
|
||||
def _parse(item):
|
||||
prompt = item['prompt'] + [dict(role='assistant', content='')]
|
||||
item['prompt_trans'] = prompt
|
||||
return item
|
||||
|
||||
HEALTHBENCH_HTML_JINJA = common.HTML_JINJA.replace('<p>Correct Answer: {{ correct_answer }}</p>\n', '', ) + '<p>Rubrics with grades: {{ rubric_grades }}</p>' # noqa: E305, E501
|
||||
|
||||
|
||||
def parse_json_to_dict(json_string: str) -> dict:
|
||||
json_cleaned = re.sub(r'^```json\s*|\s*```$', '', json_string.strip()) # noqa: W291, E501
|
||||
try:
|
||||
return json.loads(json_cleaned)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f'JSON decoding failed: {e}')
|
||||
return {}
|
||||
|
||||
|
||||
GRADER_TEMPLATE = """
|
||||
Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.
|
||||
# Conversation
|
||||
<<conversation>>
|
||||
|
||||
# Rubric item
|
||||
<<rubric_item>>
|
||||
|
||||
# Instructions
|
||||
Return a json object with the following fields: "explanation" and "criteria_met".
|
||||
- The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
|
||||
- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met.
|
||||
- One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.
|
||||
|
||||
# Example 1
|
||||
For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The assistant did not tell the user to go to the hospital if unconscious.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
# Example 2
|
||||
Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.
|
||||
For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The assistant actually gave the user accurate information, so the criteria was not met.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
As another example, if the criteria says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The response is concise, so the criteria was not met.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.
|
||||
|
||||
# Final instruction
|
||||
Return just the json object in markdown format. Do not include any other text in the response.
|
||||
""".strip() # noqa: E501
|
||||
|
||||
|
||||
def calculate_score(rubric_items: list[RubricItem], grading_response_list: list[dict]) -> float | None: # noqa: E501
|
||||
total_possible_points = sum(
|
||||
rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0 # noqa: E501
|
||||
)
|
||||
if total_possible_points == 0:
|
||||
# should not happen for overall score, but may happen for tags
|
||||
return None
|
||||
|
||||
achieved_points = sum(
|
||||
rubric_item.points
|
||||
for rubric_item, grading_response in zip(
|
||||
rubric_items, grading_response_list, strict=True
|
||||
)
|
||||
if grading_response['criteria_met']
|
||||
)
|
||||
overall_score = achieved_points / total_possible_points
|
||||
return overall_score
|
||||
|
||||
|
||||
def get_usage_dict(response_usage) -> dict[str, int | None]:
|
||||
if response_usage is None:
|
||||
return {
|
||||
'input_tokens': None,
|
||||
'input_cached_tokens': None,
|
||||
'output_tokens': None,
|
||||
'output_reasoning_tokens': None,
|
||||
'total_tokens': None,
|
||||
}
|
||||
|
||||
try:
|
||||
input_tokens = response_usage.input_tokens
|
||||
input_tokens_details = response_usage.input_tokens_details
|
||||
output_tokens = response_usage.output_tokens
|
||||
output_tokens_details = response_usage.output_tokens_details
|
||||
total_tokens = response_usage.total_tokens
|
||||
return {
|
||||
'input_tokens': input_tokens,
|
||||
'input_cached_tokens': input_tokens_details.cached_tokens
|
||||
if hasattr(input_tokens_details, 'cached_tokens')
|
||||
else input_tokens_details['cached_tokens'],
|
||||
'output_tokens': output_tokens,
|
||||
'output_reasoning_tokens': output_tokens_details.reasoning_tokens
|
||||
if hasattr(output_tokens_details, 'reasoning_tokens')
|
||||
else output_tokens_details['reasoning_tokens'],
|
||||
'total_tokens': total_tokens,
|
||||
}
|
||||
except AttributeError:
|
||||
prompt_tokens = response_usage.prompt_tokens
|
||||
prompt_tokens_details = response_usage.prompt_tokens_details
|
||||
completion_tokens = response_usage.completion_tokens
|
||||
completion_tokens_details = response_usage.completion_tokens_details # noqa: E501
|
||||
total_tokens = response_usage.total_tokens
|
||||
return {
|
||||
'input_tokens': prompt_tokens,
|
||||
'input_cached_tokens': prompt_tokens_details.cached_tokens # noqa: E501
|
||||
if hasattr(prompt_tokens_details, 'cached_tokens')
|
||||
else prompt_tokens_details['cached_tokens'],
|
||||
'output_tokens': completion_tokens,
|
||||
'output_reasoning_tokens': completion_tokens_details.reasoning_tokens # noqa: E501
|
||||
if hasattr(completion_tokens_details, 'reasoning_tokens')
|
||||
else completion_tokens_details['reasoning_tokens'],
|
||||
'total_tokens': total_tokens,
|
||||
}
|
||||
|
||||
|
||||
def _compute_clipped_stats(
|
||||
values: list,
|
||||
stat: str,
|
||||
):
|
||||
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and
|
||||
n_samples for final HealthBench scoring."""
|
||||
if stat == 'mean':
|
||||
return np.clip(np.mean(values), 0, 1)
|
||||
elif stat == 'n_samples':
|
||||
return len(values)
|
||||
elif stat == 'bootstrap_std':
|
||||
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)] # noqa: E501
|
||||
bootstrap_means = [
|
||||
_compute_clipped_stats(list(s), 'mean') for s in bootstrap_samples
|
||||
]
|
||||
return np.std(bootstrap_means)
|
||||
else:
|
||||
raise ValueError(f'Unknown {stat =}')
|
||||
|
||||
|
||||
def _aggregate_get_clipped_mean(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
) -> EvalResult:
|
||||
"""Aggregate multiple SingleEvalResults into a single EvalResult for
|
||||
HealthBench.
|
||||
|
||||
For each metric, returns the stats in _compute_clipped_stats.
|
||||
"""
|
||||
name2values = defaultdict(list)
|
||||
htmls = []
|
||||
convos = []
|
||||
metadata = []
|
||||
for single_eval_result in single_eval_results:
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
name2values['score'].append(single_eval_result.score)
|
||||
htmls.append(single_eval_result.html)
|
||||
convos.append(single_eval_result.convo)
|
||||
metadata.append(single_eval_result.example_level_metadata)
|
||||
final_metrics = {}
|
||||
for name, values in name2values.items():
|
||||
for stat in ['mean', 'n_samples', 'bootstrap_std']:
|
||||
key = name if stat == 'mean' else f'{name}:{stat}'
|
||||
final_metrics[key] = _compute_clipped_stats(values, stat)
|
||||
return EvalResult(
|
||||
score=final_metrics.pop('score', None),
|
||||
metrics=final_metrics,
|
||||
htmls=htmls,
|
||||
convos=convos,
|
||||
metadata={'example_level_metadata': metadata},
|
||||
)
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class HealthBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, **kwargs):
|
||||
subset = kwargs.get('subset')
|
||||
match subset:
|
||||
case '':
|
||||
data_files = {'test': '2025-05-07-06-14-12_oss_eval.jsonl'}
|
||||
case 'hard':
|
||||
data_files = {'test': 'hard_2025-05-08-21-00-10.jsonl'}
|
||||
case 'consensus':
|
||||
data_files = {'test': 'consensus_2025-05-09-20-00-46.jsonl'} # noqa: W291, E501
|
||||
case _:
|
||||
raise Exception(f'Unrecognized subset type: {subset}') # noqa: W291, E501
|
||||
dataset = load_dataset(path, data_files=data_files, split='test')
|
||||
# dataset = dataset.select(range(2))
|
||||
dataset = dataset.map(lambda item: _parse(item))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class HealthBenchEvaluator(BaseEvaluator):
|
||||
"""only consider the model completion mode, not physician mode / reference
|
||||
mode."""
|
||||
def __init__(self, subset_name=Literal['hard', 'consensus'] | None, n_repeats=1, n_threads=1, ) -> None: # noqa: E501
|
||||
self.n_repeats = n_repeats
|
||||
self.n_threads = n_threads
|
||||
self.subset_name = subset_name
|
||||
self.grader_model = ChatCompletionSampler(model=os.environ['OC_JUDGE_MODEL'], system_message=OPENAI_SYSTEM_MESSAGE_API, max_tokens=2048,) # noqa: E501
|
||||
|
||||
def grade_sample(self, prompt: list[dict[str, str]], response_text: str, example_tags: list[str], rubric_items: list[RubricItem], ) -> tuple[dict, str, list[dict]]: # noqa: E501
|
||||
# construct and grade the sample
|
||||
convo_with_response = prompt + [dict(content=response_text, role='assistant')] # noqa: E501
|
||||
|
||||
def grade_rubric_item(rubric_item: RubricItem) -> dict:
|
||||
convo_str = '\n\n'.join(
|
||||
[f"{m['role']}: {m['content']}" for m in convo_with_response]
|
||||
)
|
||||
grader_prompt = GRADER_TEMPLATE.replace(
|
||||
'<<conversation>>', convo_str
|
||||
).replace('<<rubric_item>>', str(rubric_item))
|
||||
messages: MessageList = [dict(content=grader_prompt, role='user')]
|
||||
while True:
|
||||
sampler_response = self.grader_model(messages)
|
||||
grading_response = sampler_response.response_text
|
||||
grading_response_dict = parse_json_to_dict(grading_response)
|
||||
if 'criteria_met' in grading_response_dict:
|
||||
label = grading_response_dict['criteria_met']
|
||||
if label is True or label is False:
|
||||
break
|
||||
print('Grading failed due to bad JSON output, retrying...')
|
||||
return grading_response_dict
|
||||
|
||||
grading_response_list = common.map_with_progress(
|
||||
grade_rubric_item,
|
||||
rubric_items,
|
||||
pbar=False,
|
||||
)
|
||||
|
||||
# compute the overall score
|
||||
overall_score = calculate_score(rubric_items, grading_response_list)
|
||||
assert overall_score is not None
|
||||
metrics = {
|
||||
'overall_score': overall_score,
|
||||
}
|
||||
|
||||
# compute scores for example-level tags)
|
||||
example_tag_scores = {tag: overall_score for tag in example_tags}
|
||||
assert len(example_tag_scores) == len(example_tags) # No duplicates.
|
||||
metrics.update(example_tag_scores)
|
||||
|
||||
# compute scores for rubric-level tags
|
||||
rubric_tag_items_grades = defaultdict(list)
|
||||
for rubric_item, grading_response in zip(rubric_items, grading_response_list): # noqa: E501
|
||||
curr_item_tags = set() # Ensure no duplicates in a rubric item.
|
||||
for tag in rubric_item.tags:
|
||||
rubric_tag_items_grades[tag].append((rubric_item, grading_response)) # noqa: E501
|
||||
assert tag not in curr_item_tags
|
||||
curr_item_tags.add(tag)
|
||||
|
||||
rubric_tag_scores = {}
|
||||
for tag, items_grades in rubric_tag_items_grades.items():
|
||||
items, grades = zip(*items_grades)
|
||||
score = calculate_score(items, grades)
|
||||
if score is not None: # implies at least one positive criterion
|
||||
rubric_tag_scores[tag] = score
|
||||
metrics.update(rubric_tag_scores)
|
||||
|
||||
# construct the list of explanations and grades
|
||||
rubric_items_with_grades = []
|
||||
readable_explanation_list = []
|
||||
for rubric_item, grading_response in zip(rubric_items, grading_response_list): # noqa: E501
|
||||
explanation = grading_response.get('explanation', 'No explanation provided') # noqa: E501
|
||||
criteria_met = grading_response['criteria_met']
|
||||
readable_explanation = (
|
||||
f'[{criteria_met}] {rubric_item}\n\tExplanation: {explanation}'
|
||||
)
|
||||
readable_explanation_list.append(readable_explanation)
|
||||
rubric_items_with_grades.append(
|
||||
{
|
||||
**rubric_item.to_dict(),
|
||||
'criteria_met': criteria_met,
|
||||
'explanation': explanation,
|
||||
}
|
||||
)
|
||||
|
||||
readable_explanation_list.sort(
|
||||
key=lambda x: x.startswith('[False]'), reverse=True
|
||||
)
|
||||
readable_explanation_str = '\n\n'.join(readable_explanation_list)
|
||||
readable_explanation_str = f'\n\n{readable_explanation_str}'
|
||||
|
||||
return metrics, readable_explanation_str, rubric_items_with_grades
|
||||
|
||||
def score(self, predictions, references, test_set):
|
||||
results = []
|
||||
if len(predictions) != len(references):
|
||||
return {'error': 'preds and refrs have different length'} # noqa: W291, E501
|
||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
||||
response_usage = None
|
||||
actual_queried_prompt_messages = test_set[idx]['prompt']
|
||||
response_text = i
|
||||
row = test_set[idx] # noqa: W291
|
||||
metrics, readable_explanation_str, rubric_items_with_grades = (
|
||||
self.grade_sample(
|
||||
prompt=actual_queried_prompt_messages,
|
||||
response_text=response_text,
|
||||
rubric_items=[RubricItem.from_dict(d) for d in row['rubrics']], # noqa: E501
|
||||
example_tags=row['example_tags'],
|
||||
)
|
||||
)
|
||||
|
||||
score = metrics['overall_score']
|
||||
|
||||
# Create HTML for each sample result
|
||||
html = common.jinja_env.from_string(
|
||||
HEALTHBENCH_HTML_JINJA.replace(
|
||||
'{{ rubric_grades }}',
|
||||
readable_explanation_str.replace('\n', '<br>'),
|
||||
)
|
||||
).render(
|
||||
prompt_messages=actual_queried_prompt_messages,
|
||||
next_message=dict(content=response_text, role='assistant'),
|
||||
score=metrics['overall_score'],
|
||||
extracted_answer=response_text,
|
||||
)
|
||||
|
||||
convo = actual_queried_prompt_messages + [
|
||||
dict(content=response_text, role='assistant')
|
||||
]
|
||||
results.append(SingleEvalResult(
|
||||
html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics=metrics,
|
||||
example_level_metadata={
|
||||
'score': score,
|
||||
'usage': get_usage_dict(response_usage),
|
||||
'rubric_items': rubric_items_with_grades,
|
||||
'prompt': actual_queried_prompt_messages,
|
||||
'completion': [dict(content=response_text, role='assistant')], # noqa: E501
|
||||
'prompt_id': row['prompt_id'],
|
||||
'completion_id': hashlib.sha256(
|
||||
(row['prompt_id'] + response_text).encode('utf-8')
|
||||
).hexdigest(),
|
||||
},
|
||||
))
|
||||
results = _aggregate_get_clipped_mean(results)
|
||||
assert results.metrics is not None
|
||||
metrics = results.metrics | {'score': results.score}
|
||||
metrics = dict(sorted(metrics.items()))
|
||||
acc = metrics.get('f1_score', metrics.get('score', None))
|
||||
return {'accuracy': acc, }
|
@ -0,0 +1,99 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
OPENAI_SYSTEM_MESSAGE_API = 'You are a helpful assistant.'
|
||||
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||
'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.'
|
||||
+ '\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01')
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class ChatCompletionSampler(SamplerBase):
|
||||
"""Sample from OpenAI's chat completion API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = 'gpt-3.5-turbo',
|
||||
system_message: str | None = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 1024,
|
||||
):
|
||||
self.api_key_name = 'OPENAI_API_KEY'
|
||||
self.client = OpenAI(
|
||||
base_url=os.getenv('OC_JUDGE_API_BASE'),
|
||||
api_key=os.getenv('OC_JUDGE_API_KEY'),
|
||||
# OC_JUDGE_MODEL
|
||||
)
|
||||
# using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
|
||||
self.model = model
|
||||
self.system_message = system_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = 'url'
|
||||
|
||||
def _handle_image(
|
||||
self,
|
||||
image: str,
|
||||
encoding: str = 'base64',
|
||||
format: str = 'png',
|
||||
fovea: int = 768,
|
||||
):
|
||||
new_image = {
|
||||
'type': 'image_url',
|
||||
'image_url': {
|
||||
'url': f'data:image/{format};{encoding},{image}',
|
||||
},
|
||||
}
|
||||
return new_image
|
||||
|
||||
def _handle_text(self, text: str):
|
||||
return {'type': 'text', 'text': text}
|
||||
|
||||
def _pack_message(self, role: str, content: Any):
|
||||
return {'role': str(role), 'content': content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
if self.system_message:
|
||||
message_list = [self._pack_message('system', self.system_message)
|
||||
] + message_list
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message_list,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
raise ValueError(
|
||||
'OpenAI API returned empty response; retrying')
|
||||
return SamplerResponse(
|
||||
response_text=content,
|
||||
response_metadata={'usage': response.usage},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
||||
except openai.BadRequestError as e:
|
||||
print('Bad Request Error', e)
|
||||
return SamplerResponse(
|
||||
response_text='No response (bad request).',
|
||||
response_metadata={'usage': None},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
55
opencompass/datasets/healthbench/types.py
Normal file
55
opencompass/datasets/healthbench/types.py
Normal file
@ -0,0 +1,55 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerResponse:
|
||||
"""Response from a sampler."""
|
||||
response_text: str
|
||||
actual_queried_message_list: MessageList
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
class SamplerBase:
|
||||
"""Base class for defining a sampling model, which can be evaluated, or
|
||||
used as part of the grading process."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
message_list: MessageList,
|
||||
) -> SamplerResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
"""Result of running an evaluation (usually consisting of many samples)"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleEvalResult:
|
||||
"""Result of evaluating a single sample."""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
example_level_metadata: dict[str, Any] | None = (
|
||||
None # Extra data such as rubric scores or sollen
|
||||
)
|
||||
|
||||
|
||||
class Eval:
|
||||
"""Base class for defining an evaluation."""
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user