mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix irrelevant files
This commit is contained in:
parent
7278a4ed19
commit
95d8d2ba4d
@ -1,88 +0,0 @@
|
|||||||
"""
|
|
||||||
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
|
||||||
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
|
||||||
https://arxiv.org/abs/2311.12022
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
|
|
||||||
import pandas
|
|
||||||
|
|
||||||
from . import common
|
|
||||||
from .common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA,
|
|
||||||
format_multichoice_question)
|
|
||||||
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
|
|
||||||
|
|
||||||
|
|
||||||
class GPQAEval(Eval):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_repeats: int = 4,
|
|
||||||
variant: str = 'diamond',
|
|
||||||
num_examples: int
|
|
||||||
| None = None, # restrict to a subset of the data for debugging
|
|
||||||
):
|
|
||||||
df = pandas.read_csv(
|
|
||||||
f'https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv'
|
|
||||||
)
|
|
||||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
|
||||||
rng = random.Random(0)
|
|
||||||
if num_examples:
|
|
||||||
assert n_repeats == 1, 'n_repeats only supported for num_examples = None'
|
|
||||||
examples = rng.sample(examples, num_examples)
|
|
||||||
examples = examples * n_repeats
|
|
||||||
examples = [
|
|
||||||
example | {
|
|
||||||
'permutation': rng.sample(range(4), 4)
|
|
||||||
} for example in examples
|
|
||||||
]
|
|
||||||
self.examples = examples
|
|
||||||
self.n_repeats = n_repeats
|
|
||||||
|
|
||||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
|
||||||
|
|
||||||
def fn(row: dict):
|
|
||||||
choices = [
|
|
||||||
row['Correct Answer'],
|
|
||||||
row['Incorrect Answer 1'],
|
|
||||||
row['Incorrect Answer 2'],
|
|
||||||
row['Incorrect Answer 3'],
|
|
||||||
]
|
|
||||||
choices = [choices[i] for i in row['permutation']]
|
|
||||||
correct_index = choices.index(row['Correct Answer'])
|
|
||||||
correct_answer = 'ABCD'[correct_index]
|
|
||||||
choices_dict = dict(A=choices[0],
|
|
||||||
B=choices[1],
|
|
||||||
C=choices[2],
|
|
||||||
D=choices[3],
|
|
||||||
Question=row['Question'])
|
|
||||||
prompt_messages = [
|
|
||||||
sampler._pack_message(
|
|
||||||
content=format_multichoice_question(choices_dict),
|
|
||||||
role='user')
|
|
||||||
]
|
|
||||||
sampler_response = sampler(prompt_messages)
|
|
||||||
response_text = sampler_response.response_text
|
|
||||||
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
|
|
||||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
|
||||||
extracted_answer = match.group(1) if match else None
|
|
||||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
|
||||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
|
||||||
prompt_messages=actual_queried_prompt_messages,
|
|
||||||
next_message=dict(content=response_text, role='assistant'),
|
|
||||||
score=score,
|
|
||||||
correct_answer=correct_answer,
|
|
||||||
extracted_answer=extracted_answer,
|
|
||||||
)
|
|
||||||
convo = actual_queried_prompt_messages + [
|
|
||||||
dict(content=response_text, role='assistant')
|
|
||||||
]
|
|
||||||
return SingleEvalResult(html=html,
|
|
||||||
score=score,
|
|
||||||
convo=convo,
|
|
||||||
metrics={'chars': len(response_text)})
|
|
||||||
|
|
||||||
results = common.map_with_progress(fn, self.examples)
|
|
||||||
return common.aggregate_results(results)
|
|
@ -1,16 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from opencompass.openicl import BaseEvaluator
|
from opencompass.openicl import BaseEvaluator
|
||||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
from opencompass.registry import LOAD_DATASET
|
||||||
from opencompass.utils import get_logger
|
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
from . import common
|
from . import common
|
||||||
from .healthbench_eval import HealthBenchEval, RubricItem
|
from .healthbench_eval import RubricItem
|
||||||
from .healthbench_meta_eval import HealthBenchMetaEval
|
|
||||||
from .sampler.chat_completion_sampler import ChatCompletionSampler
|
from .sampler.chat_completion_sampler import ChatCompletionSampler
|
||||||
from .types import SingleEvalResult
|
from .types import SingleEvalResult
|
||||||
|
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
from .healthbench_eval import RubricItem, calculate_score
|
|
||||||
|
|
||||||
|
|
||||||
def test_calculate_score():
|
|
||||||
rubric_items = [
|
|
||||||
RubricItem(criterion='test', points=7, tags=[]),
|
|
||||||
RubricItem(criterion='test', points=5, tags=[]),
|
|
||||||
RubricItem(criterion='test', points=10, tags=[]),
|
|
||||||
RubricItem(criterion='test', points=-6, tags=[]),
|
|
||||||
]
|
|
||||||
grading_response_list = [
|
|
||||||
{
|
|
||||||
'criteria_met': True
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'criteria_met': False
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'criteria_met': True
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'criteria_met': True
|
|
||||||
},
|
|
||||||
]
|
|
||||||
total_possible = 7 + 5 + 10
|
|
||||||
achieved = 7 + 0 + 10 - 6
|
|
||||||
assert (calculate_score(rubric_items, grading_response_list) == achieved /
|
|
||||||
total_possible)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_calculate_score()
|
|
@ -1,262 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
|
|
||||||
from opencompass.openicl import BaseEvaluator
|
|
||||||
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
||||||
from opencompass.utils import get_logger
|
|
||||||
|
|
||||||
from ..base import BaseDataset
|
|
||||||
from .healthbench_eval import HealthBenchEval, RubricItem
|
|
||||||
from .healthbench_meta_eval import HealthBenchMetaEval
|
|
||||||
|
|
||||||
|
|
||||||
def _parse(item):
|
|
||||||
item['rubrics'] = [RubricItem.from_dict(d) for d in item['rubrics']]
|
|
||||||
return item
|
|
||||||
|
|
||||||
def _parse_meta(item):
|
|
||||||
item['rubrics'] = [RubricItem.from_dict(d) for d in item['rubrics']]
|
|
||||||
return item
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
|
||||||
class HealthBenchDataset(BaseDataset):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(path: str, prompt_mode: str, **kwargs):
|
|
||||||
subset = kwargs.get('subset')
|
|
||||||
# nrepeats=1
|
|
||||||
# nthreads = 1
|
|
||||||
match subset:
|
|
||||||
case 'healthbench':
|
|
||||||
data_files = {'test': '2025-05-07-06-14-12_oss_eval.jsonl'}
|
|
||||||
return HealthBenchEval(
|
|
||||||
grader_model=grading_sampler,
|
|
||||||
n_repeats=1,
|
|
||||||
n_threads=1,
|
|
||||||
subset_name=None,
|
|
||||||
)
|
|
||||||
case 'healthbench_hard':
|
|
||||||
data_files = {'test': 'hard_2025-05-08-21-00-10.jsonl'}
|
|
||||||
return HealthBenchEval(
|
|
||||||
grader_model=grading_sampler,
|
|
||||||
n_repeats=1,
|
|
||||||
n_threads=1,
|
|
||||||
subset_name='hard',
|
|
||||||
)
|
|
||||||
case 'healthbench_consensus':
|
|
||||||
data_files = {'test': 'consensus_2025-05-09-20-00-46.jsonl'}
|
|
||||||
return HealthBenchEval(
|
|
||||||
grader_model=grading_sampler,
|
|
||||||
n_repeats=1,
|
|
||||||
n_threads=1,
|
|
||||||
subset_name='consensus',
|
|
||||||
)
|
|
||||||
case 'healthbench_meta':
|
|
||||||
data_files = {'test': '2025-05-07-06-14-12_oss_meta_eval.jsonl' }
|
|
||||||
return HealthBenchMetaEval(
|
|
||||||
grader_model=grading_sampler,
|
|
||||||
n_repeats=1,
|
|
||||||
n_threads=1,
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
raise Exception(f'Unrecognized eval type: {eval_name}')
|
|
||||||
|
|
||||||
dataset = load_dataset(path, data_files=data_files, split='test')
|
|
||||||
|
|
||||||
dataset = dataset.map(lambda item: _parse(item, prompt_mode))
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class HealthBenchEvaluator(BaseEvaluator):
|
|
||||||
|
|
||||||
def score(self, predictions, references, test_set):
|
|
||||||
method = test_set['prompt_mode'][0]
|
|
||||||
|
|
||||||
if len(predictions) != len(references):
|
|
||||||
return {'error': 'preds and refrs have different length'}
|
|
||||||
correct = 0
|
|
||||||
count = 0
|
|
||||||
details = []
|
|
||||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
|
||||||
i = answer_cleansing(method, i, test_set['options'][idx],
|
|
||||||
test_set['label'][idx])
|
|
||||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
|
||||||
count += 1
|
|
||||||
if i == j:
|
|
||||||
correct += 1
|
|
||||||
detail['correct'] = True
|
|
||||||
details.append(detail)
|
|
||||||
result = {'accuracy': 100 * correct / count, 'details': details}
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@TEXT_POSTPROCESSORS.register_module()
|
|
||||||
def answer_cleansing(
|
|
||||||
method: str,
|
|
||||||
prediction: str,
|
|
||||||
options: list,
|
|
||||||
label: str,
|
|
||||||
) -> str:
|
|
||||||
|
|
||||||
# Clean up unwanted phrases in the prediction
|
|
||||||
for unwanted_phrase in [
|
|
||||||
'I understand',
|
|
||||||
'A through J',
|
|
||||||
'A through E',
|
|
||||||
'A through D',
|
|
||||||
]:
|
|
||||||
prediction = prediction.replace(unwanted_phrase, '')
|
|
||||||
|
|
||||||
options_num = len(options)
|
|
||||||
options = [chr(65 + i) for i in range(options_num)]
|
|
||||||
options_str = r'\b(' + '|'.join(options) + r')\b'
|
|
||||||
prediction = re.findall(options_str, prediction)
|
|
||||||
|
|
||||||
if len(prediction) == 0:
|
|
||||||
prediction = []
|
|
||||||
else:
|
|
||||||
# If there is a "label" and its length is 1,
|
|
||||||
# process prediction accordingly
|
|
||||||
if len(label) == 1:
|
|
||||||
if method == 'few-shot':
|
|
||||||
answer_flag = True if len(prediction) > 1 else False
|
|
||||||
# choose the first or last element based on the answer_flag
|
|
||||||
if answer_flag:
|
|
||||||
prediction = [prediction[0]]
|
|
||||||
else:
|
|
||||||
prediction = [prediction[-1]]
|
|
||||||
elif method == 'zero-shot':
|
|
||||||
# choose the first element in list
|
|
||||||
prediction = [prediction[0]]
|
|
||||||
else:
|
|
||||||
raise ValueError('Method is not properly defined ...')
|
|
||||||
|
|
||||||
# Remove trailing period if it exists
|
|
||||||
if prediction[0] and prediction[0].endswith('.'):
|
|
||||||
prediction[0] = prediction[0][:-1]
|
|
||||||
|
|
||||||
return prediction[0]
|
|
||||||
|
|
||||||
|
|
||||||
def _generic_llmjudge_postprocess(judgement: str):
|
|
||||||
match = re.search(r'(A|B)', judgement)
|
|
||||||
grade_letter = (match.group(0) if match else 'B'
|
|
||||||
) # Default to "INCORRECT" if no match
|
|
||||||
return grade_letter
|
|
||||||
|
|
||||||
|
|
||||||
def HealthBench_llmjudge_postprocess(
|
|
||||||
output: dict,
|
|
||||||
output_path: str,
|
|
||||||
dataset: Dataset,
|
|
||||||
) -> dict:
|
|
||||||
# Get the original dataset
|
|
||||||
original_dataset = dataset.reader.dataset['test']
|
|
||||||
|
|
||||||
judged_answers = []
|
|
||||||
original_responses = []
|
|
||||||
references = []
|
|
||||||
details = []
|
|
||||||
|
|
||||||
# Initialize statistics dictionaries
|
|
||||||
stats = {'medical_task': {}, 'body_system': {}, 'question_type': {}}
|
|
||||||
|
|
||||||
total_correct = 0
|
|
||||||
total_count = 0
|
|
||||||
|
|
||||||
# Process each sample
|
|
||||||
for k, v in output.items():
|
|
||||||
idx = int(k) # Convert key to integer for indexing
|
|
||||||
original_responses.append(v['prediction'])
|
|
||||||
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
|
||||||
|
|
||||||
# Get category information from the dataset
|
|
||||||
sample = original_dataset[idx]
|
|
||||||
medical_task = sample.get('medical_task', 'unknown')
|
|
||||||
body_system = sample.get('body_system', 'unknown')
|
|
||||||
question_type = sample.get('question_type', 'unknown')
|
|
||||||
|
|
||||||
# Initialize category stats if not exists
|
|
||||||
for level, key in [
|
|
||||||
('medical_task', medical_task),
|
|
||||||
('body_system', body_system),
|
|
||||||
('question_type', question_type),
|
|
||||||
]:
|
|
||||||
if key not in stats[level]:
|
|
||||||
stats[level][key] = {'correct': 0, 'total': 0}
|
|
||||||
|
|
||||||
# Record the judgment
|
|
||||||
if processed_judge is not None:
|
|
||||||
judged_answers.append(processed_judge)
|
|
||||||
try:
|
|
||||||
gold = v['gold']
|
|
||||||
references.append(gold)
|
|
||||||
except KeyError:
|
|
||||||
get_logger().warning(
|
|
||||||
f'No gold answer for {k}, use empty string as reference!')
|
|
||||||
gold = ''
|
|
||||||
references.append('')
|
|
||||||
|
|
||||||
# Check if the answer is correct (A means correct)
|
|
||||||
is_correct = processed_judge == 'A'
|
|
||||||
total_count += 1
|
|
||||||
|
|
||||||
if is_correct:
|
|
||||||
total_correct += 1
|
|
||||||
# Update category stats
|
|
||||||
for level, key in [
|
|
||||||
('medical_task', medical_task),
|
|
||||||
('body_system', body_system),
|
|
||||||
('question_type', question_type),
|
|
||||||
]:
|
|
||||||
stats[level][key]['correct'] += 1
|
|
||||||
|
|
||||||
# Update category totals
|
|
||||||
for level, key in [
|
|
||||||
('medical_task', medical_task),
|
|
||||||
('body_system', body_system),
|
|
||||||
('question_type', question_type),
|
|
||||||
]:
|
|
||||||
stats[level][key]['total'] += 1
|
|
||||||
# Add to details
|
|
||||||
details.append({
|
|
||||||
'id': k,
|
|
||||||
'question': sample['question'],
|
|
||||||
'options': sample['options'],
|
|
||||||
'origin_prompt': v['origin_prompt'],
|
|
||||||
'llm_judge': processed_judge,
|
|
||||||
'gold': gold,
|
|
||||||
'is_correct': is_correct,
|
|
||||||
'medical_task': medical_task,
|
|
||||||
'body_system': body_system,
|
|
||||||
'question_type': question_type,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Calculate overall accuracy with two decimal places
|
|
||||||
overall_accuracy = (round(
|
|
||||||
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
|
||||||
|
|
||||||
# Initialize results dictionary
|
|
||||||
results = {
|
|
||||||
'accuracy': overall_accuracy,
|
|
||||||
'total_correct': total_correct,
|
|
||||||
'total_count': total_count,
|
|
||||||
'details': details,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Calculate accuracy for each category and flatten into results
|
|
||||||
for level in stats:
|
|
||||||
for key, value in stats[level].items():
|
|
||||||
if value['total'] > 0:
|
|
||||||
# Calculate accuracy with two decimal places
|
|
||||||
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
|
||||||
|
|
||||||
# Create a flattened key for the category
|
|
||||||
flat_key = f'HealthBench-{key}'
|
|
||||||
|
|
||||||
# Add to results
|
|
||||||
results[flat_key] = accuracy
|
|
||||||
|
|
||||||
return results
|
|
@ -1,339 +0,0 @@
|
|||||||
"""This script evaluates a grader model on grading HealthBench rubrics. It
|
|
||||||
effectively evaluates the evaluator against physician opinion, so we call it a
|
|
||||||
meta-evaluation.
|
|
||||||
|
|
||||||
To run, use the following command (working directory should contain simple-
|
|
||||||
evals folder): `python -m simple-evals.simple_evals --eval=healthbench_meta
|
|
||||||
--model=gpt-4.1`
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import blobfile as bf
|
|
||||||
|
|
||||||
from . import common
|
|
||||||
from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict
|
|
||||||
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
|
|
||||||
|
|
||||||
INPUT_PATH = 'https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl'
|
|
||||||
INDEX_STR_TEMPLATE = 'pairwise_{model_or_physician}_{metric}_{pred_str}'
|
|
||||||
CLUSTER_STR_TEMPLATE = '{cluster}: {index_str}'
|
|
||||||
|
|
||||||
HEALTHBENCH_META_HTML_JINJA = (common.HTML_JINJA.replace(
|
|
||||||
'<p>Correct Answer: {{ correct_answer }}</p>\n',
|
|
||||||
'',
|
|
||||||
) + "<p>Explanation for grader's label: {{ explanation }}</p>")
|
|
||||||
|
|
||||||
|
|
||||||
class HealthBenchMetaEval(Eval):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grader_model: SamplerBase,
|
|
||||||
num_examples: int | None = None,
|
|
||||||
n_threads: int = 120,
|
|
||||||
n_repeats: int = 1,
|
|
||||||
):
|
|
||||||
with bf.BlobFile(INPUT_PATH, 'rb') as f:
|
|
||||||
examples = [json.loads(line) for line in f]
|
|
||||||
print(f'Loaded {len(examples)} examples from {INPUT_PATH}')
|
|
||||||
|
|
||||||
rng = random.Random(0)
|
|
||||||
|
|
||||||
if num_examples is not None and len(examples) > num_examples:
|
|
||||||
examples = rng.sample(examples, num_examples)
|
|
||||||
|
|
||||||
self.examples = examples * n_repeats
|
|
||||||
self.grader_model = grader_model
|
|
||||||
self.n_threads = n_threads
|
|
||||||
|
|
||||||
def grade_sample(
|
|
||||||
self,
|
|
||||||
grading_response_dict: dict,
|
|
||||||
physician_labels: list[bool],
|
|
||||||
category: str,
|
|
||||||
) -> tuple[dict, bool | None, str]:
|
|
||||||
metrics = {
|
|
||||||
'num_physician_labels': len(physician_labels),
|
|
||||||
'percent_physician_pos':
|
|
||||||
sum(physician_labels) / len(physician_labels),
|
|
||||||
}
|
|
||||||
|
|
||||||
grader_label = grading_response_dict['criteria_met']
|
|
||||||
assert grader_label is True or grader_label is False
|
|
||||||
metrics['model_predicted_positive'] = grader_label
|
|
||||||
explanation = grading_response_dict.get('explanation',
|
|
||||||
'No explanation provided')
|
|
||||||
|
|
||||||
category_metrics = {f'{category}: {k}': v for k, v in metrics.items()}
|
|
||||||
metrics = {**metrics, **category_metrics}
|
|
||||||
return metrics, grader_label, explanation
|
|
||||||
|
|
||||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
|
||||||
|
|
||||||
def fn(row: dict) -> tuple[SingleEvalResult, bool | None]:
|
|
||||||
convo_with_response = row['prompt'] + [
|
|
||||||
dict(content=row['completion'], role='assistant')
|
|
||||||
]
|
|
||||||
prompt_str = '\n\n'.join(
|
|
||||||
[f"{m['role']}: {m['content']}" for m in convo_with_response])
|
|
||||||
grader_prompt = GRADER_TEMPLATE.replace('<<conversation>>',
|
|
||||||
prompt_str)
|
|
||||||
grader_prompt = grader_prompt.replace('<<rubric_item>>',
|
|
||||||
row['rubric'])
|
|
||||||
grader_convo = [dict(content=grader_prompt, role='user')]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
sampler_response = sampler(grader_convo)
|
|
||||||
response_text = sampler_response.response_text
|
|
||||||
actual_queried_grader_convo = (
|
|
||||||
sampler_response.actual_queried_message_list)
|
|
||||||
grading_response_dict = parse_json_to_dict(response_text)
|
|
||||||
if 'criteria_met' in grading_response_dict:
|
|
||||||
label = grading_response_dict['criteria_met']
|
|
||||||
if label is True or label is False:
|
|
||||||
break
|
|
||||||
print('Grading failed due to bad JSON output, retrying...')
|
|
||||||
|
|
||||||
metrics, grader_label, explanation = self.grade_sample(
|
|
||||||
grading_response_dict=grading_response_dict,
|
|
||||||
physician_labels=row['binary_labels'],
|
|
||||||
category=row['category'],
|
|
||||||
)
|
|
||||||
score = metrics['model_predicted_positive']
|
|
||||||
|
|
||||||
# Create HTML for each sample result
|
|
||||||
html = common.jinja_env.from_string(
|
|
||||||
HEALTHBENCH_META_HTML_JINJA).render(
|
|
||||||
prompt_messages=actual_queried_grader_convo,
|
|
||||||
next_message=dict(content=response_text, role='assistant'),
|
|
||||||
score=metrics['model_predicted_positive'],
|
|
||||||
extracted_answer=response_text,
|
|
||||||
explanation=explanation,
|
|
||||||
)
|
|
||||||
convo = actual_queried_grader_convo + [
|
|
||||||
dict(content=response_text, role='assistant')
|
|
||||||
]
|
|
||||||
return (
|
|
||||||
SingleEvalResult(html=html,
|
|
||||||
score=score,
|
|
||||||
convo=convo,
|
|
||||||
metrics=metrics),
|
|
||||||
grader_label,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run evaluation and collect results
|
|
||||||
all_outputs = common.map_with_progress(fn, self.examples,
|
|
||||||
self.n_threads)
|
|
||||||
results: list[SingleEvalResult]
|
|
||||||
grader_labels: list[bool]
|
|
||||||
results, grader_labels = zip(*all_outputs)
|
|
||||||
|
|
||||||
# model pairwise agreement metrics
|
|
||||||
model_agreement_metrics = compute_metrics_for_rater_by_class(
|
|
||||||
self_pred_list=grader_labels,
|
|
||||||
other_preds_list=[x['binary_labels'] for x in self.examples],
|
|
||||||
cluster_list=[x['category'] for x in self.examples],
|
|
||||||
model_or_physician='model',
|
|
||||||
)
|
|
||||||
|
|
||||||
# physicians:
|
|
||||||
physician_rating_lists = defaultdict(lambda: ([], [], []))
|
|
||||||
for example in self.examples:
|
|
||||||
for i in range(len(example['binary_labels'])):
|
|
||||||
physician_id = example['anonymized_physician_ids'][i]
|
|
||||||
self_pred = example['binary_labels'][i]
|
|
||||||
other_preds = (example['binary_labels'][:i] +
|
|
||||||
example['binary_labels'][i + 1:])
|
|
||||||
cluster = example['category']
|
|
||||||
physician_rating_lists[physician_id][0].append(self_pred)
|
|
||||||
physician_rating_lists[physician_id][1].append(other_preds)
|
|
||||||
physician_rating_lists[physician_id][2].append(cluster)
|
|
||||||
|
|
||||||
physician_agreement_metric_lists = defaultdict(dict)
|
|
||||||
for physician_id, (
|
|
||||||
physician_rating_list,
|
|
||||||
other_preds_list,
|
|
||||||
cluster_list,
|
|
||||||
) in physician_rating_lists.items():
|
|
||||||
physician_agreement_metrics = compute_metrics_for_rater_by_class(
|
|
||||||
self_pred_list=physician_rating_list,
|
|
||||||
other_preds_list=other_preds_list,
|
|
||||||
cluster_list=cluster_list,
|
|
||||||
model_or_physician='physician',
|
|
||||||
)
|
|
||||||
for k, v in physician_agreement_metrics.items():
|
|
||||||
physician_agreement_metric_lists[k][physician_id] = v
|
|
||||||
|
|
||||||
# consolidate final metrics and add agreement metrics
|
|
||||||
final_metrics = common.aggregate_results(
|
|
||||||
results, default_stats=('mean', 'n_samples', 'bootstrap_std'))
|
|
||||||
model_agreement_metrics_condensed: dict[str, float] = {
|
|
||||||
k: v['value']
|
|
||||||
for k, v in model_agreement_metrics.items()
|
|
||||||
if v['value'] is not None
|
|
||||||
}
|
|
||||||
assert final_metrics.metrics is not None
|
|
||||||
final_metrics.metrics.update(model_agreement_metrics_condensed)
|
|
||||||
final_metrics.score = final_metrics.metrics[
|
|
||||||
'pairwise_model_f1_balanced']
|
|
||||||
|
|
||||||
final_metrics.metadata = {
|
|
||||||
'model_agreement_metrics': model_agreement_metrics,
|
|
||||||
'physician_agreement_metric_lists':
|
|
||||||
physician_agreement_metric_lists,
|
|
||||||
}
|
|
||||||
return final_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics_for_rater_by_class(
|
|
||||||
self_pred_list: list[bool],
|
|
||||||
other_preds_list: list[list[bool]],
|
|
||||||
cluster_list: list[str],
|
|
||||||
model_or_physician: Literal['model', 'physician'],
|
|
||||||
) -> dict[str, dict[str, float | None]]:
|
|
||||||
# get all the metrics for each cluster
|
|
||||||
metric_lists = defaultdict(list)
|
|
||||||
for self_pred, other_preds, cluster in zip(self_pred_list,
|
|
||||||
other_preds_list,
|
|
||||||
cluster_list,
|
|
||||||
strict=True):
|
|
||||||
self_pred_str = 'pos' if self_pred else 'neg'
|
|
||||||
for other_pred in other_preds:
|
|
||||||
# precision. based on the grader's labels -
|
|
||||||
# i.e., calculated as TP / (TP + FP)
|
|
||||||
# so a prediction should be recorded whenever self_pred is True
|
|
||||||
precision_index_str = INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician,
|
|
||||||
metric='precision',
|
|
||||||
pred_str=self_pred_str,
|
|
||||||
)
|
|
||||||
metric_lists[precision_index_str].append(self_pred == other_pred)
|
|
||||||
precision_cluster_str = CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster=cluster, index_str=precision_index_str)
|
|
||||||
metric_lists[precision_cluster_str].append(self_pred == other_pred)
|
|
||||||
|
|
||||||
# recall. based on the ground truth labels -
|
|
||||||
# i.e., calculated as TP / (TP + FN)
|
|
||||||
# so a prediction should be recorded whenever other_pred is True
|
|
||||||
other_pred_str = 'pos' if other_pred else 'neg'
|
|
||||||
recall_index_str = INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician,
|
|
||||||
metric='recall',
|
|
||||||
pred_str=other_pred_str,
|
|
||||||
)
|
|
||||||
metric_lists[recall_index_str].append(self_pred == other_pred)
|
|
||||||
recall_cluster_str = CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster=cluster, index_str=recall_index_str)
|
|
||||||
metric_lists[recall_cluster_str].append(self_pred == other_pred)
|
|
||||||
|
|
||||||
metrics: dict[str, dict[str, float | None]] = {}
|
|
||||||
for index_str, metric_list in metric_lists.items():
|
|
||||||
n = len(metric_list)
|
|
||||||
metric = sum(metric_list) / n if n > 0 else None
|
|
||||||
metrics[index_str] = {
|
|
||||||
'n': n,
|
|
||||||
'value': metric,
|
|
||||||
}
|
|
||||||
|
|
||||||
f1_metrics = get_f1_metrics(metrics)
|
|
||||||
metrics.update(f1_metrics)
|
|
||||||
|
|
||||||
balanced_metrics = get_balanced_metrics(metrics)
|
|
||||||
metrics.update(balanced_metrics)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
def get_f1_metrics(
|
|
||||||
metrics: dict[str, dict[str, float | None]],
|
|
||||||
) -> dict[str, dict[str, float | None]]:
|
|
||||||
f1_metrics: dict[str, dict[str, float | None]] = {}
|
|
||||||
for precision_key_name in metrics:
|
|
||||||
if 'precision' in precision_key_name:
|
|
||||||
recall_key_name = precision_key_name.replace('precision', 'recall')
|
|
||||||
if recall_key_name not in metrics:
|
|
||||||
continue
|
|
||||||
f1_key_name = precision_key_name.replace('precision', 'f1')
|
|
||||||
assert f1_key_name not in metrics
|
|
||||||
f1_metrics[f1_key_name] = compute_f1_metric(
|
|
||||||
precision=metrics[precision_key_name],
|
|
||||||
recall=metrics[recall_key_name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return f1_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def compute_f1_metric(
|
|
||||||
precision: dict[str, float | None],
|
|
||||||
recall: dict[str, float | None],
|
|
||||||
) -> dict[str, float | None]:
|
|
||||||
precision_n = precision['n']
|
|
||||||
recall_n = recall['n']
|
|
||||||
assert precision_n is not None and recall_n is not None, 'n_pos or n_neg is None'
|
|
||||||
|
|
||||||
precision_metric = precision['value']
|
|
||||||
recall_metric = recall['value']
|
|
||||||
if precision_metric is None or recall_metric is None:
|
|
||||||
f1_metric = None
|
|
||||||
n_f1 = (
|
|
||||||
precision_n + recall_n
|
|
||||||
) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting
|
|
||||||
elif precision_metric == 0 and recall_metric == 0:
|
|
||||||
f1_metric = 0.0
|
|
||||||
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
|
|
||||||
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
|
|
||||||
else:
|
|
||||||
f1_metric = (2 * (precision_metric * recall_metric) /
|
|
||||||
(precision_metric + recall_metric))
|
|
||||||
tp = precision_metric * precision_n # because precision = TP / (TP+FP)
|
|
||||||
n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
|
|
||||||
|
|
||||||
return {
|
|
||||||
'n': n_f1,
|
|
||||||
'value': f1_metric,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_balanced_metrics(
|
|
||||||
metrics: dict[str, dict[str, float | None]],
|
|
||||||
) -> dict[str, dict[str, float | None]]:
|
|
||||||
balanced_metrics: dict[str, dict[str, float | None]] = {}
|
|
||||||
for pos_key_name in metrics:
|
|
||||||
if 'pos' in pos_key_name:
|
|
||||||
neg_key_name = pos_key_name.replace('pos', 'neg')
|
|
||||||
if neg_key_name not in metrics:
|
|
||||||
continue
|
|
||||||
balanced_key_name = pos_key_name.replace('pos', 'balanced')
|
|
||||||
assert balanced_key_name not in metrics
|
|
||||||
balanced_metrics[balanced_key_name] = compute_balanced_metric(
|
|
||||||
metric_pos=metrics[pos_key_name],
|
|
||||||
metric_neg=metrics[neg_key_name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return balanced_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def compute_balanced_metric(
|
|
||||||
metric_pos: dict[str, float | None],
|
|
||||||
metric_neg: dict[str, float | None],
|
|
||||||
) -> dict[str, float | None]:
|
|
||||||
n_pos = metric_pos['n']
|
|
||||||
n_neg = metric_neg['n']
|
|
||||||
assert n_pos is not None and n_neg is not None, 'n_pos or n_neg is None'
|
|
||||||
|
|
||||||
pos_metric = metric_pos['value']
|
|
||||||
neg_metric = metric_neg['value']
|
|
||||||
if pos_metric is None or neg_metric is None:
|
|
||||||
metric = None
|
|
||||||
else:
|
|
||||||
metric = (pos_metric + neg_metric) / 2
|
|
||||||
|
|
||||||
return {
|
|
||||||
'n': n_pos + n_neg,
|
|
||||||
# note: this overcounts samples going towards the balanced F1
|
|
||||||
'value': metric,
|
|
||||||
}
|
|
@ -1,165 +0,0 @@
|
|||||||
from . import healthbench_meta_eval
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_agreement_for_rater_by_class():
|
|
||||||
self_pred_list = [True, False, True]
|
|
||||||
other_preds_list = [[True, True, False], [True, False], [False]]
|
|
||||||
cluster_list = ['a', 'a', 'b']
|
|
||||||
model_or_physician = 'model'
|
|
||||||
metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class(
|
|
||||||
self_pred_list, other_preds_list, cluster_list, model_or_physician
|
|
||||||
)
|
|
||||||
|
|
||||||
# precision overall
|
|
||||||
index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='precision', pred_str='pos'
|
|
||||||
)
|
|
||||||
index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='precision', pred_str='neg'
|
|
||||||
)
|
|
||||||
overall_pos_precision = metrics[index_str_pos_precision]
|
|
||||||
overall_neg_precision = metrics[index_str_neg_precision]
|
|
||||||
expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1)
|
|
||||||
expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0)
|
|
||||||
assert overall_pos_precision['value'] == expected_overall_pos_precision
|
|
||||||
assert overall_neg_precision['value'] == expected_overall_neg_precision
|
|
||||||
assert overall_pos_precision['n'] == 4
|
|
||||||
assert overall_neg_precision['n'] == 2
|
|
||||||
|
|
||||||
# recall overall
|
|
||||||
index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='recall', pred_str='pos'
|
|
||||||
)
|
|
||||||
index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='recall', pred_str='neg'
|
|
||||||
)
|
|
||||||
overall_pos_recall = metrics[index_str_pos_recall]
|
|
||||||
overall_neg_recall = metrics[index_str_neg_recall]
|
|
||||||
expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0)
|
|
||||||
expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1)
|
|
||||||
assert overall_pos_recall['value'] == expected_overall_pos_recall
|
|
||||||
assert overall_neg_recall['value'] == expected_overall_neg_recall
|
|
||||||
assert overall_pos_recall['n'] == 3
|
|
||||||
assert overall_neg_recall['n'] == 3
|
|
||||||
|
|
||||||
# f1 overall
|
|
||||||
index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='f1', pred_str='pos'
|
|
||||||
)
|
|
||||||
index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='f1', pred_str='neg'
|
|
||||||
)
|
|
||||||
overall_pos_f1 = metrics[index_str_pos_f1]
|
|
||||||
overall_neg_f1 = metrics[index_str_neg_f1]
|
|
||||||
expected_overall_pos_f1 = (
|
|
||||||
2
|
|
||||||
* expected_overall_pos_precision
|
|
||||||
* expected_overall_pos_recall
|
|
||||||
/ (expected_overall_pos_precision + expected_overall_pos_recall)
|
|
||||||
)
|
|
||||||
expected_overall_neg_f1 = (
|
|
||||||
2
|
|
||||||
* expected_overall_neg_precision
|
|
||||||
* expected_overall_neg_recall
|
|
||||||
/ (expected_overall_neg_precision + expected_overall_neg_recall)
|
|
||||||
)
|
|
||||||
assert overall_pos_f1['value'] == expected_overall_pos_f1
|
|
||||||
assert overall_neg_f1['value'] == expected_overall_neg_f1
|
|
||||||
|
|
||||||
# balanced f1
|
|
||||||
index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
|
|
||||||
model_or_physician=model_or_physician, metric='f1', pred_str='balanced'
|
|
||||||
)
|
|
||||||
balanced_f1 = metrics[index_str_balanced_f1]
|
|
||||||
expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2
|
|
||||||
assert balanced_f1['value'] == expected_balanced_f1
|
|
||||||
|
|
||||||
# by cluster
|
|
||||||
# precision
|
|
||||||
cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='a', index_str=index_str_pos_precision
|
|
||||||
)
|
|
||||||
cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='a', index_str=index_str_neg_precision
|
|
||||||
)
|
|
||||||
cluster_a_pos_precision = metrics[cluster_a_str_pos_precision]
|
|
||||||
cluster_a_neg_precision = metrics[cluster_a_str_neg_precision]
|
|
||||||
assert cluster_a_pos_precision['value'] == (
|
|
||||||
# example 1, 2 in order
|
|
||||||
(2 + 0) / (3 + 0)
|
|
||||||
)
|
|
||||||
assert cluster_a_neg_precision['value'] == (
|
|
||||||
# example 1, 2 in order
|
|
||||||
(0 + 1) / (0 + 2)
|
|
||||||
)
|
|
||||||
assert cluster_a_pos_precision['n'] == 3
|
|
||||||
assert cluster_a_neg_precision['n'] == 2
|
|
||||||
|
|
||||||
# recall
|
|
||||||
cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='a', index_str=index_str_pos_recall
|
|
||||||
)
|
|
||||||
cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='a', index_str=index_str_neg_recall
|
|
||||||
)
|
|
||||||
cluster_a_pos_recall = metrics[cluster_a_str_pos_recall]
|
|
||||||
cluster_a_neg_recall = metrics[cluster_a_str_neg_recall]
|
|
||||||
assert cluster_a_pos_recall['value'] == (
|
|
||||||
# example 1, 2 in order
|
|
||||||
(2 + 0) / (2 + 1)
|
|
||||||
)
|
|
||||||
assert cluster_a_neg_recall['value'] == (
|
|
||||||
# example 1, 2 in order
|
|
||||||
(0 + 1) / (1 + 1)
|
|
||||||
)
|
|
||||||
assert cluster_a_pos_recall['n'] == 3
|
|
||||||
assert cluster_a_neg_recall['n'] == 2
|
|
||||||
|
|
||||||
# cluster B
|
|
||||||
# precision
|
|
||||||
cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_pos_precision
|
|
||||||
)
|
|
||||||
cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_neg_precision
|
|
||||||
)
|
|
||||||
cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision]
|
|
||||||
assert cluster_b_str_neg_precision not in metrics
|
|
||||||
assert cluster_b_str_pos_precision['value'] == (
|
|
||||||
# example 3 only
|
|
||||||
0 / 1
|
|
||||||
)
|
|
||||||
assert cluster_b_str_pos_precision['n'] == 1
|
|
||||||
|
|
||||||
# recall
|
|
||||||
cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_pos_recall
|
|
||||||
)
|
|
||||||
cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_neg_recall
|
|
||||||
)
|
|
||||||
assert cluster_b_str_pos_recall not in metrics
|
|
||||||
cluster_b_neg_recall = metrics[cluster_b_str_neg_recall]
|
|
||||||
assert cluster_b_neg_recall['value'] == (
|
|
||||||
# example 3 only
|
|
||||||
0 / 1
|
|
||||||
)
|
|
||||||
assert cluster_b_neg_recall['n'] == 1
|
|
||||||
|
|
||||||
# f1
|
|
||||||
index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_pos_f1
|
|
||||||
)
|
|
||||||
index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_neg_f1
|
|
||||||
)
|
|
||||||
index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
|
|
||||||
cluster='b', index_str=index_str_balanced_f1
|
|
||||||
)
|
|
||||||
assert index_str_pos_f1 not in metrics
|
|
||||||
assert index_str_neg_f1 not in metrics
|
|
||||||
assert index_str_balanced_f1 not in metrics
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_compute_agreement_for_rater_by_class()
|
|
@ -1,103 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
from .. import common
|
|
||||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
|
||||||
|
|
||||||
CLAUDE_SYSTEM_MESSAGE_LMSYS = (
|
|
||||||
'The assistant is Claude, created by Anthropic. The current date is '
|
|
||||||
"{currentDateTime}. Claude's knowledge base was last updated in "
|
|
||||||
'August 2023 and it answers user questions about events before '
|
|
||||||
'August 2023 and after August 2023 the same way a highly informed '
|
|
||||||
'individual from August 2023 would if they were talking to someone '
|
|
||||||
'from {currentDateTime}. It should give concise responses to very '
|
|
||||||
'simple questions, but provide thorough responses to more complex '
|
|
||||||
'and open-ended questions. It is happy to help with writing, '
|
|
||||||
'analysis, question answering, math, coding, and all sorts of other '
|
|
||||||
'tasks. It uses markdown for coding. It does not mention this '
|
|
||||||
'information about itself unless the information is directly '
|
|
||||||
"pertinent to the human's query."
|
|
||||||
).format(currentDateTime='2024-04-01')
|
|
||||||
# reference: https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCompletionSampler(SamplerBase):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
system_message: str | None = None,
|
|
||||||
temperature: float = 0.0, # default in Anthropic example
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
):
|
|
||||||
self.client = anthropic.Anthropic()
|
|
||||||
self.api_key = os.environ.get('ANTHROPIC_API_KEY') # please set your API_KEY
|
|
||||||
self.model = model
|
|
||||||
self.system_message = system_message
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.image_format = 'base64'
|
|
||||||
|
|
||||||
def _handle_image(
|
|
||||||
self,
|
|
||||||
image: str,
|
|
||||||
encoding: str = 'base64',
|
|
||||||
format: str = 'png',
|
|
||||||
fovea: int = 768,
|
|
||||||
):
|
|
||||||
new_image = {
|
|
||||||
'type': 'image',
|
|
||||||
'source': {
|
|
||||||
'type': encoding,
|
|
||||||
'media_type': f'image/{format}',
|
|
||||||
'data': image,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return new_image
|
|
||||||
|
|
||||||
def _handle_text(self, text):
|
|
||||||
return {'type': 'text', 'text': text}
|
|
||||||
|
|
||||||
def _pack_message(self, role, content):
|
|
||||||
return {'role': str(role), 'content': content}
|
|
||||||
|
|
||||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
|
||||||
trial = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if not common.has_only_user_assistant_messages(message_list):
|
|
||||||
raise ValueError(f'Claude sampler only supports user and assistant messages, got {message_list}')
|
|
||||||
if self.system_message:
|
|
||||||
response_message = self.client.messages.create(
|
|
||||||
model=self.model,
|
|
||||||
system=self.system_message,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
messages=message_list,
|
|
||||||
)
|
|
||||||
claude_input_messages: MessageList = [{'role': 'system', 'content': self.system_message}] + message_list
|
|
||||||
else:
|
|
||||||
response_message = self.client.messages.create(
|
|
||||||
model=self.model,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
messages=message_list,
|
|
||||||
)
|
|
||||||
claude_input_messages = message_list
|
|
||||||
response_text = response_message.content[0].text
|
|
||||||
return SamplerResponse(
|
|
||||||
response_text=response_text,
|
|
||||||
response_metadata={},
|
|
||||||
actual_queried_message_list=claude_input_messages,
|
|
||||||
)
|
|
||||||
except anthropic.RateLimitError as e:
|
|
||||||
exception_backoff = 2**trial # expontial back off
|
|
||||||
print(
|
|
||||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
time.sleep(exception_backoff)
|
|
||||||
trial += 1
|
|
||||||
# unknown error shall throw exception
|
|
@ -1,78 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
|
||||||
|
|
||||||
|
|
||||||
class OChatCompletionSampler(SamplerBase):
|
|
||||||
"""Sample from OpenAI's chat completion API for o series models."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
model: str = 'o1-mini',
|
|
||||||
):
|
|
||||||
self.api_key_name = 'OPENAI_API_KEY'
|
|
||||||
self.client = OpenAI()
|
|
||||||
# using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
|
|
||||||
self.model = model
|
|
||||||
self.image_format = 'url'
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
|
|
||||||
def _handle_image(
|
|
||||||
self,
|
|
||||||
image: str,
|
|
||||||
encoding: str = 'base64',
|
|
||||||
format: str = 'png',
|
|
||||||
fovea: int = 768,
|
|
||||||
):
|
|
||||||
new_image = {
|
|
||||||
'type': 'image_url',
|
|
||||||
'image_url': {
|
|
||||||
'url': f'data:image/{format};{encoding},{image}',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return new_image
|
|
||||||
|
|
||||||
def _handle_text(self, text: str):
|
|
||||||
return {'type': 'text', 'text': text}
|
|
||||||
|
|
||||||
def _pack_message(self, role: str, content: Any):
|
|
||||||
return {'role': str(role), 'content': content}
|
|
||||||
|
|
||||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
|
||||||
trial = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=message_list,
|
|
||||||
reasoning_effort=self.reasoning_effort,
|
|
||||||
)
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
return SamplerResponse(
|
|
||||||
response_text=content,
|
|
||||||
response_metadata={'usage': response.usage},
|
|
||||||
actual_queried_message_list=message_list,
|
|
||||||
)
|
|
||||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
print('Bad Request Error', e)
|
|
||||||
return SamplerResponse(
|
|
||||||
response_text='',
|
|
||||||
response_metadata={'usage': None},
|
|
||||||
actual_queried_message_list=message_list,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
exception_backoff = 2**trial # expontial back off
|
|
||||||
print(
|
|
||||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
time.sleep(exception_backoff)
|
|
||||||
trial += 1
|
|
||||||
# unknown error shall throw exception
|
|
@ -1,97 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from ..types import MessageList, SamplerBase, SamplerResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ResponsesSampler(SamplerBase):
|
|
||||||
"""Sample from OpenAI's responses API."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str = 'gpt-4.1',
|
|
||||||
system_message: str | None = None,
|
|
||||||
temperature: float = 0.5,
|
|
||||||
max_tokens: int = 1024,
|
|
||||||
reasoning_model: bool = False,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
):
|
|
||||||
self.api_key_name = 'OPENAI_API_KEY'
|
|
||||||
assert os.environ.get('OPENAI_API_KEY'), 'Please set OPENAI_API_KEY'
|
|
||||||
self.client = OpenAI()
|
|
||||||
self.model = model
|
|
||||||
self.system_message = system_message
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.image_format = 'url'
|
|
||||||
self.reasoning_model = reasoning_model
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
|
|
||||||
def _handle_image(
|
|
||||||
self,
|
|
||||||
image: str,
|
|
||||||
encoding: str = 'base64',
|
|
||||||
format: str = 'png',
|
|
||||||
fovea: int = 768,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
new_image = {
|
|
||||||
'type': 'input_image',
|
|
||||||
'image_url': f'data:image/{format};{encoding},{image}',
|
|
||||||
}
|
|
||||||
return new_image
|
|
||||||
|
|
||||||
def _handle_text(self, text: str) -> dict[str, Any]:
|
|
||||||
return {'type': 'input_text', 'text': text}
|
|
||||||
|
|
||||||
def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
|
|
||||||
return {'role': role, 'content': content}
|
|
||||||
|
|
||||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
|
||||||
if self.system_message:
|
|
||||||
message_list = [
|
|
||||||
self._pack_message('developer', self.system_message)
|
|
||||||
] + message_list
|
|
||||||
trial = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if self.reasoning_model:
|
|
||||||
reasoning = ({
|
|
||||||
'effort': self.reasoning_effort
|
|
||||||
} if self.reasoning_effort else None)
|
|
||||||
response = self.client.responses.create(
|
|
||||||
model=self.model,
|
|
||||||
input=message_list,
|
|
||||||
reasoning=reasoning,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = self.client.responses.create(
|
|
||||||
model=self.model,
|
|
||||||
input=message_list,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_output_tokens=self.max_tokens,
|
|
||||||
)
|
|
||||||
return SamplerResponse(
|
|
||||||
response_text=response.output_text,
|
|
||||||
response_metadata={'usage': response.usage},
|
|
||||||
actual_queried_message_list=message_list,
|
|
||||||
)
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
print('Bad Request Error', e)
|
|
||||||
return SamplerResponse(
|
|
||||||
response_text='',
|
|
||||||
response_metadata={'usage': None},
|
|
||||||
actual_queried_message_list=message_list,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
exception_backoff = 2**trial # expontial back off
|
|
||||||
print(
|
|
||||||
f'Rate limit exception so wait and retry {trial} after {exception_backoff} sec',
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
time.sleep(exception_backoff)
|
|
||||||
trial += 1
|
|
||||||
# unknown error shall throw exception
|
|
55
opencompass/datasets/healthbench/types.py
Normal file
55
opencompass/datasets/healthbench/types.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
|
Message = dict[str, Any] # keys role, content
|
||||||
|
MessageList = list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplerResponse:
|
||||||
|
"""Response from a sampler."""
|
||||||
|
response_text: str
|
||||||
|
actual_queried_message_list: MessageList
|
||||||
|
response_metadata: dict[str, Any]
|
||||||
|
|
||||||
|
class SamplerBase:
|
||||||
|
"""Base class for defining a sampling model, which can be evaluated, or
|
||||||
|
used as part of the grading process."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
message_list: MessageList,
|
||||||
|
) -> SamplerResponse:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalResult:
|
||||||
|
"""Result of running an evaluation (usually consisting of many samples)"""
|
||||||
|
|
||||||
|
score: float | None # top-line metric
|
||||||
|
metrics: dict[str, float] | None # other metrics
|
||||||
|
htmls: list[str] # strings of valid HTML
|
||||||
|
convos: list[MessageList] # sampled conversations
|
||||||
|
metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SingleEvalResult:
|
||||||
|
"""Result of evaluating a single sample."""
|
||||||
|
|
||||||
|
score: float | None
|
||||||
|
metrics: dict[str, float] = field(default_factory=dict)
|
||||||
|
html: str | None = None
|
||||||
|
convo: MessageList | None = None # sampled conversation
|
||||||
|
example_level_metadata: dict[str, Any] | None = (
|
||||||
|
None # Extra data such as rubric scores or sollen
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Eval:
|
||||||
|
"""Base class for defining an evaluation."""
|
||||||
|
|
||||||
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||||
|
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user