mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
first
This commit is contained in:
parent
95d8d2ba4d
commit
7c6d788dca
@ -1,47 +1,33 @@
|
|||||||
from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator
|
from opencompass.datasets import HealthBenchDataset, HealthBenchEvaluator
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||||
from opencompass.openicl.icl_prompt_template import HealthBenchTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
|
||||||
|
|
||||||
# Reader configuration
|
# Reader configuration
|
||||||
reader_cfg = dict(
|
reader_cfg = dict(
|
||||||
input_columns=[
|
input_columns=[
|
||||||
'example_tags', 'ideal_completions_data', 'prompt', 'prompt_id', 'rubrics', 'canary'
|
'prompt_trans'
|
||||||
],
|
],
|
||||||
output_column='prompt_id', # useless
|
output_column='prompt_id', # useless
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inference configuration
|
|
||||||
infer_cfg = dict(
|
infer_cfg = dict(
|
||||||
prompt_template=dict(
|
prompt_template=dict(
|
||||||
type=HealthBenchTemplate,
|
type=PromptTemplate,
|
||||||
key='prompt_trans',
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{prompt}', # prompt mode: zero-shot
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
),
|
),
|
||||||
retriever=dict(type=ZeroRetriever),
|
retriever=dict(type=ZeroRetriever),
|
||||||
inferencer=dict(type=GenInferencer),
|
inferencer=dict(type=ChatInferencer),
|
||||||
)
|
)
|
||||||
|
|
||||||
# infer_cfg = dict(
|
|
||||||
# prompt_template=dict(
|
|
||||||
# type=PromptTemplate,
|
|
||||||
# template=dict(
|
|
||||||
# round=[
|
|
||||||
# dict(
|
|
||||||
# role='HUMAN',
|
|
||||||
# prompt='{prompt_id}', # prompt mode: zero-shot
|
|
||||||
# ),
|
|
||||||
# dict(
|
|
||||||
# role='BOT',
|
|
||||||
# prompt='{prompt_id}', # prompt mode: zero-shot
|
|
||||||
# ),
|
|
||||||
# ],
|
|
||||||
# ),
|
|
||||||
# ),
|
|
||||||
# retriever=dict(type=ZeroRetriever),
|
|
||||||
# inferencer=dict(type=GenInferencer),
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Evaluation configuration
|
# Evaluation configuration
|
||||||
eval_cfg = dict(
|
eval_cfg = dict(
|
||||||
evaluator=dict(type=HealthBenchEvaluator),
|
evaluator=dict(type=HealthBenchEvaluator),
|
@ -24,22 +24,8 @@ grading_sampler = ChatCompletionSampler(
|
|||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
)
|
)
|
||||||
def _parse(item):
|
def _parse(item):
|
||||||
prompt = item['prompt']
|
prompt = item['prompt'] + [dict(role='assistant', content='')]
|
||||||
new_prompts = []
|
item['prompt_trans'] = prompt
|
||||||
for idx in range(len(prompt)):
|
|
||||||
foo = {}
|
|
||||||
content = prompt[idx]['content']
|
|
||||||
foo['prompt'] = content
|
|
||||||
role = prompt[idx]['role']
|
|
||||||
if role == 'user':
|
|
||||||
foo['role'] = 'HUMAN'
|
|
||||||
elif role == 'assistant':
|
|
||||||
foo['role'] = 'BOT'
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
new_prompts.append(foo)
|
|
||||||
item['prompt_trans'] = new_prompts
|
|
||||||
# item["rubrics"] = [RubricItem.from_dict(d) for d in item["rubrics"]]
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
HEALTHBENCH_HTML_JINJA = (
|
HEALTHBENCH_HTML_JINJA = (
|
||||||
@ -84,6 +70,7 @@ class HealthBenchDataset(BaseDataset):
|
|||||||
dataset = dataset.map(lambda item: _parse(item))
|
dataset = dataset.map(lambda item: _parse(item))
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from .types import MessageList
|
from .types import MessageList
|
||||||
@ -195,6 +182,62 @@ def get_usage_dict(response_usage) -> dict[str, int | None]:
|
|||||||
}
|
}
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .types import EvalResult, MessageList, SingleEvalResult
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_clipped_stats(
|
||||||
|
values: list,
|
||||||
|
stat: str,
|
||||||
|
):
|
||||||
|
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and
|
||||||
|
n_samples for final HealthBench scoring."""
|
||||||
|
if stat == 'mean':
|
||||||
|
return np.clip(np.mean(values), 0, 1)
|
||||||
|
elif stat == 'n_samples':
|
||||||
|
return len(values)
|
||||||
|
elif stat == 'bootstrap_std':
|
||||||
|
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
|
||||||
|
bootstrap_means = [
|
||||||
|
_compute_clipped_stats(list(s), 'mean') for s in bootstrap_samples
|
||||||
|
]
|
||||||
|
return np.std(bootstrap_means)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown {stat =}')
|
||||||
|
|
||||||
|
def _aggregate_get_clipped_mean(
|
||||||
|
single_eval_results: list[SingleEvalResult],
|
||||||
|
) -> EvalResult:
|
||||||
|
"""Aggregate multiple SingleEvalResults into a single EvalResult for
|
||||||
|
HealthBench.
|
||||||
|
|
||||||
|
For each metric, returns the stats in _compute_clipped_stats.
|
||||||
|
"""
|
||||||
|
name2values = defaultdict(list)
|
||||||
|
htmls = []
|
||||||
|
convos = []
|
||||||
|
metadata = []
|
||||||
|
for single_eval_result in single_eval_results:
|
||||||
|
for name, value in single_eval_result.metrics.items():
|
||||||
|
name2values[name].append(value)
|
||||||
|
if single_eval_result.score is not None:
|
||||||
|
name2values['score'].append(single_eval_result.score)
|
||||||
|
htmls.append(single_eval_result.html)
|
||||||
|
convos.append(single_eval_result.convo)
|
||||||
|
metadata.append(single_eval_result.example_level_metadata)
|
||||||
|
final_metrics = {}
|
||||||
|
for name, values in name2values.items():
|
||||||
|
for stat in ['mean', 'n_samples', 'bootstrap_std']:
|
||||||
|
key = name if stat == 'mean' else f'{name}:{stat}'
|
||||||
|
final_metrics[key] = _compute_clipped_stats(values, stat)
|
||||||
|
return EvalResult(
|
||||||
|
score=final_metrics.pop('score', None),
|
||||||
|
metrics=final_metrics,
|
||||||
|
htmls=htmls,
|
||||||
|
convos=convos,
|
||||||
|
metadata={'example_level_metadata': metadata},
|
||||||
|
)
|
||||||
|
|
||||||
class HealthBenchEvaluator(BaseEvaluator):
|
class HealthBenchEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
@ -290,10 +333,8 @@ class HealthBenchEvaluator(BaseEvaluator):
|
|||||||
|
|
||||||
def score(self, predictions, references, test_set):
|
def score(self, predictions, references, test_set):
|
||||||
results = []
|
results = []
|
||||||
ret = []
|
|
||||||
if len(predictions) != len(references):
|
if len(predictions) != len(references):
|
||||||
return {'error': 'preds and refrs have different length'}
|
return {'error': 'preds and refrs have different length'}
|
||||||
all_score = 0
|
|
||||||
for idx, (i, j) in enumerate(zip(predictions, references)):
|
for idx, (i, j) in enumerate(zip(predictions, references)):
|
||||||
row = test_set[idx]
|
row = test_set[idx]
|
||||||
prompt_messages = row['prompt']
|
prompt_messages = row['prompt']
|
||||||
@ -328,7 +369,7 @@ class HealthBenchEvaluator(BaseEvaluator):
|
|||||||
convo = actual_queried_prompt_messages + [
|
convo = actual_queried_prompt_messages + [
|
||||||
dict(content=response_text, role='assistant')
|
dict(content=response_text, role='assistant')
|
||||||
]
|
]
|
||||||
ret.append(SingleEvalResult(
|
results.append(SingleEvalResult(
|
||||||
html=html,
|
html=html,
|
||||||
score=score,
|
score=score,
|
||||||
convo=convo,
|
convo=convo,
|
||||||
@ -345,13 +386,18 @@ class HealthBenchEvaluator(BaseEvaluator):
|
|||||||
).hexdigest(),
|
).hexdigest(),
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
all_score += score
|
results = _aggregate_get_clipped_mean(results)
|
||||||
avg_score = all_score / float(idx+1)
|
assert results.metrics is not None
|
||||||
|
metrics = results.metrics | {'score': results.score}
|
||||||
return {
|
metrics = dict(sorted(metrics.items()))
|
||||||
'score': avg_score
|
result_dict = {
|
||||||
|
'score': results.score,
|
||||||
|
'metrics': results.metrics,
|
||||||
|
'htmls': results.htmls,
|
||||||
|
'convos': results.convos,
|
||||||
|
'metadata': results.metadata,
|
||||||
}
|
}
|
||||||
|
return {'accuracy': result_dict['score'],}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -257,22 +257,3 @@ class PromptTemplate:
|
|||||||
prompt.append(dict(section='end', pos='end'))
|
prompt.append(dict(section='end', pos='end'))
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
class HealthBenchTemplate:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
key: Union[Dict, str],
|
|
||||||
) -> None:
|
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def generate_item(self, entry: Dict, **kwargs):
|
|
||||||
template = [{'section': 'round', 'pos': 'begin'}]
|
|
||||||
end_template = [{'section': 'round', 'pos': 'end'}]
|
|
||||||
mid = entry[self.key]
|
|
||||||
template = template + mid + end_template
|
|
||||||
ret = PromptList()
|
|
||||||
for item in template:
|
|
||||||
ret.append(item)
|
|
||||||
return ret
|
|
||||||
|
Loading…
Reference in New Issue
Block a user