# flake8: noqa # yapf: disable from typing import Dict, List, Optional, Union from opencompass.models.base import BaseModel, LMTemplateParser from opencompass.models.base_api import APITemplateParser from opencompass.registry import MODELS from opencompass.utils.logging import get_logger from opencompass.utils.prompt import PromptList PromptType = Union[PromptList, str] def _get_stopping_criteria(stop_words, tokenizer, batch_size): from transformers import (PreTrainedTokenizer, StoppingCriteria, StoppingCriteriaList) class MultiTokenEOSCriteria(StoppingCriteria): """Criteria to stop on the specified multi-token sequence.""" def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, batch_size: int): self.done_tracker = [False] * batch_size self.sequence = sequence self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) self.sequence_id_len = len(self.sequence_ids) self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs) -> bool: # compare the last len(stop) tokens lookback_ids_batch = input_ids[:, -self.sequence_id_len:] lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) for i, done in enumerate(self.done_tracker): if done: continue self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] return False not in self.done_tracker criteria = [] for stop_word in stop_words: c = MultiTokenEOSCriteria(stop_word, tokenizer, batch_size) criteria.append(c) criteria = StoppingCriteriaList(criteria) return criteria def _get_possible_max_seq_len(max_seq_len, path): if max_seq_len is not None: return max_seq_len from transformers import AutoConfig config = AutoConfig.from_pretrained(path, trust_remote_code=True) possible_keys = [ 'max_position_embeddings', 'seq_length', 'model_max_length', ] for k in possible_keys: if hasattr(config, k): return getattr(config, k) raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.') def _convert_chat_messages(inputs): outputs = [] for _input in inputs: messages = [] if isinstance(_input, str): messages.append({'role': 'user', 'content': _input}) else: for item in _input: role = { 'HUMAN': 'user', 'BOT': 'assistant', 'SYSTEM': 'system', }[item['role']] messages.append({'role': role, 'content': item['prompt']}) outputs.append(messages) return outputs def _format_with_fast_chat_template(inputs: List[str], name: str='vicuna'): try: from fastchat.model import get_conversation_template except ImportError: raise ModuleNotFoundError('fastchat not found. Please install with\npip install "fschat[model_worker,webui]"') outputs = [] for _input in inputs: template = get_conversation_template(name) for item in _input: if item['role'] == 'user': template.append_message(template.roles[0], item['content']) elif item['role'] == 'assistant': template.append_message(template.roles[1], item['content']) elif item['role'] == 'system': continue else: raise ValueError(f'Unknown role {item["role"]}') template.append_message(template.roles[1], None) outputs.append(template.get_prompt()) return outputs def _get_meta_template(meta_template): default_meta_template = dict( round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ] ) return APITemplateParser(meta_template or default_meta_template) def _set_model_kwargs_torch_dtype(model_kwargs): import torch if 'torch_dtype' not in model_kwargs: torch_dtype = torch.float16 else: torch_dtype = { 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.float': torch.float, 'auto': 'auto', 'None': None, }.get(model_kwargs['torch_dtype']) if torch_dtype is not None: model_kwargs['torch_dtype'] = torch_dtype return model_kwargs @MODELS.register_module() class HuggingFacewithChatTemplate(BaseModel): def __init__(self, path: str, model_kwargs: dict = dict(), tokenizer_path: Optional[str] = None, tokenizer_kwargs: dict = dict(), peft_path: Optional[str] = None, peft_kwargs: dict = dict(), tokenizer_only: bool = False, generation_kwargs: dict = dict(), max_seq_len: Optional[int] = None, meta_template: Optional[Dict] = None, pad_token_id: Optional[int] = None, fastchat_template: Optional[str] = None, stop_words: Optional[str] = [], **other_kwargs): self.logger = get_logger() self.path = path self.tokenizer_only = tokenizer_only self.template_parser = _get_meta_template(meta_template) self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id) if not tokenizer_only: self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self.generation_kwargs = generation_kwargs self.fastchat_template = fastchat_template self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) for k, v in other_kwargs.items(): if v is not None: self.logger.warning(f'Unused argument {k}={v}') def _load_tokenizer(self, path: Optional[str], kwargs: dict, pad_token_id: Optional[int] = None): from transformers import AutoTokenizer, GenerationConfig DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True) tokenizer_kwargs = DEFAULT_TOKENIZER_KWARGS tokenizer_kwargs.update(kwargs) self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_kwargs) # A patch for some models without pad_token_id if pad_token_id is not None: if self.tokenizer.pad_token_id is None: self.logger.debug(f'Using {pad_token_id} as pad_token_id') elif self.tokenizer.pad_token_id != pad_token_id: self.logger.warning(f'pad_token_id is not consistent. Using {pad_token_id} as pad_token_id') self.tokenizer.pad_token_id = pad_token_id return if self.tokenizer.pad_token_id is not None: return self.logger.warning('pad_token_id is not set for the tokenizer.') generation_config = GenerationConfig.from_pretrained(path) if generation_config.pad_token_id is not None: self.logger.warning(f'Using {generation_config.pad_token_id} as pad_token_id.') self.tokenizer.pad_token_id = generation_config.pad_token_id return if self.tokenizer.eos_token_id is not None: self.logger.warning(f'Using eos_token_id {self.tokenizer.eos_token_id} as pad_token_id.') self.tokenizer.pad_token_id = self.tokenizer.eos_token_id return raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.') def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()): from transformers import AutoModel, AutoModelForCausalLM DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True) model_kwargs = DEFAULT_MODEL_KWARGS model_kwargs.update(kwargs) model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) self.logger.debug(f'using model_kwargs: {model_kwargs}') try: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) except ValueError: self.model = AutoModel.from_pretrained(path, **model_kwargs) if peft_path is not None: from peft import PeftModel peft_kwargs['is_trainable'] = False self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs) self.model.eval() self.model.generation_config.do_sample = False def _get_potential_stop_words(self, path: Optional[str]): from transformers import GenerationConfig potential_stop_words = [] try: generation_config = GenerationConfig.from_pretrained(path) for token_id in generation_config.eos_token_id: potential_stop_words.append(self.tokenizer.decode(token_id)) except: pass potential_stop_words.append(self.tokenizer.eos_token) potential_stop_words = list(set(potential_stop_words)) return potential_stop_words def generate(self, inputs: List[str], max_out_len: int, min_out_len: Optional[int] = None, stopping_criteria: List[str] = [], **kwargs) -> List[str]: messages = _convert_chat_messages(inputs) batch_size = len(messages) tokenize_kwargs = dict( return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, max_length=self.max_seq_len ) if self.fastchat_template: messages = _format_with_fast_chat_template(messages, self.fastchat_template) tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) else: messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] tokenize_kwargs['add_special_tokens'] = False tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) tokens = {k: v.to(self.model.device) for k, v in tokens.items()} generation_kwargs = self.generation_kwargs.copy() generation_kwargs.update(kwargs) stopping_criteria = list(set(stopping_criteria + self.stop_words)) if stopping_criteria: generation_kwargs['stopping_criteria'] = _get_stopping_criteria(stopping_criteria, self.tokenizer, batch_size) if max_out_len is not None: generation_kwargs['max_new_tokens'] = max_out_len if min_out_len is not None: generation_kwargs['min_new_tokens'] = min_out_len generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id # step-2: conduct model forward to generate output outputs = self.model.generate(**tokens, **generation_kwargs) outputs = outputs[:, tokens['input_ids'].shape[1]:] # step-3: decode the output decodeds = self.tokenizer.batch_decode(outputs) for stop in stopping_criteria: decodeds = [t.split(stop)[0] for t in decodeds] return decodeds def get_token_len(self, prompt: str) -> int: m = _convert_chat_messages([prompt])[0] t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True) return len(t['input_ids']) def _convert_base_messages(inputs): outputs = [] for _input in inputs: if isinstance(_input, str): outputs.append(_input) else: messages = [] for item in _input: messages.append(item['prompt']) outputs.append(''.join(messages)) return outputs class HuggingFaceBaseModel(HuggingFacewithChatTemplate): def __init__(self, path: str, model_kwargs: dict = dict(), tokenizer_path: Optional[str] = None, tokenizer_kwargs: dict = dict(), peft_path: Optional[str] = None, peft_kwargs: dict = dict(), tokenizer_only: bool = False, generation_kwargs: dict = dict(), max_seq_len: Optional[int] = None, pad_token_id: Optional[int] = None, stop_words: Optional[str] = [], **other_kwargs): self.logger = get_logger() self.path = path self.tokenizer_only = tokenizer_only self.template_parser = LMTemplateParser() self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id) if not tokenizer_only: self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self.generation_kwargs = generation_kwargs self.stop_words = stop_words for k, v in other_kwargs.items(): if v is not None: self.logger.warning(f'Unused argument {k}={v}') def generate(self, inputs: List[str], max_out_len: int, min_out_len: Optional[int] = None, stopping_criteria: List[str] = [], **kwargs) -> List[str]: messages = _convert_base_messages(inputs) batch_size = len(messages) tokenize_kwargs = dict( return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, max_length=self.max_seq_len ) tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) tokens = {k: v.to(self.model.device) for k, v in tokens.items()} generation_kwargs = self.generation_kwargs.copy() generation_kwargs.update(kwargs) stopping_criteria = list(set(stopping_criteria + self.stop_words)) if stopping_criteria: generation_kwargs['stopping_criteria'] = _get_stopping_criteria(stopping_criteria, self.tokenizer, batch_size) if max_out_len is not None: generation_kwargs['max_new_tokens'] = max_out_len if min_out_len is not None: generation_kwargs['min_new_tokens'] = min_out_len generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id # step-2: conduct model forward to generate output outputs = self.model.generate(**tokens, **generation_kwargs) outputs = outputs[:, tokens['input_ids'].shape[1]:] # step-3: decode the output decodeds = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) for stop in stopping_criteria: decodeds = [token.split(stop)[0] for token in decodeds] return decodeds def get_ppl(self, inputs: List[str], mask_length: Optional[List[int]] = None) -> List[float]: """Get perplexity scores given a list of inputs. Args: inputs (List[str]): A list of strings. mask_length (Optional[List[int]]): A list of mask lengths. If provided, the perplexity scores will be calculated with the first mask_length[i] tokens masked out. It's okay to skip its implementation if advanced features in PPLInfernecer is not needed. Returns: List[float]: A list of perplexity scores. """ assert self.tokenizer.pad_token import torch import torch.nn.functional as F pad_token_id = self.tokenizer.pad_token_id messages = _convert_base_messages(inputs) tokenize_kwargs = dict( return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, max_length=self.max_seq_len ) tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) tokens = {k: v.to(self.model.device) for k, v in tokens.items()} outputs = self.model(**tokens)[0] batch_size, seq_len, vocab_size = outputs.shape shift_logits = outputs[:, :-1, :].contiguous().float() shift_labels = tokens['input_ids'][:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, vocab_size), shift_labels.view(-1), ignore_index=pad_token_id, reduction='none').view(batch_size, seq_len - 1) lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy() if mask_length is not None: import numpy as np mask = torch.zeros_like(shift_labels) # [batch,seqlen] for i in range(len(mask)): for j in range(mask_length[i] - 1, len(mask[i])): mask[i][j] = 1 loss = loss * mask lens -= np.array(mask_length) ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens return ce_loss def get_loglikelihood(self, inputs: List[str], conts: List[str]) -> List[float]: mask_length = [self.get_token_len(c, add_special_tokens=False) for c in conts] return - self.get_ppl(inputs, mask_length) def get_token_len(self, prompt: str, add_special_tokens: bool=True) -> int: m = _convert_base_messages([prompt])[0] t = self.tokenizer(m, add_special_tokens=add_special_tokens) return len(t['input_ids'])