2024-05-14 14:50:16 +08:00
|
|
|
# flake8: noqa
|
|
|
|
# yapf: disable
|
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
|
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
|
|
|
from opencompass.models.base_api import APITemplateParser
|
|
|
|
from opencompass.registry import MODELS
|
|
|
|
from opencompass.utils.logging import get_logger
|
|
|
|
from opencompass.utils.prompt import PromptList
|
|
|
|
|
|
|
|
PromptType = Union[PromptList, str]
|
|
|
|
|
|
|
|
|
|
|
|
def _get_stopping_criteria(stop_words, tokenizer, batch_size):
|
|
|
|
from transformers import (PreTrainedTokenizer, StoppingCriteria,
|
|
|
|
StoppingCriteriaList)
|
|
|
|
|
|
|
|
class MultiTokenEOSCriteria(StoppingCriteria):
|
|
|
|
"""Criteria to stop on the specified multi-token sequence."""
|
|
|
|
|
|
|
|
def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, batch_size: int):
|
|
|
|
self.done_tracker = [False] * batch_size
|
|
|
|
self.sequence = sequence
|
|
|
|
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
|
|
|
|
self.sequence_id_len = len(self.sequence_ids)
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
|
|
|
# compare the last len(stop) tokens
|
|
|
|
lookback_ids_batch = input_ids[:, -self.sequence_id_len:]
|
|
|
|
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
|
|
|
|
for i, done in enumerate(self.done_tracker):
|
|
|
|
if done:
|
|
|
|
continue
|
|
|
|
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
|
|
|
|
return False not in self.done_tracker
|
|
|
|
|
|
|
|
criteria = []
|
|
|
|
for stop_word in stop_words:
|
|
|
|
c = MultiTokenEOSCriteria(stop_word, tokenizer, batch_size)
|
|
|
|
criteria.append(c)
|
|
|
|
criteria = StoppingCriteriaList(criteria)
|
|
|
|
return criteria
|
|
|
|
|
|
|
|
def _get_possible_max_seq_len(max_seq_len, path):
|
|
|
|
if max_seq_len is not None:
|
|
|
|
return max_seq_len
|
|
|
|
|
|
|
|
from transformers import AutoConfig
|
|
|
|
config = AutoConfig.from_pretrained(path, trust_remote_code=True)
|
|
|
|
possible_keys = [
|
|
|
|
'max_position_embeddings',
|
|
|
|
'seq_length',
|
|
|
|
'model_max_length',
|
|
|
|
]
|
|
|
|
for k in possible_keys:
|
|
|
|
if hasattr(config, k):
|
|
|
|
return getattr(config, k)
|
|
|
|
raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.')
|
|
|
|
|
|
|
|
|
2024-05-21 14:22:46 +08:00
|
|
|
def _convert_chat_messages(inputs, merge_role=True):
|
2024-05-14 14:50:16 +08:00
|
|
|
outputs = []
|
|
|
|
for _input in inputs:
|
|
|
|
messages = []
|
|
|
|
if isinstance(_input, str):
|
2024-05-17 16:50:58 +08:00
|
|
|
messages.append({'role': 'user', 'content': _input})
|
2024-05-14 14:50:16 +08:00
|
|
|
else:
|
|
|
|
for item in _input:
|
|
|
|
role = {
|
|
|
|
'HUMAN': 'user',
|
|
|
|
'BOT': 'assistant',
|
|
|
|
'SYSTEM': 'system',
|
|
|
|
}[item['role']]
|
|
|
|
messages.append({'role': role, 'content': item['prompt']})
|
2024-05-21 14:22:46 +08:00
|
|
|
|
|
|
|
if merge_role:
|
|
|
|
merged_messages = []
|
|
|
|
for item in messages:
|
|
|
|
if merged_messages and merged_messages[-1]['role'] == item['role']:
|
|
|
|
merged_messages[-1]['content'] += '\n' + item['content']
|
|
|
|
else:
|
|
|
|
merged_messages.append(item)
|
|
|
|
messages = merged_messages
|
|
|
|
|
2024-05-14 14:50:16 +08:00
|
|
|
outputs.append(messages)
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
def _format_with_fast_chat_template(inputs: List[str], name: str='vicuna'):
|
|
|
|
try:
|
|
|
|
from fastchat.model import get_conversation_template
|
|
|
|
except ImportError:
|
|
|
|
raise ModuleNotFoundError('fastchat not found. Please install with\npip install "fschat[model_worker,webui]"')
|
|
|
|
|
|
|
|
outputs = []
|
|
|
|
for _input in inputs:
|
|
|
|
template = get_conversation_template(name)
|
|
|
|
for item in _input:
|
|
|
|
if item['role'] == 'user':
|
|
|
|
template.append_message(template.roles[0], item['content'])
|
|
|
|
elif item['role'] == 'assistant':
|
|
|
|
template.append_message(template.roles[1], item['content'])
|
|
|
|
elif item['role'] == 'system':
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unknown role {item["role"]}')
|
|
|
|
template.append_message(template.roles[1], None)
|
|
|
|
outputs.append(template.get_prompt())
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
def _get_meta_template(meta_template):
|
|
|
|
default_meta_template = dict(
|
|
|
|
round=[
|
|
|
|
dict(role='HUMAN', api_role='HUMAN'),
|
2024-05-21 14:22:46 +08:00
|
|
|
# XXX: all system roles are mapped to human in purpose
|
|
|
|
dict(role='SYSTEM', api_role='HUMAN'),
|
2024-05-14 14:50:16 +08:00
|
|
|
dict(role='BOT', api_role='BOT', generate=True),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
return APITemplateParser(meta_template or default_meta_template)
|
|
|
|
|
|
|
|
|
|
|
|
def _set_model_kwargs_torch_dtype(model_kwargs):
|
|
|
|
import torch
|
|
|
|
if 'torch_dtype' not in model_kwargs:
|
|
|
|
torch_dtype = torch.float16
|
|
|
|
else:
|
|
|
|
torch_dtype = {
|
|
|
|
'torch.float16': torch.float16,
|
|
|
|
'torch.bfloat16': torch.bfloat16,
|
|
|
|
'torch.float': torch.float,
|
|
|
|
'auto': 'auto',
|
|
|
|
'None': None,
|
|
|
|
}.get(model_kwargs['torch_dtype'])
|
|
|
|
if torch_dtype is not None:
|
|
|
|
model_kwargs['torch_dtype'] = torch_dtype
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
|
class HuggingFacewithChatTemplate(BaseModel):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
path: str,
|
|
|
|
model_kwargs: dict = dict(),
|
|
|
|
tokenizer_path: Optional[str] = None,
|
|
|
|
tokenizer_kwargs: dict = dict(),
|
|
|
|
peft_path: Optional[str] = None,
|
|
|
|
peft_kwargs: dict = dict(),
|
|
|
|
tokenizer_only: bool = False,
|
|
|
|
generation_kwargs: dict = dict(),
|
|
|
|
max_seq_len: Optional[int] = None,
|
|
|
|
meta_template: Optional[Dict] = None,
|
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
|
fastchat_template: Optional[str] = None,
|
|
|
|
stop_words: Optional[str] = [],
|
|
|
|
**other_kwargs):
|
|
|
|
|
|
|
|
self.logger = get_logger()
|
|
|
|
self.path = path
|
|
|
|
self.tokenizer_only = tokenizer_only
|
|
|
|
self.template_parser = _get_meta_template(meta_template)
|
|
|
|
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
|
|
|
self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id)
|
|
|
|
if not tokenizer_only:
|
|
|
|
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
|
|
|
self.generation_kwargs = generation_kwargs
|
|
|
|
self.fastchat_template = fastchat_template
|
2024-05-15 14:10:33 +08:00
|
|
|
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
2024-05-14 14:50:16 +08:00
|
|
|
|
|
|
|
for k, v in other_kwargs.items():
|
|
|
|
if v is not None:
|
|
|
|
self.logger.warning(f'Unused argument {k}={v}')
|
|
|
|
|
|
|
|
def _load_tokenizer(self, path: Optional[str], kwargs: dict, pad_token_id: Optional[int] = None):
|
|
|
|
from transformers import AutoTokenizer, GenerationConfig
|
|
|
|
|
2024-05-17 16:50:58 +08:00
|
|
|
DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True)
|
2024-05-14 14:50:16 +08:00
|
|
|
tokenizer_kwargs = DEFAULT_TOKENIZER_KWARGS
|
|
|
|
tokenizer_kwargs.update(kwargs)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_kwargs)
|
|
|
|
|
|
|
|
# A patch for some models without pad_token_id
|
|
|
|
if pad_token_id is not None:
|
|
|
|
if self.tokenizer.pad_token_id is None:
|
|
|
|
self.logger.debug(f'Using {pad_token_id} as pad_token_id')
|
|
|
|
elif self.tokenizer.pad_token_id != pad_token_id:
|
|
|
|
self.logger.warning(f'pad_token_id is not consistent. Using {pad_token_id} as pad_token_id')
|
|
|
|
self.tokenizer.pad_token_id = pad_token_id
|
|
|
|
return
|
|
|
|
if self.tokenizer.pad_token_id is not None:
|
|
|
|
return
|
|
|
|
self.logger.warning('pad_token_id is not set for the tokenizer.')
|
|
|
|
generation_config = GenerationConfig.from_pretrained(path)
|
|
|
|
if generation_config.pad_token_id is not None:
|
|
|
|
self.logger.warning(f'Using {generation_config.pad_token_id} as pad_token_id.')
|
|
|
|
self.tokenizer.pad_token_id = generation_config.pad_token_id
|
|
|
|
return
|
|
|
|
if self.tokenizer.eos_token_id is not None:
|
|
|
|
self.logger.warning(f'Using eos_token_id {self.tokenizer.eos_token_id} as pad_token_id.')
|
|
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
|
|
return
|
|
|
|
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
|
|
|
|
|
|
|
def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()):
|
|
|
|
from transformers import AutoModel, AutoModelForCausalLM
|
|
|
|
|
|
|
|
DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True)
|
|
|
|
model_kwargs = DEFAULT_MODEL_KWARGS
|
|
|
|
model_kwargs.update(kwargs)
|
|
|
|
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
2024-05-17 16:50:58 +08:00
|
|
|
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
2024-05-14 14:50:16 +08:00
|
|
|
|
|
|
|
try:
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
|
|
|
except ValueError:
|
|
|
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
|
|
|
|
|
|
|
if peft_path is not None:
|
|
|
|
from peft import PeftModel
|
|
|
|
peft_kwargs['is_trainable'] = False
|
|
|
|
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
self.model.generation_config.do_sample = False
|
|
|
|
|
2024-05-15 14:10:33 +08:00
|
|
|
def _get_potential_stop_words(self, path: Optional[str]):
|
|
|
|
from transformers import GenerationConfig
|
|
|
|
potential_stop_words = []
|
|
|
|
try:
|
|
|
|
generation_config = GenerationConfig.from_pretrained(path)
|
|
|
|
for token_id in generation_config.eos_token_id:
|
|
|
|
potential_stop_words.append(self.tokenizer.decode(token_id))
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
potential_stop_words.append(self.tokenizer.eos_token)
|
|
|
|
potential_stop_words = list(set(potential_stop_words))
|
|
|
|
return potential_stop_words
|
|
|
|
|
2024-05-14 14:50:16 +08:00
|
|
|
def generate(self,
|
|
|
|
inputs: List[str],
|
|
|
|
max_out_len: int,
|
|
|
|
min_out_len: Optional[int] = None,
|
|
|
|
stopping_criteria: List[str] = [],
|
|
|
|
**kwargs) -> List[str]:
|
|
|
|
messages = _convert_chat_messages(inputs)
|
|
|
|
batch_size = len(messages)
|
|
|
|
|
|
|
|
tokenize_kwargs = dict(
|
|
|
|
return_tensors='pt',
|
|
|
|
padding=True,
|
|
|
|
truncation=True,
|
|
|
|
add_special_tokens=True,
|
|
|
|
max_length=self.max_seq_len
|
|
|
|
)
|
|
|
|
if self.fastchat_template:
|
|
|
|
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
|
|
|
|
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
|
|
|
else:
|
|
|
|
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
|
|
|
|
tokenize_kwargs['add_special_tokens'] = False
|
|
|
|
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
|
|
|
|
|
|
|
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
|
|
|
|
|
|
|
generation_kwargs = self.generation_kwargs.copy()
|
|
|
|
generation_kwargs.update(kwargs)
|
|
|
|
stopping_criteria = list(set(stopping_criteria + self.stop_words))
|
|
|
|
if stopping_criteria:
|
|
|
|
generation_kwargs['stopping_criteria'] = _get_stopping_criteria(stopping_criteria, self.tokenizer, batch_size)
|
|
|
|
if max_out_len is not None:
|
|
|
|
generation_kwargs['max_new_tokens'] = max_out_len
|
|
|
|
if min_out_len is not None:
|
|
|
|
generation_kwargs['min_new_tokens'] = min_out_len
|
|
|
|
generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
|
|
|
|
|
|
|
|
# step-2: conduct model forward to generate output
|
|
|
|
outputs = self.model.generate(**tokens, **generation_kwargs)
|
|
|
|
outputs = outputs[:, tokens['input_ids'].shape[1]:]
|
|
|
|
|
|
|
|
# step-3: decode the output
|
|
|
|
decodeds = self.tokenizer.batch_decode(outputs)
|
|
|
|
for stop in stopping_criteria:
|
|
|
|
decodeds = [t.split(stop)[0] for t in decodeds]
|
|
|
|
|
|
|
|
return decodeds
|
|
|
|
|
|
|
|
def get_token_len(self, prompt: str) -> int:
|
|
|
|
m = _convert_chat_messages([prompt])[0]
|
|
|
|
t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
|
|
|
|
return len(t['input_ids'])
|
|
|
|
|
|
|
|
def _convert_base_messages(inputs):
|
|
|
|
outputs = []
|
|
|
|
for _input in inputs:
|
|
|
|
if isinstance(_input, str):
|
|
|
|
outputs.append(_input)
|
|
|
|
else:
|
|
|
|
messages = []
|
|
|
|
for item in _input:
|
|
|
|
messages.append(item['prompt'])
|
|
|
|
outputs.append(''.join(messages))
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
path: str,
|
|
|
|
model_kwargs: dict = dict(),
|
|
|
|
tokenizer_path: Optional[str] = None,
|
|
|
|
tokenizer_kwargs: dict = dict(),
|
|
|
|
peft_path: Optional[str] = None,
|
|
|
|
peft_kwargs: dict = dict(),
|
|
|
|
tokenizer_only: bool = False,
|
|
|
|
generation_kwargs: dict = dict(),
|
|
|
|
max_seq_len: Optional[int] = None,
|
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
|
stop_words: Optional[str] = [],
|
|
|
|
**other_kwargs):
|
|
|
|
|
|
|
|
self.logger = get_logger()
|
|
|
|
self.path = path
|
|
|
|
self.tokenizer_only = tokenizer_only
|
|
|
|
self.template_parser = LMTemplateParser()
|
|
|
|
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
|
|
|
self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id)
|
|
|
|
if not tokenizer_only:
|
|
|
|
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
|
|
|
self.generation_kwargs = generation_kwargs
|
|
|
|
self.stop_words = stop_words
|
|
|
|
|
|
|
|
for k, v in other_kwargs.items():
|
|
|
|
if v is not None:
|
|
|
|
self.logger.warning(f'Unused argument {k}={v}')
|
|
|
|
|
|
|
|
def generate(self,
|
|
|
|
inputs: List[str],
|
|
|
|
max_out_len: int,
|
|
|
|
min_out_len: Optional[int] = None,
|
|
|
|
stopping_criteria: List[str] = [],
|
|
|
|
**kwargs) -> List[str]:
|
|
|
|
messages = _convert_base_messages(inputs)
|
|
|
|
batch_size = len(messages)
|
|
|
|
|
|
|
|
tokenize_kwargs = dict(
|
|
|
|
return_tensors='pt',
|
|
|
|
padding=True,
|
|
|
|
truncation=True,
|
|
|
|
add_special_tokens=True,
|
|
|
|
max_length=self.max_seq_len
|
|
|
|
)
|
|
|
|
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
|
|
|
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
|
|
|
|
|
|
|
generation_kwargs = self.generation_kwargs.copy()
|
|
|
|
generation_kwargs.update(kwargs)
|
|
|
|
stopping_criteria = list(set(stopping_criteria + self.stop_words))
|
|
|
|
if stopping_criteria:
|
|
|
|
generation_kwargs['stopping_criteria'] = _get_stopping_criteria(stopping_criteria, self.tokenizer, batch_size)
|
|
|
|
if max_out_len is not None:
|
|
|
|
generation_kwargs['max_new_tokens'] = max_out_len
|
|
|
|
if min_out_len is not None:
|
|
|
|
generation_kwargs['min_new_tokens'] = min_out_len
|
|
|
|
generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
|
|
|
|
|
|
|
|
# step-2: conduct model forward to generate output
|
|
|
|
outputs = self.model.generate(**tokens, **generation_kwargs)
|
|
|
|
outputs = outputs[:, tokens['input_ids'].shape[1]:]
|
|
|
|
|
|
|
|
# step-3: decode the output
|
|
|
|
decodeds = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
for stop in stopping_criteria:
|
|
|
|
decodeds = [token.split(stop)[0] for token in decodeds]
|
|
|
|
|
|
|
|
return decodeds
|
|
|
|
|
|
|
|
def get_ppl(self, inputs: List[str], mask_length: Optional[List[int]] = None) -> List[float]:
|
|
|
|
"""Get perplexity scores given a list of inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs (List[str]): A list of strings.
|
|
|
|
mask_length (Optional[List[int]]): A list of mask lengths. If
|
|
|
|
provided, the perplexity scores will be calculated with the
|
|
|
|
first mask_length[i] tokens masked out. It's okay to skip
|
|
|
|
its implementation if advanced features in PPLInfernecer is
|
|
|
|
not needed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[float]: A list of perplexity scores.
|
|
|
|
"""
|
|
|
|
assert self.tokenizer.pad_token
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
|
|
messages = _convert_base_messages(inputs)
|
|
|
|
|
|
|
|
tokenize_kwargs = dict(
|
|
|
|
return_tensors='pt',
|
|
|
|
padding=True,
|
|
|
|
truncation=True,
|
|
|
|
add_special_tokens=True,
|
|
|
|
max_length=self.max_seq_len
|
|
|
|
)
|
|
|
|
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
|
|
|
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
|
|
|
outputs = self.model(**tokens)[0]
|
|
|
|
|
|
|
|
batch_size, seq_len, vocab_size = outputs.shape
|
|
|
|
shift_logits = outputs[:, :-1, :].contiguous().float()
|
|
|
|
shift_labels = tokens['input_ids'][:, 1:].contiguous()
|
|
|
|
loss = F.cross_entropy(
|
|
|
|
shift_logits.view(-1, vocab_size),
|
|
|
|
shift_labels.view(-1),
|
|
|
|
ignore_index=pad_token_id,
|
|
|
|
reduction='none').view(batch_size, seq_len - 1)
|
|
|
|
lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy()
|
|
|
|
|
|
|
|
if mask_length is not None:
|
|
|
|
import numpy as np
|
|
|
|
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
|
|
|
|
for i in range(len(mask)):
|
|
|
|
for j in range(mask_length[i] - 1, len(mask[i])):
|
|
|
|
mask[i][j] = 1
|
|
|
|
loss = loss * mask
|
|
|
|
lens -= np.array(mask_length)
|
|
|
|
|
|
|
|
ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
|
|
|
|
return ce_loss
|
|
|
|
|
|
|
|
def get_loglikelihood(self, inputs: List[str], conts: List[str]) -> List[float]:
|
|
|
|
mask_length = [self.get_token_len(c, add_special_tokens=False) for c in conts]
|
|
|
|
return - self.get_ppl(inputs, mask_length)
|
|
|
|
|
|
|
|
def get_token_len(self, prompt: str, add_special_tokens: bool=True) -> int:
|
|
|
|
m = _convert_base_messages([prompt])[0]
|
|
|
|
t = self.tokenizer(m, add_special_tokens=add_special_tokens)
|
|
|
|
return len(t['input_ids'])
|