2023-07-04 21:34:55 +08:00
import os
from typing import Dict , List , Optional , Union
import numpy as np
import torch
from opencompass . models . base import BaseModel
from opencompass . registry import MODELS
from opencompass . utils . logging import get_logger
from opencompass . utils . prompt import PromptList
PromptType = Union [ PromptList , str ]
@MODELS.register_module ( )
class HuggingFace ( BaseModel ) :
2023-08-25 17:36:30 +08:00
""" Model wrapper around HuggingFace models.
2023-07-04 21:34:55 +08:00
Args :
path ( str ) : The name or path to HuggingFace ' s model.
hf_cache_dir : Set the cache dir to HF model cache dir . If None , it will
use the env variable HF_MODEL_HUB . Defaults to None .
max_seq_len ( int ) : The maximum length of the input sequence . Defaults
to 2048.
tokenizer_path ( str ) : The path to the tokenizer . Defaults to None .
tokenizer_kwargs ( dict ) : Keyword arguments for the tokenizer .
Defaults to { } .
2023-07-18 16:21:43 +08:00
peft_path ( str , optional ) : The name or path to the HuggingFace ' s PEFT
model . If None , the original model will not be converted to PEFT .
Defaults to None .
2023-07-04 21:34:55 +08:00
tokenizer_only ( bool ) : If True , only the tokenizer will be initialized .
Defaults to False .
model_kwargs ( dict ) : Keyword arguments for the model , used in loader .
Defaults to dict ( device_map = ' auto ' ) .
meta_template ( Dict , optional ) : The model ' s meta prompt
template if needed , in case the requirement of injecting or
wrapping of any meta instructions .
extract_pred_after_decode ( bool ) : Whether to extract the prediction
string from the decoded output string , instead of extract the
prediction tokens before decoding . Defaults to False .
batch_padding ( bool ) : If False , inference with be performed in for - loop
without batch padding .
2023-08-31 16:53:39 +08:00
pad_token_id ( int ) : The id of the padding token . Defaults to None . Use
( #vocab + pad_token_id) if get negative value.
2023-09-27 16:32:40 +08:00
mode ( str , optional ) : The method of input truncation when input length
exceeds max_seq_len . ' mid ' represents the part of input to
truncate . Defaults to ' none ' .
2023-07-04 21:34:55 +08:00
Note :
About ` ` extract_pred_after_decode ` ` : Commonly , we should extract the
the prediction tokens before decoding . But for some tokenizers using
` ` sentencepiece ` ` , like LLaMA , this behavior may change the number of
whitespaces , which is harmful for Python programming tasks .
"""
def __init__ ( self ,
path : str ,
hf_cache_dir : Optional [ str ] = None ,
max_seq_len : int = 2048 ,
tokenizer_path : Optional [ str ] = None ,
tokenizer_kwargs : dict = dict ( ) ,
2023-07-18 16:21:43 +08:00
peft_path : Optional [ str ] = None ,
2023-07-04 21:34:55 +08:00
tokenizer_only : bool = False ,
model_kwargs : dict = dict ( device_map = ' auto ' ) ,
meta_template : Optional [ Dict ] = None ,
extract_pred_after_decode : bool = False ,
2023-08-31 16:53:39 +08:00
batch_padding : bool = False ,
2023-09-27 16:32:40 +08:00
pad_token_id : Optional [ int ] = None ,
mode : str = ' none ' ) :
2023-07-04 21:34:55 +08:00
super ( ) . __init__ ( path = path ,
max_seq_len = max_seq_len ,
tokenizer_only = tokenizer_only ,
meta_template = meta_template )
from opencompass . utils . fileio import patch_hf_auto_model
if hf_cache_dir is None :
hf_cache_dir = os . getenv ( ' HF_MODEL_HUB ' , None )
patch_hf_auto_model ( hf_cache_dir )
self . logger = get_logger ( )
2023-08-31 16:53:39 +08:00
self . pad_token_id = pad_token_id
2023-09-27 16:32:40 +08:00
assert mode in [ ' none ' , ' mid ' ]
self . mode = mode
2023-07-04 21:34:55 +08:00
self . _load_tokenizer ( path = path ,
tokenizer_path = tokenizer_path ,
tokenizer_kwargs = tokenizer_kwargs )
self . batch_padding = batch_padding
self . extract_pred_after_decode = extract_pred_after_decode
if not tokenizer_only :
2023-07-18 16:21:43 +08:00
self . _load_model ( path = path ,
model_kwargs = model_kwargs ,
peft_path = peft_path )
2023-07-04 21:34:55 +08:00
def _load_tokenizer ( self , path : str , tokenizer_path : Optional [ str ] ,
tokenizer_kwargs : dict ) :
from transformers import AutoTokenizer
self . tokenizer = AutoTokenizer . from_pretrained (
tokenizer_path if tokenizer_path else path , * * tokenizer_kwargs )
2023-08-31 16:53:39 +08:00
# A patch for some models without pad_token_id
if self . pad_token_id is not None :
if self . pad_token_id < 0 :
self . pad_token_id + = self . tokenizer . vocab_size
if self . tokenizer . pad_token_id is None :
self . logger . warning (
f ' Using { self . pad_token_id } as pad_token_id ' )
elif self . tokenizer . pad_token_id != self . pad_token_id :
self . logger . warning (
f ' pad_token_id is not consistent with the tokenizer. Using { self . pad_token_id } as pad_token_id ' # noqa
)
self . tokenizer . pad_token_id = self . pad_token_id
elif self . tokenizer . pad_token_id is None :
self . logger . warning ( ' pad_token_id is not set for the tokenizer. ' )
if self . tokenizer . eos_token is not None :
self . logger . warning ( ' Using eos_token_id as pad_token_id. ' )
self . logger . warning (
f ' { self . tokenizer . eos_token } la { self . tokenizer . eos_token is None } ' # noqa
)
self . tokenizer . pad_token = self . tokenizer . eos_token
else :
raise ValueError (
' pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id= {PAD_TOKEN_ID} ` in model_cfg. You may find pad_token_id in `generation.json` ' # noqa
)
2023-07-04 21:34:55 +08:00
# A patch for llama when batch_padding = True
if ' decapoda-research/llama ' in path or \
( tokenizer_path and
' decapoda-research/llama ' in tokenizer_path ) :
self . logger . warning ( ' We set new pad_token_id for LLaMA model ' )
# keep consistent with official LLaMA repo
# https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb # noqa
self . tokenizer . bos_token = ' <s> '
self . tokenizer . eos_token = ' </s> '
self . tokenizer . pad_token_id = 0
2023-07-18 16:21:43 +08:00
def _load_model ( self ,
path : str ,
model_kwargs : dict ,
peft_path : Optional [ str ] = None ) :
2023-08-25 17:36:30 +08:00
from transformers import AutoModel , AutoModelForCausalLM
2023-07-04 21:34:55 +08:00
model_kwargs . setdefault ( ' torch_dtype ' , torch . float16 )
2023-08-25 17:36:30 +08:00
try :
self . model = AutoModelForCausalLM . from_pretrained (
path , * * model_kwargs )
except ValueError :
self . model = AutoModel . from_pretrained ( path , * * model_kwargs )
2023-07-18 16:21:43 +08:00
if peft_path is not None :
from peft import PeftModel
self . model = PeftModel . from_pretrained ( self . model ,
peft_path ,
is_trainable = False )
2023-07-04 21:34:55 +08:00
self . model . eval ( )
# A patch for llama when batch_padding = True
if ' decapoda-research/llama ' in path :
self . model . config . bos_token_id = 1
self . model . config . eos_token_id = 2
self . model . config . pad_token_id = self . tokenizer . pad_token_id
2023-07-28 17:29:37 +08:00
def generate ( self , inputs : List [ str ] , max_out_len : int ,
* * kwargs ) - > List [ str ] :
2023-07-04 21:34:55 +08:00
""" Generate results given a list of inputs.
Args :
inputs ( List [ str ] ) : A list of strings .
max_out_len ( int ) : The maximum length of the output .
Returns :
List [ str ] : A list of generated strings .
"""
if self . batch_padding and len ( inputs ) > 1 :
2023-07-28 17:29:37 +08:00
return self . _batch_generate ( inputs = inputs ,
max_out_len = max_out_len ,
* * kwargs )
2023-07-04 21:34:55 +08:00
else :
2023-07-28 17:29:37 +08:00
return sum ( ( self . _single_generate (
inputs = [ input_ ] , max_out_len = max_out_len , * * kwargs )
2023-07-04 21:34:55 +08:00
for input_ in inputs ) , [ ] )
2023-07-28 17:29:37 +08:00
def _batch_generate ( self , inputs : List [ str ] , max_out_len : int ,
* * kwargs ) - > List [ str ] :
2023-07-04 21:34:55 +08:00
""" Support for batch prompts inference.
Args :
inputs ( List [ str ] ) : A list of strings .
max_out_len ( int ) : The maximum length of the output .
Returns :
List [ str ] : A list of generated strings .
"""
if self . extract_pred_after_decode :
prompt_lens = [ len ( input_ ) for input_ in inputs ]
# step-1: tokenize the input with batch_encode_plus
tokens = self . tokenizer . batch_encode_plus ( inputs ,
padding = True ,
truncation = True ,
2023-07-17 15:59:10 +08:00
max_length = self . max_seq_len -
max_out_len )
2023-07-04 21:34:55 +08:00
tokens = {
k : torch . tensor ( np . array ( tokens [ k ] ) , device = self . model . device )
for k in tokens if k in [ ' input_ids ' , ' attention_mask ' ]
}
# step-2: conduct model forward to generate output
2023-07-28 17:29:37 +08:00
outputs = self . model . generate ( * * tokens ,
max_new_tokens = max_out_len ,
* * kwargs )
2023-07-04 21:34:55 +08:00
if not self . extract_pred_after_decode :
outputs = outputs [ : , tokens [ ' input_ids ' ] . shape [ 1 ] : ]
decodeds = self . tokenizer . batch_decode ( outputs ,
skip_special_tokens = True )
if self . extract_pred_after_decode :
decodeds = [
token [ len_ : ] for token , len_ in zip ( decodeds , prompt_lens )
]
return decodeds
2023-07-28 17:29:37 +08:00
def _single_generate ( self , inputs : List [ str ] , max_out_len : int ,
* * kwargs ) - > List [ str ] :
2023-07-04 21:34:55 +08:00
""" Support for single prompt inference.
Args :
inputs ( List [ str ] ) : A list of strings .
max_out_len ( int ) : The maximum length of the output .
Returns :
List [ str ] : A list of generated strings .
"""
if self . extract_pred_after_decode :
prompt_lens = [ len ( input_ ) for input_ in inputs ]
2023-09-27 16:32:40 +08:00
if self . mode == ' mid ' :
input_ids = self . tokenizer ( inputs , truncation = False ) [ ' input_ids ' ]
input_ids = torch . tensor ( input_ids , device = self . model . device )
if len ( input_ids [ 0 ] ) > self . max_seq_len - max_out_len :
half = int ( ( self . max_seq_len - max_out_len ) / 2 )
inputs = [
self . tokenizer . decode ( input_ids [ 0 ] [ : half ] ,
skip_special_tokens = True ) +
self . tokenizer . decode ( input_ids [ 0 ] [ - half : ] ,
skip_special_tokens = True )
]
2023-07-04 21:34:55 +08:00
input_ids = self . tokenizer ( inputs ,
truncation = True ,
2023-07-17 15:59:10 +08:00
max_length = self . max_seq_len -
max_out_len ) [ ' input_ids ' ]
2023-07-04 21:34:55 +08:00
input_ids = torch . tensor ( input_ids , device = self . model . device )
2023-08-24 14:07:33 +08:00
# To accommodate the PeftModel, parameters should be passed in
# key-value format for generate.
outputs = self . model . generate ( input_ids = input_ids ,
2023-07-28 17:29:37 +08:00
max_new_tokens = max_out_len ,
* * kwargs )
2023-07-04 21:34:55 +08:00
if not self . extract_pred_after_decode :
outputs = outputs [ : , input_ids . shape [ 1 ] : ]
decodeds = self . tokenizer . batch_decode ( outputs ,
skip_special_tokens = True )
if self . extract_pred_after_decode :
decodeds = [
token [ len_ : ] for token , len_ in zip ( decodeds , prompt_lens )
]
return decodeds
def get_logits ( self , inputs : List [ str ] ) :
if self . batch_padding and len ( inputs ) > 1 :
# batch inference
tokens = self . tokenizer ( inputs ,
padding = True ,
truncation = True ,
max_length = self . max_seq_len )
tokens = {
k : torch . tensor ( np . array ( tokens [ k ] ) , device = self . model . device )
for k in tokens if k in [ ' input_ids ' , ' attention_mask ' ]
}
outputs = self . model ( * * tokens )
else :
input_ids = self . tokenizer (
inputs ,
padding = False ,
truncation = True ,
max_length = self . max_seq_len ) [ ' input_ids ' ]
input_ids = torch . tensor ( input_ids , device = self . model . device )
tokens = { ' input_ids ' : input_ids }
outputs = self . model ( input_ids )
return outputs [ 0 ] , { ' tokens ' : tokens }
def get_ppl ( self ,
inputs : List [ str ] ,
mask_length : Optional [ List [ int ] ] = None ) - > List [ float ] :
""" Get perplexity scores given a list of inputs.
Args :
inputs ( List [ str ] ) : A list of strings .
mask_length ( Optional [ List [ int ] ] ) : A list of mask lengths . If
provided , the perplexity scores will be calculated with the
first mask_length [ i ] tokens masked out . It ' s okay to skip
its implementation if advanced features in PPLInfernecer is
not needed .
Returns :
List [ float ] : A list of perplexity scores .
"""
if self . batch_padding and len ( inputs ) > 1 :
assert self . tokenizer . pad_token
return self . _get_ppl ( inputs , mask_length = mask_length )
else :
return np . concatenate ( [
self . _get_ppl ( inputs = [ text ] , mask_length = mask_length )
for text in inputs
] )
def _get_ppl ( self ,
inputs : List [ str ] ,
mask_length : Optional [ List [ int ] ] = None ) - > List [ float ] :
""" Get perplexity scores given a list of inputs.
Args :
inputs ( List [ str ] ) : A list of strings .
mask_length ( Optional [ List [ int ] ] ) : A list of mask lengths . If
provided , the perplexity scores will be calculated with the
first mask_length [ i ] tokens masked out . It ' s okay to skip
its implementation if advanced features in PPLInfernecer is
not needed .
Returns :
List [ float ] : A list of perplexity scores .
"""
outputs , inputs = self . get_logits ( inputs )
2023-08-31 16:53:39 +08:00
shift_logits = outputs [ . . . , : - 1 , : ] . contiguous ( ) . float ( )
2023-07-04 21:34:55 +08:00
shift_labels = inputs [ ' tokens ' ] [ ' input_ids ' ] [ . . . , 1 : ] . contiguous ( )
loss_fct = torch . nn . CrossEntropyLoss (
reduction = ' none ' , ignore_index = self . tokenizer . pad_token_id )
loss = loss_fct ( shift_logits . view ( - 1 , shift_logits . size ( - 1 ) ) ,
shift_labels . view ( - 1 ) ) . view ( shift_labels . size ( ) )
if mask_length is not None :
mask = torch . zeros_like ( shift_labels ) # [batch,seqlen]
for i in range ( len ( mask ) ) :
for j in range ( mask_length [ i ] - 1 , len ( mask [ i ] ) ) :
mask [ i ] [ j ] = 1
loss = loss * mask
lens = ( inputs [ ' tokens ' ] [ ' input_ids ' ] !=
self . tokenizer . pad_token_id ) . sum ( - 1 ) . cpu ( ) . numpy ( )
if mask_length is not None :
lens - = np . array ( mask_length )
ce_loss = loss . sum ( - 1 ) . cpu ( ) . detach ( ) . numpy ( ) / lens
return ce_loss
def get_token_len ( self , prompt : str ) - > int :
""" Get lengths of the tokenized strings.
Args :
prompt ( str ) : Input string .
Returns :
int : Length of the input tokens
"""
return len ( self . tokenizer . encode ( prompt ) )
@MODELS.register_module ( )
class HuggingFaceCausalLM ( HuggingFace ) :
""" Model wrapper around HuggingFace CausalLM.
Args :
path ( str ) : The name or path to HuggingFace ' s model.
hf_cache_dir : Set the cache dir to HF model cache dir . If None , it will
use the env variable HF_MODEL_HUB . Defaults to None .
max_seq_len ( int ) : The maximum length of the input sequence . Defaults
to 2048.
tokenizer_path ( str ) : The path to the tokenizer . Defaults to None .
tokenizer_kwargs ( dict ) : Keyword arguments for the tokenizer .
Defaults to { } .
2023-07-18 16:21:43 +08:00
peft_path ( str , optional ) : The name or path to the HuggingFace ' s PEFT
model . If None , the original model will not be converted to PEFT .
Defaults to None .
2023-07-04 21:34:55 +08:00
tokenizer_only ( bool ) : If True , only the tokenizer will be initialized .
Defaults to False .
model_kwargs ( dict ) : Keyword arguments for the model , used in loader .
Defaults to dict ( device_map = ' auto ' ) .
meta_template ( Dict , optional ) : The model ' s meta prompt
template if needed , in case the requirement of injecting or
wrapping of any meta instructions .
batch_padding ( bool ) : If False , inference with be performed in for - loop
without batch padding .
"""
2023-07-18 16:21:43 +08:00
def _load_model ( self ,
path : str ,
model_kwargs : dict ,
peft_path : Optional [ str ] = None ) :
2023-07-04 21:34:55 +08:00
from transformers import AutoModelForCausalLM
model_kwargs . setdefault ( ' torch_dtype ' , torch . float16 )
self . model = AutoModelForCausalLM . from_pretrained ( path , * * model_kwargs )
2023-07-18 16:21:43 +08:00
if peft_path is not None :
from peft import PeftModel
self . model = PeftModel . from_pretrained ( self . model ,
peft_path ,
is_trainable = False )
2023-07-04 21:34:55 +08:00
self . model . eval ( )