mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix irrelevant files
This commit is contained in:
parent
7278a4ed19
commit
95d8d2ba4d
@ -1,88 +0,0 @@
|
||||
"""
|
||||
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
||||
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
||||
https://arxiv.org/abs/2311.12022
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from . import common
|
||||
from .common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA,
|
||||
format_multichoice_question)
|
||||
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
|
||||
|
||||
|
||||
class GPQAEval(Eval):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_repeats: int = 4,
|
||||
variant: str = 'diamond',
|
||||
num_examples: int
|
||||
| None = None, # restrict to a subset of the data for debugging
|
||||
):
|
||||
df = pandas.read_csv(
|
||||
f'https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv'
|
||||
)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
rng = random.Random(0)
|
||||
if num_examples:
|
||||
assert n_repeats == 1, 'n_repeats only supported for num_examples = None'
|
||||
examples = rng.sample(examples, num_examples)
|
||||
examples = examples * n_repeats
|
||||
examples = [
|
||||
example | {
|
||||
'permutation': rng.sample(range(4), 4)
|
||||
} for example in examples
|
||||
]
|
||||
self.examples = examples
|
||||
self.n_repeats = n_repeats
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
|
||||
def fn(row: dict):
|
||||
choices = [
|
||||
row['Correct Answer'],
|
||||
row['Incorrect Answer 1'],
|
||||
row['Incorrect Answer 2'],
|
||||
row['Incorrect Answer 3'],
|
||||
]
|
||||
choices = [choices[i] for i in row['permutation']]
|
||||
correct_index = choices.index(row['Correct Answer'])
|
||||
correct_answer = 'ABCD'[correct_index]
|
||||
choices_dict = dict(A=choices[0],
|
||||
B=choices[1],
|
||||
C=choices[2],
|
||||
D=choices[3],
|
||||
Question=row['Question'])
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_multichoice_question(choices_dict),
|
||||
role='user')
|
||||
]
|
||||
sampler_response = sampler(prompt_messages)
|
||||
response_text = sampler_response.response_text
|
||||
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
|
||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=actual_queried_prompt_messages,
|
||||
next_message=dict(content=response_text, role='assistant'),
|
||||
score=score,
|
||||
correct_answer=correct_answer,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = actual_queried_prompt_messages + [
|
||||
dict(content=response_text, role='assistant')
|
||||
]
|
||||
return SingleEvalResult(html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics={'chars': len(response_text)})
|
||||
|
||||
results = common.map_with_progress(fn, self.examples)
|
||||
return common.aggregate_results(results)
|
@ -1,16 +1,14 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
from opencompass.utils import get_logger
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from ..base import BaseDataset
|
||||
from . import common
|
||||
from .healthbench_eval import HealthBenchEval, RubricItem
|
||||
from .healthbench_meta_eval import HealthBenchMetaEval
|
||||
from .healthbench_eval import RubricItem
|
||||
from .sampler.chat_completion_sampler import ChatCompletionSampler
|
||||
from .types import SingleEvalResult
|
||||
|
||||
|
@ -1,32 +0,0 @@
|
||||
from .healthbench_eval import RubricItem, calculate_score
|
||||
|
||||
|
||||
def test_calculate_score():
|
||||
rubric_items = [
|
||||
RubricItem(criterion='test', points=7, tags=[]),
|
||||
RubricItem(criterion='test', points=5, tags=[]),
|
||||
RubricItem(criterion='test', points=10, tags=[]),
|
||||
RubricItem(criterion='test', points=-6, tags=[]),
|
||||
]
|
||||
grading_response_list = [
|
||||
{
|
||||
'criteria_met': True
|
||||
},
|
||||
{
|
||||
'criteria_met': False
|
||||
},
|
||||
{
|
||||
'criteria_met': True
|
||||
},
|
||||
{
|
||||
'criteria_met': True
|
||||
},
|
||||
]
|
||||
total_possible = 7 + 5 + 10
|
||||
achieved = 7 + 0 + 10 - 6
|
||||
assert (calculate_score(rubric_items, grading_response_list) == achieved /
|
||||
total_possible)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_calculate_score()
|
@ -1,262 +0,0 @@
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
from ..base import BaseDataset
|
||||
from .healthbench_eval import HealthBenchEval, RubricItem
|
||||
from .healthbench_meta_eval import HealthBenchMetaEval
|
||||
|
||||
|
||||
def _parse(item):
|
||||
item['rubrics'] = [RubricItem.from_dict(d) for d in item['rubrics']]
|
||||
return item
|
||||
|
||||
def _parse_meta(item):
|
||||
item['rubrics'] = [RubricItem.from_dict(d) for d in item['rubrics']]
|
||||
return item
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class HealthBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, prompt_mode: str, **kwargs):
|
||||
subset = kwargs.get('subset')
|
||||
# nrepeats=1
|
||||
# nthreads = 1
|
||||
match subset:
|
||||
case 'healthbench':
|
||||
data_files = {'test': '2025-05-07-06-14-12_oss_eval.jsonl'}
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
n_repeats=1,
|
||||
n_threads=1,
|
||||
subset_name=None,
|
||||
)
|
||||
case 'healthbench_hard':
|
||||
data_files = {'test': 'hard_2025-05-08-21-00-10.jsonl'}
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
n_repeats=1,
|
||||
n_threads=1,
|
||||
subset_name='hard',
|
||||
)
|
||||
case 'healthbench_consensus':
|
||||
data_files = {'test': 'consensus_2025-05-09-20-00-46.jsonl'}
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
n_repeats=1,
|
||||
n_threads=1,
|
||||
subset_name='consensus',
|
||||
)
|
||||
case 'healthbench_meta':
|
||||
data_files = {'test': '2025-05-07-06-14-12_oss_meta_eval.jsonl' }
|
||||
return HealthBenchMetaEval(
|
||||
grader_model=grading_sampler,
|
||||
n_repeats=1,
|
||||
n_threads=1,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f'Unrecognized eval type: {eval_name}')
|
||||
|
||||
dataset = load_dataset(path, data_files=data_files, split='test')
|
||||
|
||||
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class HealthBenchEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references, test_set):
|
||||
method = test_set['prompt_mode'][0]
|
||||
|
||||
if len(predictions) != len(references):
|
||||
return {'error': 'preds and refrs have different length'}
|
||||
correct = 0
|
||||
count = 0
|
||||
details = []
|
||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
||||
i = answer_cleansing(method, i, test_set['options'][idx],
|
||||
test_set['label'][idx])
|
||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
||||
count += 1
|
||||
if i == j:
|
||||
correct += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
result = {'accuracy': 100 * correct / count, 'details': details}
|
||||
return result
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module()
|
||||
def answer_cleansing(
|
||||
method: str,
|
||||
prediction: str,
|
||||
options: list,
|
||||
label: str,
|
||||
) -> str:
|
||||
|
||||
# Clean up unwanted phrases in the prediction
|
||||
for unwanted_phrase in [
|
||||
'I understand',
|
||||
'A through J',
|
||||
'A through E',
|
||||
'A through D',
|
||||
]:
|
||||
prediction = prediction.replace(unwanted_phrase, '')
|
||||
|
||||
options_num = len(options)
|
||||
options = [chr(65 + i) for i in range(options_num)]
|
||||
options_str = r'\b(' + '|'.join(options) + r')\b'
|
||||
prediction = re.findall(options_str, prediction)
|
||||
|
||||
if len(prediction) == 0:
|
||||
prediction = []
|
||||
else:
|
||||
# If there is a "label" and its length is 1,
|
||||
# process prediction accordingly
|
||||
if len(label) == 1:
|
||||
if method == 'few-shot':
|
||||
answer_flag = True if len(prediction) > 1 else False
|
||||
# choose the first or last element based on the answer_flag
|
||||
if answer_flag:
|
||||
prediction = [prediction[0]]
|
||||
else:
|
||||
prediction = [prediction[-1]]
|
||||
elif method == 'zero-shot':
|
||||
# choose the first element in list
|
||||
prediction = [prediction[0]]
|
||||
else:
|
||||
raise ValueError('Method is not properly defined ...')
|
||||
|
||||
# Remove trailing period if it exists
|
||||
if prediction[0] and prediction[0].endswith('.'):
|
||||
prediction[0] = prediction[0][:-1]
|
||||
|
||||
return prediction[0]
|
||||
|
||||
|
||||
def _generic_llmjudge_postprocess(judgement: str):
|
||||
match = re.search(r'(A|B)', judgement)
|
||||
grade_letter = (match.group(0) if match else 'B'
|
||||
) # Default to "INCORRECT" if no match
|
||||
return grade_letter
|
||||
|
||||
|
||||
def HealthBench_llmjudge_postprocess(
|
||||
output: dict,
|
||||
output_path: str,
|
||||
dataset: Dataset,
|
||||
) -> dict:
|
||||
# Get the original dataset
|
||||
original_dataset = dataset.reader.dataset['test']
|
||||
|
||||
judged_answers = []
|
||||
original_responses = []
|
||||
references = []
|
||||
details = []
|
||||
|
||||
# Initialize statistics dictionaries
|
||||
stats = {'medical_task': {}, 'body_system': {}, 'question_type': {}}
|
||||
|
||||
total_correct = 0
|
||||
total_count = 0
|
||||
|
||||
# Process each sample
|
||||
for k, v in output.items():
|
||||
idx = int(k) # Convert key to integer for indexing
|
||||
original_responses.append(v['prediction'])
|
||||
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
||||
|
||||
# Get category information from the dataset
|
||||
sample = original_dataset[idx]
|
||||
medical_task = sample.get('medical_task', 'unknown')
|
||||
body_system = sample.get('body_system', 'unknown')
|
||||
question_type = sample.get('question_type', 'unknown')
|
||||
|
||||
# Initialize category stats if not exists
|
||||
for level, key in [
|
||||
('medical_task', medical_task),
|
||||
('body_system', body_system),
|
||||
('question_type', question_type),
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {'correct': 0, 'total': 0}
|
||||
|
||||
# Record the judgment
|
||||
if processed_judge is not None:
|
||||
judged_answers.append(processed_judge)
|
||||
try:
|
||||
gold = v['gold']
|
||||
references.append(gold)
|
||||
except KeyError:
|
||||
get_logger().warning(
|
||||
f'No gold answer for {k}, use empty string as reference!')
|
||||
gold = ''
|
||||
references.append('')
|
||||
|
||||
# Check if the answer is correct (A means correct)
|
||||
is_correct = processed_judge == 'A'
|
||||
total_count += 1
|
||||
|
||||
if is_correct:
|
||||
total_correct += 1
|
||||
# Update category stats
|
||||
for level, key in [
|
||||
('medical_task', medical_task),
|
||||
('body_system', body_system),
|
||||
('question_type', question_type),
|
||||
]:
|
||||
stats[level][key]['correct'] += 1
|
||||
|
||||
# Update category totals
|
||||
for level, key in [
|
||||
('medical_task', medical_task),
|
||||
('body_system', body_system),
|
||||
('question_type', question_type),
|
||||
]:
|
||||
stats[level][key]['total'] += 1
|
||||
# Add to details
|
||||
details.append({
|
||||
'id': k,
|
||||
'question': sample['question'],
|
||||
'options': sample['options'],
|
||||
'origin_prompt': v['origin_prompt'],
|
||||
'llm_judge': processed_judge,
|
||||
'gold': gold,
|
||||
'is_correct': is_correct,
|
||||
'medical_task': medical_task,
|
||||
'body_system': body_system,
|
||||
'question_type': question_type,
|
||||
})
|
||||
|
||||
# Calculate overall accuracy with two decimal places
|
||||
overall_accuracy = (round(
|
||||
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
||||
|
||||
# Initialize results dictionary
|
||||
results = {
|
||||
'accuracy': overall_accuracy,
|
||||
'total_correct': total_correct,
|
||||
'total_count': total_count,
|
||||
'details': details,
|
||||
}
|
||||
|
||||
# Calculate accuracy for each category and flatten into results
|
||||
for level in stats:
|
||||
for key, value in stats[level].items():
|
||||
if value['total'] > 0:
|
||||
# Calculate accuracy with two decimal places
|
||||
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
||||
|
||||
# Create a flattened key for the category
|
||||
flat_key = f'HealthBench-{key}'
|
||||
|
||||
# Add to results
|
||||
results[flat_key] = accuracy
|
||||
|
||||
return results
|
@ -1,339 +0,0 @@
|
||||
"""This script evaluates a grader model on grading HealthBench rubrics. It
|
||||
effectively evaluates the evaluator against physician opinion, so we call it a
|
||||
meta-evaluation.
|
||||
|
||||
To run, use the following command (working directory should contain simple-
|
||||
evals folder): `python -m simple-evals.simple_evals --eval=healthbench_meta
|
||||
--model=gpt-4.1`
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Literal
|
||||
|
||||
import blobfile as bf
|
||||
|
||||
from . import common
|
||||
from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict
|
||||
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
|
||||
|
||||
INPUT_PATH = 'https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl'
|
||||
INDEX_STR_TEMPLATE = 'pairwise_{model_or_physician}_{metric}_{pred_str}'
|
||||
CLUSTER_STR_TEMPLATE = '{cluster}: {index_str}'
|
||||
|
||||
HEALTHBENCH_META_HTML_JINJA = (common.HTML_JINJA.replace(
|
||||
'<p>Correct Answer: {{ correct_answer }}</p>\n',
|
||||
'',
|
||||
) + "<p>Explanation for grader's label: {{ explanation }}</p>")
|
||||
|
||||
|
||||
class HealthBenchMetaEval(Eval):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grader_model: SamplerBase,
|
||||
num_examples: int | None = None,
|
||||
n_threads: int = 120,
|
||||
n_repeats: int = 1,
|
||||
):
|
||||
with bf.BlobFile(INPUT_PATH, 'rb') as f:
|
||||
examples = [json.loads(line) for line in f]
|
||||
print(f'Loaded {len(examples)} examples from {INPUT_PATH}')
|
||||
|
||||
rng = random.Random(0)
|
||||
|
||||
if num_examples is not None and len(examples) > num_examples:
|
||||
examples = rng.sample(examples, num_examples)
|
||||
|
||||
self.examples = examples * n_repeats
|
||||
self.grader_model = grader_model
|
||||
self.n_threads = n_threads
|
||||
|
||||
def grade_sample(
|
||||
self,
|
||||
grading_response_dict: dict,
|
||||
physician_labels: list[bool],
|
||||
category: str,
|
||||
) -> tuple[dict, bool | None, str]:
|
||||
metrics = {
|
||||
'num_physician_labels': len(physician_labels),
|
||||
'percent_physician_pos':
|
||||
sum(physician_labels) / len(physician_labels),
|
||||
}
|
||||
|
||||
grader_label = grading_response_dict['criteria_met']
|
||||
assert grader_label is True or grader_label is False
|
||||
metrics['model_predicted_positive'] = grader_label
|
||||
explanation = grading_response_dict.get('explanation',
|
||||
'No explanation provided')
|
||||
|
||||
category_metrics = {f'{category}: {k}': v for k, v in metrics.items()}
|
||||
metrics = {**metrics, **category_metrics}
|
||||
return metrics, grader_label, explanation
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
|
||||
def fn(row: dict) -> tuple[SingleEvalResult, bool | None]:
|
||||
convo_with_response = row['prompt'] + [
|
||||
dict(content=row['completion'], role='assistant')
|
||||
]
|
||||
prompt_str = '\n\n'.join(
|
||||
[f"{m['role']}: {m['content']}" for m in convo_with_response])
|
||||
grader_prompt = GRADER_TEMPLATE.replace('<<conversation>>',
|
||||
prompt_str)
|
||||
grader_prompt = grader_prompt.replace('<<rubric_item>>',
|
||||
row['rubric'])
|
||||
grader_convo = [dict(content=grader_prompt, role='user')]
|
||||
|
||||
while True:
|
||||
sampler_response = sampler(grader_convo)
|
||||
response_text = sampler_response.response_text
|
||||
actual_queried_grader_convo = (
|
||||
sampler_response.actual_queried_message_list)
|
||||
grading_response_dict = parse_json_to_dict(response_text)
|
||||
if 'criteria_met' in grading_response_dict:
|
||||
label = grading_response_dict['criteria_met']
|
||||
if label is True or label is False:
|
||||
break
|
||||
print('Grading failed due to bad JSON output, retrying...')
|
||||
|
||||
metrics, grader_label, explanation = self.grade_sample(
|
||||
grading_response_dict=grading_response_dict,
|
||||
physician_labels=row['binary_labels'],
|
||||
category=row['category'],
|
||||
)
|
||||
score = metrics['model_predicted_positive']
|
||||
|
||||
# Create HTML for each sample result
|
||||
html = common.jinja_env.from_string(
|
||||
HEALTHBENCH_META_HTML_JINJA).render(
|
||||
prompt_messages=actual_queried_grader_convo,
|
||||
next_message=dict(content=response_text, role='assistant'),
|
||||
score=metrics['model_predicted_positive'],
|
||||
extracted_answer=response_text,
|
||||
explanation=explanation,
|
||||
)
|
||||
convo = actual_queried_grader_convo + [
|
||||
dict(content=response_text, role='assistant')
|
||||
]
|
||||
return (
|
||||
SingleEvalResult(html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics=metrics),
|
||||
grader_label,
|
||||
)
|
||||
|
||||
# Run evaluation and collect results
|
||||
all_outputs = common.map_with_progress(fn, self.examples,
|
||||
self.n_threads)
|
||||
results: list[SingleEvalResult]
|
||||
grader_labels: list[bool]
|
||||
results, grader_labels = zip(*all_outputs)
|
||||
|
||||
# model pairwise agreement metrics
|
||||
model_agreement_metrics = compute_metrics_for_rater_by_class(
|
||||
self_pred_list=grader_labels,
|
||||
other_preds_list=[x['binary_labels'] for x in self.examples],
|
||||
cluster_list=[x['category'] for x in self.examples],
|
||||
model_or_physician='model',
|
||||
)
|
||||
|
||||
# physicians:
|
||||
physician_rating_lists = defaultdict(lambda: ([], [], []))
|
||||
for example in self.examples:
|
||||
for i in range(len(example['binary_labels'])):
|
||||
physician_id = example['anonymized_physician_ids'][i]
|
||||
self_pred = example['binary_labels'][i]
|
||||
other_preds = (example['binary_labels'][:i] +
|
||||
example['binary_labels'][i + 1:])
|
||||
cluster = example['category']
|
||||
physician_rating_lists[physician_id][0].append(self_pred)
|
||||
physician_rating_lists[physician_id][1].append(other_preds)
|
||||
physician_rating_lists[physician_id][2].append(cluster)
|
||||
|
||||
physician_agreement_metric_lists = defaultdict(dict)
|
||||
for physician_id, (
|
||||
physician_rating_list,
|
||||
other_preds_list,
|
||||
cluster_list,
|
||||
) in physician_rating_lists.items():
|
||||
physician_agreement_metrics = compute_metrics_for_rater_by_class(
|
||||
self_pred_list=physician_rating_list,
|
||||
other_preds_list=other_preds_list,
|
||||
cluster_list=cluster_list,
|
||||
model_or_physician='physician',
|
||||
)
|
||||
for k, v in physician_agreement_metrics.items():
|
||||
physician_agreement_metric_lists[k][physician_id] = v
|
||||
|
||||
# consolidate final metrics and add agreement metrics
|
||||
final_metrics = common.aggregate_results(
|
||||
results, default_stats=('mean', 'n_samples', 'bootstrap_std'))
|
||||
model_agreement_metrics_condensed: dict[str, float] = {
|
||||
k: v['value']
|
||||
for k, v in model_agreement_metrics.items()
|
||||
if v['value'] is not None
|
||||
}
|
||||
assert final_metrics.metrics is not None
|
||||
final_metrics.metrics.update(model_agreement_metrics_condensed)
|
||||
final_metrics.score = final_metrics.metrics[
|
||||
'pairwise_model_f1_balanced']
|
||||
|
||||
final_metrics.metadata = {
|
||||
'model_agreement_metrics': model_agreement_metrics,
|
||||
'physician_agreement_metric_lists':
|
||||
physician_agreement_metric_lists,
|
||||
}
|
||||
return final_metrics
|
||||
|
||||
|
||||
def compute_metrics_for_rater_by_class(
|
||||
self_pred_list: list[bool],
|
||||
other_preds_list: list[list[bool]],
|
||||
cluster_list: list[str],
|
||||
model_or_physician: Literal['model', 'physician'],
|
||||
) -> dict[str, dict[str, float | None]]:
|
||||
# get all the metrics for each cluster
|
||||
metric_lists = defaultdict(list)
|
||||
for self_pred, other_preds, cluster in zip(self_pred_list,
|
||||
other_preds_list,
|
||||
cluster_list,
|
||||
strict=True):
|
||||
self_pred_str = 'pos' if self_pred else 'neg'
|
||||
for other_pred in other_preds:
|
||||
# precision. based on the grader's labels -
|
||||
# i.e., calculated as TP / (TP + FP)
|
||||
# so a prediction should be recorded whenever self_pred is True
|
||||
precision_index_str = INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician,
|
||||
metric='precision',
|
||||
pred_str=self_pred_str,
|
||||
)
|
||||
metric_lists[precision_index_str].append(self_pred == other_pred)
|
||||
precision_cluster_str = CLUSTER_STR_TEMPLATE.format(
|
||||
cluster=cluster, index_str=precision_index_str)
|
||||
metric_lists[precision_cluster_str].append(self_pred == other_pred)
|
||||
|
||||
# recall. based on the ground truth labels -
|
||||
# i.e., calculated as TP / (TP + FN)
|
||||
# so a prediction should be recorded whenever other_pred is True
|
||||
other_pred_str = 'pos' if other_pred else 'neg'
|
||||
recall_index_str = INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician,
|
||||
metric='recall',
|
||||
pred_str=other_pred_str,
|
||||
)
|
||||
metric_lists[recall_index_str].append(self_pred == other_pred)
|
||||
recall_cluster_str = CLUSTER_STR_TEMPLATE.format(
|
||||
cluster=cluster, index_str=recall_index_str)
|
||||
metric_lists[recall_cluster_str].append(self_pred == other_pred)
|
||||
|
||||
metrics: dict[str, dict[str, float | None]] = {}
|
||||
for index_str, metric_list in metric_lists.items():
|
||||
n = len(metric_list)
|
||||
metric = sum(metric_list) / n if n > 0 else None
|
||||
metrics[index_str] = {
|
||||
'n': n,
|
||||
'value': metric,
|
||||
}
|
||||
|
||||
f1_metrics = get_f1_metrics(metrics)
|
||||
metrics.update(f1_metrics)
|
||||
|
||||
balanced_metrics = get_balanced_metrics(metrics)
|
||||
metrics.update(balanced_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def get_f1_metrics(
|
||||
metrics: dict[str, dict[str, float | None]],
|
||||
) -> dict[str, dict[str, float | None]]:
|
||||
f1_metrics: dict[str, dict[str, float | None]] = {}
|
||||
for precision_key_name in metrics:
|
||||
if 'precision' in precision_key_name:
|
||||
recall_key_name = precision_key_name.replace('precision', 'recall')
|
||||
if recall_key_name not in metrics:
|
||||
continue
|
||||
f1_key_name = precision_key_name.replace('precision', 'f1')
|
||||
assert f1_key_name not in metrics
|
||||
f1_metrics[f1_key_name] = compute_f1_metric(
|
||||
precision=metrics[precision_key_name],
|
||||
recall=metrics[recall_key_name],
|
||||
)
|
||||
|
||||
return f1_metrics
|
||||
|
||||
|
||||
def compute_f1_metric(
|
||||
precision: dict[str, float | None],
|
||||
recall: dict[str, float | None],
|
||||
) -> dict[str, float | None]:
|
||||
precision_n = precision['n']
|
||||
recall_n = recall['n']
|
||||
assert precision_n is not None and recall_n is not None, 'n_pos or n_neg is None'
|
||||
|
||||
precision_metric = precision['value']
|
||||
recall_metric = recall['value']
|
||||
if precision_metric is None or recall_metric is None:
|
||||
f1_metric = None
|
||||
n_f1 = (
|
||||
precision_n + recall_n
|
||||
) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting
|
||||
elif precision_metric == 0 and recall_metric == 0:
|
||||
f1_metric = 0.0
|
||||
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
|
||||
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
|
||||
else:
|
||||
f1_metric = (2 * (precision_metric * recall_metric) /
|
||||
(precision_metric + recall_metric))
|
||||
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
|
||||
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
|
||||
|
||||
return {
|
||||
'n': n_f1,
|
||||
'value': f1_metric,
|
||||
}
|
||||
|
||||
|
||||
def get_balanced_metrics(
|
||||
metrics: dict[str, dict[str, float | None]],
|
||||
) -> dict[str, dict[str, float | None]]:
|
||||
balanced_metrics: dict[str, dict[str, float | None]] = {}
|
||||
for pos_key_name in metrics:
|
||||
if 'pos' in pos_key_name:
|
||||
neg_key_name = pos_key_name.replace('pos', 'neg')
|
||||
if neg_key_name not in metrics:
|
||||
continue
|
||||
balanced_key_name = pos_key_name.replace('pos', 'balanced')
|
||||
assert balanced_key_name not in metrics
|
||||
balanced_metrics[balanced_key_name] = compute_balanced_metric(
|
||||
metric_pos=metrics[pos_key_name],
|
||||
metric_neg=metrics[neg_key_name],
|
||||
)
|
||||
|
||||
return balanced_metrics
|
||||
|
||||
|
||||
def compute_balanced_metric(
|
||||
metric_pos: dict[str, float | None],
|
||||
metric_neg: dict[str, float | None],
|
||||
) -> dict[str, float | None]:
|
||||
n_pos = metric_pos['n']
|
||||
n_neg = metric_neg['n']
|
||||
assert n_pos is not None and n_neg is not None, 'n_pos or n_neg is None'
|
||||
|
||||
pos_metric = metric_pos['value']
|
||||
neg_metric = metric_neg['value']
|
||||
if pos_metric is None or neg_metric is None:
|
||||
metric = None
|
||||
else:
|
||||
metric = (pos_metric + neg_metric) / 2
|
||||
|
||||
return {
|
||||
'n': n_pos + n_neg,
|
||||
# note: this overcounts samples going towards the balanced F1
|
||||
'value': metric,
|
||||
}
|
@ -1,165 +0,0 @@
|
||||
from . import healthbench_meta_eval
|
||||
|
||||
|
||||
def test_compute_agreement_for_rater_by_class():
|
||||
self_pred_list = [True, False, True]
|
||||
other_preds_list = [[True, True, False], [True, False], [False]]
|
||||
cluster_list = ['a', 'a', 'b']
|
||||
model_or_physician = 'model'
|
||||
metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class(
|
||||
self_pred_list, other_preds_list, cluster_list, model_or_physician
|
||||
)
|
||||
|
||||
# precision overall
|
||||
index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='precision', pred_str='pos'
|
||||
)
|
||||
index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='precision', pred_str='neg'
|
||||
)
|
||||
overall_pos_precision = metrics[index_str_pos_precision]
|
||||
overall_neg_precision = metrics[index_str_neg_precision]
|
||||
expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1)
|
||||
expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0)
|
||||
assert overall_pos_precision['value'] == expected_overall_pos_precision
|
||||
assert overall_neg_precision['value'] == expected_overall_neg_precision
|
||||
assert overall_pos_precision['n'] == 4
|
||||
assert overall_neg_precision['n'] == 2
|
||||
|
||||
# recall overall
|
||||
index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='recall', pred_str='pos'
|
||||
)
|
||||
index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='recall', pred_str='neg'
|
||||
)
|
||||
overall_pos_recall = metrics[index_str_pos_recall]
|
||||
overall_neg_recall = metrics[index_str_neg_recall]
|
||||
expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0)
|
||||
expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1)
|
||||
assert overall_pos_recall['value'] == expected_overall_pos_recall
|
||||
assert overall_neg_recall['value'] == expected_overall_neg_recall
|
||||
assert overall_pos_recall['n'] == 3
|
||||
assert overall_neg_recall['n'] == 3
|
||||
|
||||
# f1 overall
|
||||
index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='f1', pred_str='pos'
|
||||
)
|
||||
index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='f1', pred_str='neg'
|
||||
)
|
||||
overall_pos_f1 = metrics[index_str_pos_f1]
|
||||
overall_neg_f1 = metrics[index_str_neg_f1]
|
||||
expected_overall_pos_f1 = (
|
||||
2
|
||||
* expected_overall_pos_precision
|
||||
* expected_overall_pos_recall
|
||||
/ (expected_overall_pos_precision + expected_overall_pos_recall)
|
||||
)
|
||||
expected_overall_neg_f1 = (
|
||||
2
|
||||
* expected_overall_neg_precision
|
||||
* expected_overall_neg_recall
|
||||
/ (expected_overall_neg_precision + expected_overall_neg_recall)
|
||||
)
|
||||
assert overall_pos_f1['value'] == expected_overall_pos_f1
|
||||
assert overall_neg_f1['value'] == expected_overall_neg_f1
|
||||
|
||||
# balanced f1
|
||||
index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
||||
model_or_physician=model_or_physician, metric='f1', pred_str='balanced'
|
||||
)
|
||||
balanced_f1 = metrics[index_str_balanced_f1]
|
||||
expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2
|
||||
assert balanced_f1['value'] == expected_balanced_f1
|
||||
|
||||
# by cluster
|
||||
# precision
|
||||
cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='a', index_str=index_str_pos_precision
|
||||
)
|
||||
cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='a', index_str=index_str_neg_precision
|
||||
)
|
||||
cluster_a_pos_precision = metrics[cluster_a_str_pos_precision]
|
||||
cluster_a_neg_precision = metrics[cluster_a_str_neg_precision]
|
||||
assert cluster_a_pos_precision['value'] == (
|
||||
# example 1, 2 in order
|
||||
(2 + 0) / (3 + 0)
|
||||
)
|
||||
assert cluster_a_neg_precision['value'] == (
|
||||
# example 1, 2 in order
|
||||
(0 + 1) / (0 + 2)
|
||||
)
|
||||
assert cluster_a_pos_precision['n'] == 3
|
||||
assert cluster_a_neg_precision['n'] == 2
|
||||
|
||||
# recall
|
||||
cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='a', index_str=index_str_pos_recall
|
||||
)
|
||||
cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='a', index_str=index_str_neg_recall
|
||||
)
|
||||
cluster_a_pos_recall = metrics[cluster_a_str_pos_recall]
|
||||
cluster_a_neg_recall = metrics[cluster_a_str_neg_recall]
|
||||
assert cluster_a_pos_recall['value'] == (
|
||||
# example 1, 2 in order
|
||||
(2 + 0) / (2 + 1)
|
||||
)
|
||||
assert cluster_a_neg_recall['value'] == (
|
||||
# example 1, 2 in order
|
||||
(0 + 1) / (1 + 1)
|
||||
)
|
||||
assert cluster_a_pos_recall['n'] == 3
|
||||
assert cluster_a_neg_recall['n'] == 2
|
||||
|
||||
# cluster B
|
||||
# precision
|
||||
cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_pos_precision
|
||||
)
|
||||
cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_neg_precision
|
||||
)
|
||||
cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision]
|
||||
assert cluster_b_str_neg_precision not in metrics
|
||||
assert cluster_b_str_pos_precision['value'] == (
|
||||
# example 3 only
|
||||
0 / 1
|
||||
)
|
||||
assert cluster_b_str_pos_precision['n'] == 1
|
||||
|
||||
# recall
|
||||
cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_pos_recall
|
||||
)
|
||||
cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_neg_recall
|
||||
)
|
||||
assert cluster_b_str_pos_recall not in metrics
|
||||
cluster_b_neg_recall = metrics[cluster_b_str_neg_recall]
|
||||
assert cluster_b_neg_recall['value'] == (
|
||||
# example 3 only
|
||||
0 / 1
|
||||
)
|
||||
assert cluster_b_neg_recall['n'] == 1
|
||||
|
||||
# f1
|
||||
index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_pos_f1
|
||||
)
|
||||
index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_neg_f1
|
||||
)
|
||||
index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
||||
cluster='b', index_str=index_str_balanced_f1
|
||||
)
|
||||
assert index_str_pos_f1 not in metrics
|
||||
assert index_str_neg_f1 not in metrics
|
||||
assert index_str_balanced_f1 not in metrics
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_compute_agreement_for_rater_by_class()
|
@ -1,103 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import anthropic
|
||||
|
||||
from .. import common
|
||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
CLAUDE_SYSTEM_MESSAGE_LMSYS = (
|
||||
'The assistant is Claude, created by Anthropic. The current date is '
|
||||
"{currentDateTime}. Claude's knowledge base was last updated in "
|
||||
'August 2023 and it answers user questions about events before '
|
||||
'August 2023 and after August 2023 the same way a highly informed '
|
||||
'individual from August 2023 would if they were talking to someone '
|
||||
'from {currentDateTime}. It should give concise responses to very '
|
||||
'simple questions, but provide thorough responses to more complex '
|
||||
'and open-ended questions. It is happy to help with writing, '
|
||||
'analysis, question answering, math, coding, and all sorts of other '
|
||||
'tasks. It uses markdown for coding. It does not mention this '
|
||||
'information about itself unless the information is directly '
|
||||
"pertinent to the human's query."
|
||||
).format(currentDateTime='2024-04-01')
|
||||
# reference: https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894
|
||||
|
||||
|
||||
class ClaudeCompletionSampler(SamplerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
system_message: str | None = None,
|
||||
temperature: float = 0.0, # default in Anthropic example
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.client = anthropic.Anthropic()
|
||||
self.api_key = os.environ.get('ANTHROPIC_API_KEY') # please set your API_KEY
|
||||
self.model = model
|
||||
self.system_message = system_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = 'base64'
|
||||
|
||||
def _handle_image(
|
||||
self,
|
||||
image: str,
|
||||
encoding: str = 'base64',
|
||||
format: str = 'png',
|
||||
fovea: int = 768,
|
||||
):
|
||||
new_image = {
|
||||
'type': 'image',
|
||||
'source': {
|
||||
'type': encoding,
|
||||
'media_type': f'image/{format}',
|
||||
'data': image,
|
||||
},
|
||||
}
|
||||
return new_image
|
||||
|
||||
def _handle_text(self, text):
|
||||
return {'type': 'text', 'text': text}
|
||||
|
||||
def _pack_message(self, role, content):
|
||||
return {'role': str(role), 'content': content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
if not common.has_only_user_assistant_messages(message_list):
|
||||
raise ValueError(f'Claude sampler only supports user and assistant messages, got {message_list}')
|
||||
if self.system_message:
|
||||
response_message = self.client.messages.create(
|
||||
model=self.model,
|
||||
system=self.system_message,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
messages=message_list,
|
||||
)
|
||||
claude_input_messages: MessageList = [{'role': 'system', 'content': self.system_message}] + message_list
|
||||
else:
|
||||
response_message = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
messages=message_list,
|
||||
)
|
||||
claude_input_messages = message_list
|
||||
response_text = response_message.content[0].text
|
||||
return SamplerResponse(
|
||||
response_text=response_text,
|
||||
response_metadata={},
|
||||
actual_queried_message_list=claude_input_messages,
|
||||
)
|
||||
except anthropic.RateLimitError as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
@ -1,78 +0,0 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
|
||||
class OChatCompletionSampler(SamplerBase):
|
||||
"""Sample from OpenAI's chat completion API for o series models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
reasoning_effort: str | None = None,
|
||||
model: str = 'o1-mini',
|
||||
):
|
||||
self.api_key_name = 'OPENAI_API_KEY'
|
||||
self.client = OpenAI()
|
||||
# using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
|
||||
self.model = model
|
||||
self.image_format = 'url'
|
||||
self.reasoning_effort = reasoning_effort
|
||||
|
||||
def _handle_image(
|
||||
self,
|
||||
image: str,
|
||||
encoding: str = 'base64',
|
||||
format: str = 'png',
|
||||
fovea: int = 768,
|
||||
):
|
||||
new_image = {
|
||||
'type': 'image_url',
|
||||
'image_url': {
|
||||
'url': f'data:image/{format};{encoding},{image}',
|
||||
},
|
||||
}
|
||||
return new_image
|
||||
|
||||
def _handle_text(self, text: str):
|
||||
return {'type': 'text', 'text': text}
|
||||
|
||||
def _pack_message(self, role: str, content: Any):
|
||||
return {'role': str(role), 'content': content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message_list,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
return SamplerResponse(
|
||||
response_text=content,
|
||||
response_metadata={'usage': response.usage},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
||||
except openai.BadRequestError as e:
|
||||
print('Bad Request Error', e)
|
||||
return SamplerResponse(
|
||||
response_text='',
|
||||
response_metadata={'usage': None},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
@ -1,97 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
|
||||
class ResponsesSampler(SamplerBase):
|
||||
"""Sample from OpenAI's responses API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = 'gpt-4.1',
|
||||
system_message: str | None = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 1024,
|
||||
reasoning_model: bool = False,
|
||||
reasoning_effort: str | None = None,
|
||||
):
|
||||
self.api_key_name = 'OPENAI_API_KEY'
|
||||
assert os.environ.get('OPENAI_API_KEY'), 'Please set OPENAI_API_KEY'
|
||||
self.client = OpenAI()
|
||||
self.model = model
|
||||
self.system_message = system_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = 'url'
|
||||
self.reasoning_model = reasoning_model
|
||||
self.reasoning_effort = reasoning_effort
|
||||
|
||||
def _handle_image(
|
||||
self,
|
||||
image: str,
|
||||
encoding: str = 'base64',
|
||||
format: str = 'png',
|
||||
fovea: int = 768,
|
||||
) -> dict[str, Any]:
|
||||
new_image = {
|
||||
'type': 'input_image',
|
||||
'image_url': f'data:image/{format};{encoding},{image}',
|
||||
}
|
||||
return new_image
|
||||
|
||||
def _handle_text(self, text: str) -> dict[str, Any]:
|
||||
return {'type': 'input_text', 'text': text}
|
||||
|
||||
def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
|
||||
return {'role': role, 'content': content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
if self.system_message:
|
||||
message_list = [
|
||||
self._pack_message('developer', self.system_message)
|
||||
] + message_list
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
if self.reasoning_model:
|
||||
reasoning = ({
|
||||
'effort': self.reasoning_effort
|
||||
} if self.reasoning_effort else None)
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=message_list,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
else:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=message_list,
|
||||
temperature=self.temperature,
|
||||
max_output_tokens=self.max_tokens,
|
||||
)
|
||||
return SamplerResponse(
|
||||
response_text=response.output_text,
|
||||
response_metadata={'usage': response.usage},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
print('Bad Request Error', e)
|
||||
return SamplerResponse(
|
||||
response_text='',
|
||||
response_metadata={'usage': None},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
55
opencompass/datasets/healthbench/types.py
Normal file
55
opencompass/datasets/healthbench/types.py
Normal file
@ -0,0 +1,55 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerResponse:
|
||||
"""Response from a sampler."""
|
||||
response_text: str
|
||||
actual_queried_message_list: MessageList
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
class SamplerBase:
|
||||
"""Base class for defining a sampling model, which can be evaluated, or
|
||||
used as part of the grading process."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
message_list: MessageList,
|
||||
) -> SamplerResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
"""Result of running an evaluation (usually consisting of many samples)"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleEvalResult:
|
||||
"""Result of evaluating a single sample."""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
example_level_metadata: dict[str, Any] | None = (
|
||||
None # Extra data such as rubric scores or sollen
|
||||
)
|
||||
|
||||
|
||||
class Eval:
|
||||
"""Base class for defining an evaluation."""
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user