2025-03-18 20:15:20 +08:00
|
|
|
import os
|
2024-12-30 17:31:00 +08:00
|
|
|
import os.path as osp
|
2025-05-20 16:46:55 +08:00
|
|
|
from copy import deepcopy
|
2024-12-30 17:31:00 +08:00
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
import mmengine
|
2025-04-08 11:58:14 +08:00
|
|
|
from datasets import Dataset
|
2024-12-30 17:31:00 +08:00
|
|
|
from mmengine.config import ConfigDict
|
|
|
|
|
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
|
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
|
|
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
2025-01-07 00:14:32 +08:00
|
|
|
from opencompass.registry import (DICT_POSTPROCESSORS, ICL_PROMPT_TEMPLATES,
|
|
|
|
TEXT_POSTPROCESSORS)
|
2024-12-30 17:31:00 +08:00
|
|
|
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
|
|
|
|
from opencompass.utils.logging import get_logger
|
|
|
|
|
2025-05-20 16:46:55 +08:00
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
|
|
|
|
class GenericLLMEvaluator(BaseEvaluator):
|
|
|
|
"""Generic LLM evaluator.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
prompt_template (ConfigDict): The prompt template for evaluation.
|
|
|
|
judge_cfg (ConfigDict): The config for Judge LLM.
|
|
|
|
dataset_cfg (ConfigDict): The config for dataset.
|
|
|
|
pred_postprocessor (ConfigDict): The config for postprocessor.
|
2025-05-20 16:46:55 +08:00
|
|
|
used for the prediction results.
|
2024-12-30 17:31:00 +08:00
|
|
|
dict_postprocessor (ConfigDict): The config for postprocessor,
|
|
|
|
used for evaluation results dict.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
prompt_template: ConfigDict,
|
|
|
|
judge_cfg: ConfigDict,
|
|
|
|
dataset_cfg: Optional[ConfigDict] = None,
|
|
|
|
pred_postprocessor: Optional[ConfigDict] = None,
|
|
|
|
dict_postprocessor: Optional[ConfigDict] = None,
|
|
|
|
keep_predictions: bool = False,
|
|
|
|
) -> None:
|
2025-05-20 16:46:55 +08:00
|
|
|
super().__init__(pred_postprocessor=pred_postprocessor)
|
2025-03-18 20:15:20 +08:00
|
|
|
# If judge_cfg is not provided, fall back to the default configuration
|
|
|
|
if not judge_cfg:
|
|
|
|
self.judge_cfg = self.default_judge_cfg
|
|
|
|
else:
|
|
|
|
self.judge_cfg = judge_cfg
|
2024-12-31 13:05:05 +08:00
|
|
|
self.output_path = ''
|
2024-12-30 17:31:00 +08:00
|
|
|
|
|
|
|
self.prompt_template = ICL_PROMPT_TEMPLATES.build(prompt_template)
|
|
|
|
|
|
|
|
# Build Dataset
|
|
|
|
self.dataset_cfg = dataset_cfg
|
|
|
|
assert dataset_cfg is not None, 'dataset_cfg is None'
|
|
|
|
|
|
|
|
self.dict_postprocessor = dict_postprocessor
|
|
|
|
self.pred_postprocessor = pred_postprocessor
|
|
|
|
|
2025-05-20 16:46:55 +08:00
|
|
|
def build_inferencer(self):
|
2024-12-30 17:31:00 +08:00
|
|
|
"""Build LLM Inference."""
|
|
|
|
|
2025-05-20 16:46:55 +08:00
|
|
|
self.output_path = f'{self._out_dir}_replica{self.dataset_replica_idx}.json' # noqa
|
|
|
|
logger.info(f'LLM judge details will be saved at:{self.output_path}')
|
|
|
|
out_dir, out_name = osp.split(self.output_path)
|
|
|
|
|
|
|
|
logger.info(
|
2024-12-31 13:05:05 +08:00
|
|
|
f'Set self.output_path to {self.output_path} for current task')
|
|
|
|
assert self.output_path is not None, 'output_path is None'
|
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
# Build LLM Inference
|
|
|
|
max_out_len = self.judge_cfg.get('max_out_len', None)
|
|
|
|
batch_size = self.judge_cfg.get('batch_size', None)
|
|
|
|
|
|
|
|
model = build_model_from_cfg(model_cfg=self.judge_cfg)
|
|
|
|
|
|
|
|
self.inferencer = GenInferencer(
|
|
|
|
model,
|
|
|
|
max_out_len=max_out_len,
|
|
|
|
batch_size=batch_size,
|
|
|
|
output_json_filepath=out_dir,
|
|
|
|
output_json_filename=out_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
def score(
|
|
|
|
self,
|
|
|
|
predictions,
|
|
|
|
references: Optional[List] = None,
|
2025-04-08 11:58:14 +08:00
|
|
|
test_set: Optional[Dataset] = None,
|
2024-12-30 17:31:00 +08:00
|
|
|
) -> Dict:
|
2025-04-08 11:58:14 +08:00
|
|
|
"""Apply to single-model scoring.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
predictions: List of model predictions
|
|
|
|
references: List of reference answers
|
|
|
|
test_set: Optional Dataset containing additional
|
|
|
|
context for evaluation
|
|
|
|
"""
|
2025-03-24 14:25:12 +08:00
|
|
|
assert len(predictions) == len(
|
|
|
|
references), 'predictions and references must have the same length'
|
2025-04-08 11:58:14 +08:00
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
# -------------- Build Inferencer ----------------
|
|
|
|
self.build_inferencer()
|
|
|
|
# ---------------- Process Predictions ------------------
|
2025-01-07 00:14:32 +08:00
|
|
|
predictions = self.pred_postprocess(predictions)
|
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
# For Single Round Dialogue
|
2025-04-08 11:58:14 +08:00
|
|
|
prediction_dict = {'prediction': predictions, 'obj_gold': references}
|
2024-12-30 17:31:00 +08:00
|
|
|
|
|
|
|
# ---------------- Build Dataset for LLM Judge -----------------
|
|
|
|
if self.dataset_cfg:
|
|
|
|
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
|
|
|
for k, v in prediction_dict.items():
|
|
|
|
dataset.reader.dataset['test'] = dataset.test.add_column(k, v)
|
|
|
|
dataset.reader.input_columns.append(k)
|
|
|
|
|
|
|
|
if references:
|
|
|
|
dataset.reader.input_columns.append('reference')
|
|
|
|
dataset.reader.dataset['test'] = dataset.test.add_column(
|
|
|
|
'reference', references)
|
|
|
|
else:
|
2025-04-08 11:58:14 +08:00
|
|
|
# Handle test_set in the else branch
|
2024-12-30 17:31:00 +08:00
|
|
|
from opencompass.datasets.lmeval import LMEvalDataset
|
|
|
|
|
2025-04-08 11:58:14 +08:00
|
|
|
if test_set is not None:
|
|
|
|
# If test_set is provided, use it as the base
|
|
|
|
# Ensure necessary columns exist
|
|
|
|
if 'prediction' not in test_set.column_names:
|
|
|
|
test_set = test_set.add_column('prediction', predictions)
|
|
|
|
if 'reference' not in test_set.column_names:
|
|
|
|
test_set = test_set.add_column('reference', references)
|
|
|
|
|
|
|
|
# Prepare input_columns and data dictionary
|
|
|
|
input_columns = test_set.column_names
|
|
|
|
data_dict = {
|
|
|
|
column: test_set[column]
|
|
|
|
for column in test_set.column_names
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
# Original default dataset building logic
|
|
|
|
input_columns = list(prediction_dict.keys())
|
|
|
|
if references:
|
|
|
|
input_columns.append('reference')
|
|
|
|
data_dict = prediction_dict.copy()
|
|
|
|
if references:
|
|
|
|
data_dict['reference'] = references
|
|
|
|
|
|
|
|
# Create LMEvalDataset
|
2024-12-30 17:31:00 +08:00
|
|
|
dataset = LMEvalDataset(
|
2025-04-08 11:58:14 +08:00
|
|
|
reader_cfg=dict(
|
|
|
|
input_columns=input_columns,
|
|
|
|
output_column=None,
|
|
|
|
train_split='test',
|
|
|
|
),
|
|
|
|
**data_dict,
|
2024-12-30 17:31:00 +08:00
|
|
|
)
|
2025-04-08 11:58:14 +08:00
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
dataset.reader.output_column = 'reference'
|
|
|
|
retriever = ZeroRetriever(dataset)
|
|
|
|
# ----------------- LLM Judge ----------------
|
|
|
|
self.inferencer.inference(retriever=retriever,
|
|
|
|
prompt_template=self.prompt_template)
|
|
|
|
|
|
|
|
output = mmengine.load(self.output_path)
|
2025-03-24 14:25:12 +08:00
|
|
|
return self.output_postprocess(output, dataset)
|
2025-01-07 00:14:32 +08:00
|
|
|
|
|
|
|
def pred_postprocess(self, predictions: List) -> Dict:
|
|
|
|
if self.pred_postprocessor is None:
|
|
|
|
return predictions
|
|
|
|
else:
|
|
|
|
kwargs = self.pred_postprocessor
|
|
|
|
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
|
|
|
return [proc(pred, **kwargs) for pred in predictions]
|
2024-12-30 17:31:00 +08:00
|
|
|
|
2025-03-24 14:25:12 +08:00
|
|
|
def output_postprocess(self, output: Dict, dataset=None) -> Dict:
|
2024-12-30 17:31:00 +08:00
|
|
|
"""Postprocess output by adding necessary statistics or data into
|
|
|
|
it."""
|
2025-03-24 14:25:12 +08:00
|
|
|
import inspect
|
|
|
|
|
2024-12-30 17:31:00 +08:00
|
|
|
if self.dict_postprocessor is None:
|
|
|
|
return output
|
|
|
|
else:
|
2025-05-20 16:46:55 +08:00
|
|
|
kwargs = deepcopy(self.dict_postprocessor)
|
2024-12-30 17:31:00 +08:00
|
|
|
proc = DICT_POSTPROCESSORS.get(kwargs.pop('type'))
|
2025-03-24 14:25:12 +08:00
|
|
|
sig = inspect.signature(proc)
|
|
|
|
if 'dataset' in sig.parameters:
|
|
|
|
return proc(output,
|
|
|
|
self.output_path,
|
|
|
|
dataset=dataset,
|
|
|
|
**kwargs)
|
|
|
|
else:
|
|
|
|
return proc(output, self.output_path, **kwargs)
|
2025-03-18 20:15:20 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def default_judge_cfg(self):
|
|
|
|
from opencompass.models import OpenAISDK
|
2025-05-20 16:46:55 +08:00
|
|
|
logger.info('Please set your judge model in `OC_JUDGE_MODEL`, \
|
|
|
|
`OC_JUDGE_API_KEY`, `OC_JUDGE_API_BASE` environment variables.')
|
2025-03-18 20:15:20 +08:00
|
|
|
DEFAULT_JUDGE_CFG = dict(
|
|
|
|
type=OpenAISDK,
|
|
|
|
path=os.environ['OC_JUDGE_MODEL'],
|
|
|
|
key=os.environ['OC_JUDGE_API_KEY'],
|
|
|
|
openai_api_base=[
|
|
|
|
os.environ.get('OC_JUDGE_API_BASE',
|
|
|
|
'https://api.openai.com/v1/')
|
|
|
|
],
|
|
|
|
meta_template=dict(round=[
|
|
|
|
dict(role='HUMAN', api_role='HUMAN'),
|
|
|
|
dict(role='BOT', api_role='BOT', generate=True),
|
|
|
|
], ),
|
|
|
|
query_per_second=16,
|
|
|
|
batch_size=1024,
|
|
|
|
temperature=0.001,
|
|
|
|
tokenizer_path='gpt-4o-2024-05-13',
|
|
|
|
verbose=True,
|
|
|
|
max_out_len=16384,
|
|
|
|
max_seq_len=49152,
|
|
|
|
)
|
|
|
|
|
|
|
|
return DEFAULT_JUDGE_CFG
|