2023-12-12 20:58:17 +08:00
# flake8: noqa: E501
2024-06-28 14:16:34 +08:00
# yapf: disable
2023-09-22 15:42:31 +08:00
import os . path as osp
2023-12-11 22:22:11 +08:00
import random
2024-06-28 14:16:34 +08:00
import re
2023-09-22 15:42:31 +08:00
from typing import Dict , List , Optional
import mmengine
2023-12-12 20:58:17 +08:00
from datasets import Dataset
2023-09-22 15:42:31 +08:00
from mmengine . config import ConfigDict
from opencompass . openicl . icl_inferencer import GenInferencer
from opencompass . openicl . icl_retriever import ZeroRetriever
2024-10-15 16:36:05 +08:00
from opencompass . registry import DICT_POSTPROCESSORS , ICL_PROMPT_TEMPLATES
2023-09-22 15:42:31 +08:00
from opencompass . utils import build_dataset_from_cfg , build_model_from_cfg
from opencompass . utils . logging import get_logger
2024-01-24 12:11:47 +08:00
def extract_dicts ( data ) :
max_round_num = max ( len ( sublist ) for sublist in data )
predictions = [ [ ] for _ in range ( max_round_num ) ]
for sublist in data :
for i , d in enumerate ( sublist ) :
predictions [ i ] . append ( d . get ( ' assistant ' ) )
for j in range ( i + 1 , max_round_num ) :
predictions [ j ] . append ( None )
return predictions
2024-06-28 14:16:34 +08:00
def order_preds_and_record_references ( predictions , references , infer_order , seed = 666 ) :
2023-12-12 20:58:17 +08:00
""" Order predictions based on args and recording regrading references.
Args :
predictions ( List ) : List of multi model predictions .
references ( List ) : List of reference based on each problem .
infer_order ( str , optional ) : The mode of inference order .
seed ( int , optional ) : Random seed .
"""
2023-12-11 22:22:11 +08:00
random . seed ( seed )
list_of_preds = [ [ ] for _ in range ( len ( predictions ) ) ]
for i in range ( len ( predictions [ 0 ] [ ' model_preds ' ] ) ) :
2024-06-28 14:16:34 +08:00
preds = [ [ pred [ ' model_preds ' ] [ i ] , pred [ ' model_name ' ] ] for pred in predictions ]
2023-12-12 20:58:17 +08:00
if infer_order == ' random ' :
2023-12-11 22:22:11 +08:00
random . shuffle ( preds )
for j in range ( len ( preds ) ) :
list_of_preds [ j ] . append ( preds [ j ] [ 0 ] )
references [ i ] [ f ' answer { j + 1 } ' ] = preds [ j ] [ 1 ]
2023-12-12 20:58:17 +08:00
if infer_order == ' double ' :
assert len ( predictions ) == 2
2024-06-28 14:16:34 +08:00
list_of_preds = [ a + b for a , b in zip ( list_of_preds , reversed ( list_of_preds ) ) ]
2023-12-12 20:58:17 +08:00
reversed_references = [ ]
for item in references :
reversed_item = item . copy ( )
2024-06-28 14:16:34 +08:00
reversed_item [ ' answer1 ' ] , reversed_item [ ' answer2 ' ] = reversed_item [ ' answer2 ' ] , reversed_item [ ' answer1 ' ]
2023-12-12 20:58:17 +08:00
reversed_references . append ( reversed_item )
references + = reversed_references
2023-12-11 22:22:11 +08:00
return list_of_preds , references
2024-06-28 14:16:34 +08:00
def count_chinese_characters ( text ) :
words = re . findall ( r ' [ \ u4e00- \ u9fff] ' , text )
return len ( words )
def count_english_words ( text ) :
words = re . findall ( r ' \ b[a-zA-Z]+ \ b ' , text )
return len ( words )
2023-09-22 15:42:31 +08:00
class LMEvaluator :
""" Evaluate output with language model.
Args :
prompt_template ( ConfigDict ) : Prompt template configuration . Used to
prompt the language model for scores . User can use two reserved
keywords , ` ` { prediction } ` ` and ` ` { reference } ` ` , referring to
the prediction and optionally the reference answer .
judge_cfg ( ConfigDict ) : The config of language model as a judge .
2024-04-22 12:06:03 +08:00
meta_review_prompt_template ( ConfigDict , optional ) : Prompt template for meta judge model .
2023-09-22 15:42:31 +08:00
output_path ( str ) : The path to prediction output .
dataset_cfg ( ConfigDict , optional ) : The config of the dataset to be
evaluated .
2024-04-22 12:06:03 +08:00
pack_all_predictions ( bool , optional ) : For multiround evaluation , judge all round or judge every single round .
2024-04-26 14:56:23 +08:00
pred_postprocessor ( ConfigDict ) : The model prediction ' s postprocessor
2023-09-22 15:42:31 +08:00
config .
"""
def __init__ (
self ,
prompt_template : ConfigDict ,
judge_cfg : ConfigDict ,
output_path : str ,
2024-04-02 11:52:06 +08:00
meta_review_prompt_template : Optional [ ConfigDict ] = None ,
2024-04-22 12:06:03 +08:00
pack_all_predictions : Optional [ bool ] = False ,
2023-09-22 15:42:31 +08:00
dataset_cfg : Optional [ ConfigDict ] = None ,
2024-04-26 14:56:23 +08:00
pred_postprocessor : Optional [ ConfigDict ] = None ,
2024-10-15 16:36:05 +08:00
dict_postprocessor : Optional [ ConfigDict ] = None ,
2023-09-22 15:42:31 +08:00
) - > None :
self . output_path = output_path
out_dir , out_name = osp . split ( output_path )
if not out_dir :
out_dir = ' ./ '
self . prompt_tmpl = ICL_PROMPT_TEMPLATES . build ( prompt_template )
2024-04-02 11:52:06 +08:00
if meta_review_prompt_template is not None :
2024-06-28 14:16:34 +08:00
self . meta_review_prompt_tmpl = ICL_PROMPT_TEMPLATES . build ( meta_review_prompt_template )
2023-09-22 15:42:31 +08:00
max_out_len = judge_cfg . get ( ' max_out_len ' , None )
batch_size = judge_cfg . get ( ' batch_size ' , None )
model = build_model_from_cfg ( model_cfg = judge_cfg )
self . inferencer = GenInferencer ( model ,
max_out_len = max_out_len ,
batch_size = batch_size ,
output_json_filepath = out_dir ,
output_json_filename = out_name )
self . logger = get_logger ( )
self . dataset_cfg = dataset_cfg
2024-04-22 12:06:03 +08:00
self . pack_all_predictions = pack_all_predictions
2024-10-15 16:36:05 +08:00
self . pred_postprocessor = pred_postprocessor
self . dict_postprocessor = dict_postprocessor
2023-09-22 15:42:31 +08:00
2024-04-02 11:52:06 +08:00
def score ( self ,
predictions ,
judgements : Optional [ List ] = None ,
references : Optional [ List ] = None ,
meta : Optional [ bool ] = False ,
infer_order : Optional [ str ] = ' random ' ) - > Dict :
2023-12-13 19:59:30 +08:00
dup_indices = [ ]
2024-04-09 17:50:23 +08:00
if isinstance ( predictions , list ) :
2023-12-11 22:22:11 +08:00
""" Apply to multi-model comparison. """
2024-06-28 14:16:34 +08:00
if references is None :
references = [ { } for _ in range ( len ( predictions [ 0 ] [ ' model_preds ' ] ) ) ]
predictions , references = order_preds_and_record_references ( predictions , references , infer_order )
2023-12-13 19:59:30 +08:00
# calculate dupicated predictions numbers
total_predictions_num = len ( predictions [ 0 ] )
2024-01-24 12:11:47 +08:00
# since there is impossible that two models response same pattern in multi-round chat, so we just check dup for single chat
if isinstance ( predictions [ 0 ] [ 0 ] , str ) :
for i in range ( len ( predictions [ 0 ] ) ) :
check = [ sub [ i ] for sub in predictions ]
if len ( set ( check ) ) == 1 :
dup_indices . append ( i )
2023-12-13 19:59:30 +08:00
2024-04-09 17:50:23 +08:00
elif isinstance ( predictions , dict ) :
2023-12-11 22:22:11 +08:00
""" Apply to single-model scoring. """
2024-06-28 14:16:34 +08:00
if references is None :
references = [ { } for _ in range ( len ( predictions [ 0 ] [ ' model_preds ' ] ) ) ]
2023-12-11 22:22:11 +08:00
predictions = [ predictions [ ' model_preds ' ] ]
2023-12-12 20:58:17 +08:00
2024-04-02 11:52:06 +08:00
# Due to the rarity of identical predictions, we have temporarily disabled the plagiarism detection feature.
dup_indices = [ ]
2023-12-12 20:58:17 +08:00
if len ( dup_indices ) != 0 :
# remove dupicated predictions
for index in sorted ( dup_indices , reverse = True ) :
for sublist in predictions :
del sublist [ index ]
del references [ index ]
2023-10-13 19:50:54 +08:00
pred_dict = { }
2024-06-28 14:16:34 +08:00
if isinstance ( predictions [ 0 ] [ 0 ] , str ) :
# single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
2024-01-24 12:11:47 +08:00
for i in range ( len ( predictions ) ) :
key = ' prediction ' if i == 0 else f ' prediction { i + 1 } '
2024-04-26 14:56:23 +08:00
gold_key = ' obj_gold '
2024-01-24 12:11:47 +08:00
pred_dict [ key ] = predictions [ i ]
2024-04-26 14:56:23 +08:00
pred_dict [ gold_key ] = references
2024-06-28 14:16:34 +08:00
pred_dict [ key + ' _en_word_count ' ] = [ count_english_words ( j ) for j in predictions [ i ] ]
pred_dict [ key + ' _cn_word_count ' ] = [ count_chinese_characters ( j ) for j in predictions [ i ] ]
2024-04-02 11:52:06 +08:00
if judgements :
for i in range ( len ( judgements ) ) :
key = ' judgement ' if i == 0 else f ' judgement { i + 1 } '
pred_dict [ key ] = judgements [ i ] [ ' model_preds ' ]
for j in range ( len ( references ) ) :
2024-06-28 14:16:34 +08:00
references [ j ] [ ' judge_model ' + str ( i + 1 ) ] = judgements [ i ] [ ' model_name ' ]
elif isinstance ( predictions [ 0 ] [ 0 ] , list ) :
# multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
2024-04-22 12:06:03 +08:00
if self . pack_all_predictions :
for i in range ( len ( predictions ) ) :
key = ' prediction ' if i == 0 else f ' prediction { i + 1 } '
2024-11-01 13:52:22 +08:00
predictions [ i ] = [ str ( _ ) for _ in predictions [ i ] ] # Fix the dictionary order to prevent the following situations: {'assistant':'', 'round':2, 'user':''}
2024-04-22 12:06:03 +08:00
pred_dict [ key ] = predictions [ i ]
else :
for i in range ( len ( predictions ) ) :
multiround_predictions = extract_dicts ( predictions [ i ] )
for j in range ( len ( multiround_predictions ) ) :
key = ' prediction ' if i == 0 else f ' prediction { i } '
key + = ' _r ' + str ( j + 1 )
pred_dict [ key ] = multiround_predictions [ j ]
2024-04-02 11:52:06 +08:00
if judgements :
raise NotImplementedError (
' Not applied meta-reivew judge on multi-round dataset ' )
2024-04-26 14:56:23 +08:00
else :
2024-06-28 14:16:34 +08:00
raise NotImplementedError ( f ' { predictions [ 0 ] [ 0 ] } with type { type ( predictions [ 0 ] [ 0 ] ) } , please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None ' )
2023-09-22 15:42:31 +08:00
if self . dataset_cfg :
dataset = build_dataset_from_cfg ( self . dataset_cfg )
2023-12-12 20:58:17 +08:00
2024-04-02 11:52:06 +08:00
if infer_order == ' double ' :
2024-06-28 14:16:34 +08:00
new_ds = { k : dataset . test [ k ] * 2 for k in dataset . test . column_names }
2023-12-12 20:58:17 +08:00
dataset . reader . dataset [ ' test ' ] = Dataset . from_dict ( new_ds )
if len ( dup_indices ) != 0 :
2024-06-28 14:16:34 +08:00
remaining_indices = [ idx for idx in range ( len ( dataset . test ) ) if idx not in dup_indices ]
dataset . reader . dataset [ ' test ' ] = dataset . test . select ( remaining_indices )
print ( f ' Among total { total_predictions_num } predictions, there are { len ( dup_indices ) } predictions totally same, which are removed! ' )
2023-10-13 19:50:54 +08:00
for k , v in pred_dict . items ( ) :
dataset . reader . dataset [ ' test ' ] = dataset . test . add_column ( k , v )
dataset . reader . input_columns . append ( k )
2024-06-06 11:40:48 +08:00
2023-09-22 15:42:31 +08:00
if references :
dataset . reader . input_columns . append ( ' reference ' )
2024-06-28 14:16:34 +08:00
dataset . reader . dataset [ ' test ' ] = dataset . test . add_column ( ' reference ' , references )
2023-09-22 15:42:31 +08:00
else :
2023-10-13 19:50:54 +08:00
# build a default dataset just for comparison
2023-09-22 15:42:31 +08:00
from opencompass . datasets . lmeval import LMEvalDataset
2023-10-13 19:50:54 +08:00
input_columns = list ( pred_dict . keys ( ) )
2023-09-22 15:42:31 +08:00
if references :
input_columns . append ( ' reference ' )
2024-06-28 14:16:34 +08:00
dataset = LMEvalDataset (
reader_cfg = dict ( input_columns = input_columns , output_column = None , train_split = ' test ' ) ,
reference = references ,
* * pred_dict
)
2023-12-11 22:22:11 +08:00
dataset . reader . output_column = ' reference '
2023-09-22 15:42:31 +08:00
retriever = ZeroRetriever ( dataset )
2024-06-28 14:16:34 +08:00
2024-04-02 11:52:06 +08:00
if meta :
2024-06-28 14:16:34 +08:00
self . inferencer . inference ( retriever = retriever , prompt_template = self . meta_review_prompt_tmpl )
2024-04-02 11:52:06 +08:00
else :
2024-06-28 14:16:34 +08:00
self . inferencer . inference ( retriever = retriever , prompt_template = self . prompt_tmpl )
2023-09-22 15:42:31 +08:00
output = mmengine . load ( self . output_path )
2023-10-13 19:50:54 +08:00
return self . postprocess ( output )
def postprocess ( self , output : Dict ) - > Dict :
""" Postprocess output by adding necessary statistics or data into
it . """
2024-10-15 16:36:05 +08:00
if self . dict_postprocessor is None :
return output
else :
kwargs = self . dict_postprocessor
proc = DICT_POSTPROCESSORS . get ( kwargs . pop ( ' type ' ) )
return proc ( output , self . output_path , * * kwargs )