This commit is contained in:
bio-mlhui 2025-05-23 07:26:21 +00:00
parent 95d8d2ba4d
commit 7c6d788dca
4 changed files with 85 additions and 71 deletions

View File

@ -1,47 +1,33 @@
from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import HealthBenchTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
# Reader configuration # Reader configuration
reader_cfg = dict( reader_cfg = dict(
input_columns=[ input_columns=[
'example_tags', 'ideal_completions_data', 'prompt', 'prompt_id', 'rubrics', 'canary' 'prompt_trans'
], ],
output_column='prompt_id', # useless output_column='prompt_id', # useless
) )
# Inference configuration
infer_cfg = dict( infer_cfg = dict(
prompt_template=dict( prompt_template=dict(
type=HealthBenchTemplate, type=PromptTemplate,
key='prompt_trans', template=dict(
round=[
dict(
role='HUMAN',
prompt='{prompt}', # prompt mode: zero-shot
),
],
),
), ),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer), inferencer=dict(type=ChatInferencer),
) )
# infer_cfg = dict(
# prompt_template=dict(
# type=PromptTemplate,
# template=dict(
# round=[
# dict(
# role='HUMAN',
# prompt='{prompt_id}', # prompt mode: zero-shot
# ),
# dict(
# role='BOT',
# prompt='{prompt_id}', # prompt mode: zero-shot
# ),
# ],
# ),
# ),
# retriever=dict(type=ZeroRetriever),
# inferencer=dict(type=GenInferencer),
# )
# Evaluation configuration # Evaluation configuration
eval_cfg = dict( eval_cfg = dict(
evaluator=dict(type=HealthBenchEvaluator), evaluator=dict(type=HealthBenchEvaluator),

View File

@ -24,22 +24,8 @@ grading_sampler = ChatCompletionSampler(
max_tokens=2048, max_tokens=2048,
) )
def _parse(item): def _parse(item):
prompt = item['prompt'] prompt = item['prompt'] + [dict(role='assistant', content='')]
new_prompts = [] item['prompt_trans'] = prompt
for idx in range(len(prompt)):
foo = {}
content = prompt[idx]['content']
foo['prompt'] = content
role = prompt[idx]['role']
if role == 'user':
foo['role'] = 'HUMAN'
elif role == 'assistant':
foo['role'] = 'BOT'
else:
raise ValueError()
new_prompts.append(foo)
item['prompt_trans'] = new_prompts
# item["rubrics"] = [RubricItem.from_dict(d) for d in item["rubrics"]]
return item return item
HEALTHBENCH_HTML_JINJA = ( HEALTHBENCH_HTML_JINJA = (
@ -84,6 +70,7 @@ class HealthBenchDataset(BaseDataset):
dataset = dataset.map(lambda item: _parse(item)) dataset = dataset.map(lambda item: _parse(item))
return dataset return dataset
from collections import defaultdict from collections import defaultdict
from .types import MessageList from .types import MessageList
@ -195,6 +182,62 @@ def get_usage_dict(response_usage) -> dict[str, int | None]:
} }
import hashlib import hashlib
import numpy as np
from .types import EvalResult, MessageList, SingleEvalResult
def _compute_clipped_stats(
values: list,
stat: str,
):
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and
n_samples for final HealthBench scoring."""
if stat == 'mean':
return np.clip(np.mean(values), 0, 1)
elif stat == 'n_samples':
return len(values)
elif stat == 'bootstrap_std':
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
bootstrap_means = [
_compute_clipped_stats(list(s), 'mean') for s in bootstrap_samples
]
return np.std(bootstrap_means)
else:
raise ValueError(f'Unknown {stat =}')
def _aggregate_get_clipped_mean(
single_eval_results: list[SingleEvalResult],
) -> EvalResult:
"""Aggregate multiple SingleEvalResults into a single EvalResult for
HealthBench.
For each metric, returns the stats in _compute_clipped_stats.
"""
name2values = defaultdict(list)
htmls = []
convos = []
metadata = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values['score'].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
metadata.append(single_eval_result.example_level_metadata)
final_metrics = {}
for name, values in name2values.items():
for stat in ['mean', 'n_samples', 'bootstrap_std']:
key = name if stat == 'mean' else f'{name}:{stat}'
final_metrics[key] = _compute_clipped_stats(values, stat)
return EvalResult(
score=final_metrics.pop('score', None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
metadata={'example_level_metadata': metadata},
)
class HealthBenchEvaluator(BaseEvaluator): class HealthBenchEvaluator(BaseEvaluator):
@ -290,10 +333,8 @@ class HealthBenchEvaluator(BaseEvaluator):
def score(self, predictions, references, test_set): def score(self, predictions, references, test_set):
results = [] results = []
ret = []
if len(predictions) != len(references): if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'} return {'error': 'preds and refrs have different length'}
all_score = 0
for idx, (i, j) in enumerate(zip(predictions, references)): for idx, (i, j) in enumerate(zip(predictions, references)):
row = test_set[idx] row = test_set[idx]
prompt_messages = row['prompt'] prompt_messages = row['prompt']
@ -328,7 +369,7 @@ class HealthBenchEvaluator(BaseEvaluator):
convo = actual_queried_prompt_messages + [ convo = actual_queried_prompt_messages + [
dict(content=response_text, role='assistant') dict(content=response_text, role='assistant')
] ]
ret.append(SingleEvalResult( results.append(SingleEvalResult(
html=html, html=html,
score=score, score=score,
convo=convo, convo=convo,
@ -345,13 +386,18 @@ class HealthBenchEvaluator(BaseEvaluator):
).hexdigest(), ).hexdigest(),
}, },
)) ))
all_score += score results = _aggregate_get_clipped_mean(results)
avg_score = all_score / float(idx+1) assert results.metrics is not None
metrics = results.metrics | {'score': results.score}
return { metrics = dict(sorted(metrics.items()))
'score': avg_score result_dict = {
'score': results.score,
'metrics': results.metrics,
'htmls': results.htmls,
'convos': results.convos,
'metadata': results.metadata,
} }
return {'accuracy': result_dict['score'],}

View File

@ -257,22 +257,3 @@ class PromptTemplate:
prompt.append(dict(section='end', pos='end')) prompt.append(dict(section='end', pos='end'))
return prompt return prompt
class HealthBenchTemplate:
def __init__(
self,
key: Union[Dict, str],
) -> None:
self.key = key
def generate_item(self, entry: Dict, **kwargs):
template = [{'section': 'round', 'pos': 'begin'}]
end_template = [{'section': 'round', 'pos': 'end'}]
mid = entry[self.key]
template = template + mid + end_template
ret = PromptList()
for item in template:
ret.append(item)
return ret

1
root Symbolic link
View File

@ -0,0 +1 @@
/root