This commit is contained in:
bio-mlhui 2025-05-23 07:26:21 +00:00
parent 95d8d2ba4d
commit 7c6d788dca
4 changed files with 85 additions and 71 deletions

View File

@ -1,47 +1,33 @@
from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import HealthBenchTemplate
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
# Reader configuration
reader_cfg = dict(
input_columns=[
'example_tags', 'ideal_completions_data', 'prompt', 'prompt_id', 'rubrics', 'canary'
'prompt_trans'
],
output_column='prompt_id', # useless
)
# Inference configuration
infer_cfg = dict(
prompt_template=dict(
type=HealthBenchTemplate,
key='prompt_trans',
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{prompt}', # prompt mode: zero-shot
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
inferencer=dict(type=ChatInferencer),
)
# infer_cfg = dict(
# prompt_template=dict(
# type=PromptTemplate,
# template=dict(
# round=[
# dict(
# role='HUMAN',
# prompt='{prompt_id}', # prompt mode: zero-shot
# ),
# dict(
# role='BOT',
# prompt='{prompt_id}', # prompt mode: zero-shot
# ),
# ],
# ),
# ),
# retriever=dict(type=ZeroRetriever),
# inferencer=dict(type=GenInferencer),
# )
# Evaluation configuration
eval_cfg = dict(
evaluator=dict(type=HealthBenchEvaluator),

View File

@ -24,22 +24,8 @@ grading_sampler = ChatCompletionSampler(
max_tokens=2048,
)
def _parse(item):
prompt = item['prompt']
new_prompts = []
for idx in range(len(prompt)):
foo = {}
content = prompt[idx]['content']
foo['prompt'] = content
role = prompt[idx]['role']
if role == 'user':
foo['role'] = 'HUMAN'
elif role == 'assistant':
foo['role'] = 'BOT'
else:
raise ValueError()
new_prompts.append(foo)
item['prompt_trans'] = new_prompts
# item["rubrics"] = [RubricItem.from_dict(d) for d in item["rubrics"]]
prompt = item['prompt'] + [dict(role='assistant', content='')]
item['prompt_trans'] = prompt
return item
HEALTHBENCH_HTML_JINJA = (
@ -84,6 +70,7 @@ class HealthBenchDataset(BaseDataset):
dataset = dataset.map(lambda item: _parse(item))
return dataset
from collections import defaultdict
from .types import MessageList
@ -195,6 +182,62 @@ def get_usage_dict(response_usage) -> dict[str, int | None]:
}
import hashlib
import numpy as np
from .types import EvalResult, MessageList, SingleEvalResult
def _compute_clipped_stats(
values: list,
stat: str,
):
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and
n_samples for final HealthBench scoring."""
if stat == 'mean':
return np.clip(np.mean(values), 0, 1)
elif stat == 'n_samples':
return len(values)
elif stat == 'bootstrap_std':
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
bootstrap_means = [
_compute_clipped_stats(list(s), 'mean') for s in bootstrap_samples
]
return np.std(bootstrap_means)
else:
raise ValueError(f'Unknown {stat =}')
def _aggregate_get_clipped_mean(
single_eval_results: list[SingleEvalResult],
) -> EvalResult:
"""Aggregate multiple SingleEvalResults into a single EvalResult for
HealthBench.
For each metric, returns the stats in _compute_clipped_stats.
"""
name2values = defaultdict(list)
htmls = []
convos = []
metadata = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values['score'].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
metadata.append(single_eval_result.example_level_metadata)
final_metrics = {}
for name, values in name2values.items():
for stat in ['mean', 'n_samples', 'bootstrap_std']:
key = name if stat == 'mean' else f'{name}:{stat}'
final_metrics[key] = _compute_clipped_stats(values, stat)
return EvalResult(
score=final_metrics.pop('score', None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
metadata={'example_level_metadata': metadata},
)
class HealthBenchEvaluator(BaseEvaluator):
@ -290,10 +333,8 @@ class HealthBenchEvaluator(BaseEvaluator):
def score(self, predictions, references, test_set):
results = []
ret = []
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
all_score = 0
for idx, (i, j) in enumerate(zip(predictions, references)):
row = test_set[idx]
prompt_messages = row['prompt']
@ -328,7 +369,7 @@ class HealthBenchEvaluator(BaseEvaluator):
convo = actual_queried_prompt_messages + [
dict(content=response_text, role='assistant')
]
ret.append(SingleEvalResult(
results.append(SingleEvalResult(
html=html,
score=score,
convo=convo,
@ -345,13 +386,18 @@ class HealthBenchEvaluator(BaseEvaluator):
).hexdigest(),
},
))
all_score += score
avg_score = all_score / float(idx+1)
return {
'score': avg_score
results = _aggregate_get_clipped_mean(results)
assert results.metrics is not None
metrics = results.metrics | {'score': results.score}
metrics = dict(sorted(metrics.items()))
result_dict = {
'score': results.score,
'metrics': results.metrics,
'htmls': results.htmls,
'convos': results.convos,
'metadata': results.metadata,
}
return {'accuracy': result_dict['score'],}

View File

@ -257,22 +257,3 @@ class PromptTemplate:
prompt.append(dict(section='end', pos='end'))
return prompt
class HealthBenchTemplate:
def __init__(
self,
key: Union[Dict, str],
) -> None:
self.key = key
def generate_item(self, entry: Dict, **kwargs):
template = [{'section': 'round', 'pos': 'begin'}]
end_template = [{'section': 'round', 'pos': 'end'}]
mid = entry[self.key]
template = template + mid + end_template
ret = PromptList()
for item in template:
ret.append(item)
return ret

1
root Symbolic link
View File

@ -0,0 +1 @@
/root