import os import random import sys import time from typing import Dict, List, Optional, Union import numpy as np import torch import torch.distributed as dist from opencompass.models.base import BaseModel from opencompass.registry import MODELS from opencompass.utils.logging import get_logger class InternTrainManager: def __init__(self, module_path): self.module_path = module_path @staticmethod def build(module_path): sys.path.insert(0, module_path) try: from internlm.core.context.registry import \ register_model_initializer # noqa: F401 return CurrentInternTrainManager(module_path) except ImportError: return LegacyInternTrainManager(module_path) class CurrentInternTrainManager(InternTrainManager): def load_config(self, path, model_config=None): from internlm.config import Config if model_config is None: model_config = torch.load(os.path.join(path, 'model_config.pt')) elif isinstance(model_config, dict): model_config = Config(model_config) elif isinstance(model_config, str): model_config = Config.fromfile(model_config).model else: raise NotImplementedError( 'model_config should be None, dict or filename.') return model_config def initialize_model(self): from internlm.train.pipeline import (initialize_model, initialize_parallel_communicator) model = initialize_model().model initialize_parallel_communicator(model) return model class LegacyInternTrainManager(InternTrainManager): def load_config(self, path, model_config=None): from internlm.core.context import Config if model_config is None: model_config = torch.load(os.path.join(path, 'model_config.pt')) elif isinstance(model_config, dict): model_config = Config(model_config) elif isinstance(model_config, str): model_config = Config.from_file(model_config).model else: raise NotImplementedError( 'model_config should be None, dict or filename.') return model_config def initialize_model(self): from internlm.train.pipeline import initialize_model model = initialize_model().model return model @MODELS.register_module() class InternTrain(BaseModel): def __init__(self, path: str, module_path: str, max_seq_len: int = 2048, tokenizer_only: bool = False, tokenizer_path: Optional[str] = None, tokenizer_type: str = 'INTERNLM', model_config: Optional[str] = None, model_type: str = 'INTERNLM2', ckpt_type: Optional[str] = None, meta_template: Optional[Dict] = None, model_dtype: Optional[str] = None, generation_kwargs={}, sync_rank: bool = False, mode='none'): super().__init__(path=path, max_seq_len=max_seq_len, tokenizer_only=tokenizer_only, meta_template=meta_template, sync_rank=sync_rank) self.logger = get_logger() # insert interntrain module self.manager = InternTrainManager.build(module_path) # TODO: mode is not a good name, change it both here and huggingface.py # mode = 'mid' is used only in longtext eval, which cut off tokens in # the middle # https://github.com/THUDM/LongBench assert mode in ['none', 'mid'] self.mode = mode self._load_tokenizer(tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type) if not tokenizer_only: self._load_model(path=path, model_config=model_config, model_type=model_type, model_dtype=model_dtype, ckpt_type=ckpt_type) # default generation_kwargs assert generation_kwargs.pop('num_return_sequences', 1) == 1 # TODO self.generation_kwargs = { 'temperature': 1.0, 'top_p': 1.0, 'top_k': 50, 'do_sample': False, 'repetition_penalty': 1.0, } self.generation_kwargs.update(generation_kwargs) self.logger.info(f'generation_kwargs: {self.generation_kwargs}') # generator from internlm.apis.inference import SequenceGenerator eos_token_ids = self.generation_kwargs.get('eos_token_id', []) if isinstance(eos_token_ids, int): eos_token_ids = [eos_token_ids] eos_token_ids.append(self.tokenizer.eos_id) if self.eos_token_id is not None: eos_token_ids.append(self.eos_token_id) eos_token_ids = list(set(eos_token_ids)) self.generator = SequenceGenerator(self.model, bos_token_id=self.tokenizer.bos_id, pad_token_id=self.tokenizer.bos_id, eos_token_id=eos_token_ids) def _load_model(self, path: str, model_config: Optional[str] = None, model_type: str = 'INTERNLM2', model_dtype: Optional[str] = None, ckpt_type: Optional[str] = None): # funcs from internlm.checkpoint.load_funcs import (LOAD_FUNC_DICT, merge_pp_within_tp) from internlm.core.context import global_context as gpc from internlm.initialize.launch import launch from internlm.utils.storage_manager import (get_storage_manager, init_storage_manager) # config model_config = self.manager.load_config(path, model_config) model_config['parallel_output'] = False model_config['dtype'] = self._convert_dtype(model_config['dtype'], model_dtype=model_dtype) world_size = int(os.getenv('WORLD_SIZE', '1')) tp_size = world_size # TODO self.logger.info(f'world size: {world_size} tp: {tp_size}') parallel_config = dict(zero1=dict(size=1, fsdp=False), pipeline=dict(size=1), tensor=dict(size=tp_size, mode='mtp'), sequence_parallel=False) config = dict(model=model_config, parallel=parallel_config, data=dict(use_packed_dataset=False), model_type=model_type, use_cuda_flash_attn=model_config.get( 'use_flash_attn', True)) launch( config=config, seed=42, local_rank=int(os.getenv('RANK', '0')), rank=int(os.getenv('LOCAL_RANK', '0')), world_size=int(os.getenv('WORLD_SIZE', '1')), host=os.getenv('MASTER_ADDR', '127.0.0.1'), port=int(os.getenv('MASTER_PORT', random.randint(12000, 32000))), ) self.logger.info(f'Config: {gpc.config}') self.model = self.manager.initialize_model() # load state dict try: get_storage_manager() except AssertionError: init_storage_manager(False, None, None) get_storage_manager() if ckpt_type is None or ckpt_type == 'internevo': state_dict = merge_pp_within_tp(path, del_model_prefix=True) load_info = self.model.load_state_dict(state_dict, strict=False) self.logger.info(load_info) else: load_func = LOAD_FUNC_DICT[ckpt_type] load_func(path, self.model) self.model.to(model_config['dtype']).eval().cuda() def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str): from internlm.core.context.registry import TOKENIZER_INITIALIZER tokenizer_cls = TOKENIZER_INITIALIZER.get_module(tokenizer_type) self.tokenizer = tokenizer_cls( model_path=tokenizer_path, use_bos=True, use_eos=False, ) # TODO use bos as pad temporarily if self.tokenizer.pad_id == -1: self.pad_id = self.tokenizer.bos_id else: self.pad_id = self.tokenizer.pad_id def _convert_dtype(self, default_dtype, model_dtype=None): if model_dtype is None: return default_dtype elif isinstance(model_dtype, torch.dtype): return model_dtype elif model_dtype == 'torch.bfloat16': return torch.bfloat16 elif model_dtype in ('torch.float16', 'torch.half'): return torch.float16 elif model_dtype in ('torch.float32', 'torch.float'): return torch.float32 elif model_dtype in ('torch.tf32'): torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True return torch.float32 else: raise NotImplementedError(f'Unknown model dtype {model_dtype}') def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. Args: prompt (str): Input string. Returns: int: Length of the input tokens """ tokens = self.tokenizer(prompt, use_bos=True, use_eos=True) return len(tokens) def generate(self, inputs: List[str], max_out_len: int, min_out_len: Optional[int] = None, stopping_criteria: List[str] = []) -> List[str]: """Generate results given a list of inputs. Args: inputs (List[str]): A list of strings. max_out_len (int): The maximum length of the output. Returns: List[str]: A list of generated strings. """ if min_out_len is None: # keep same with InternTrain's default value min_out_len = 1 tokens = self.batch_encode(inputs, self.max_seq_len - max_out_len, left_padding=True) # random seed for pass@k seed = torch.tensor(time.time(), dtype=torch.int64).cuda() dist.broadcast(seed, src=0) torch.cuda.manual_seed(seed.item()) dist.barrier() outputs = self.generator.generate( tokens, max_length=tokens.shape[1] + max_out_len, **self.generation_kwargs) # bsz, num_return_sequences, max_length outputs = outputs[:, 0, tokens.shape[1]:] output_text = self.batch_decode(outputs, stopping_criteria=stopping_criteria) return output_text def get_ppl(self, input_texts: List[str], mask_length: Optional[List[int]] = None) -> List[float]: """Get perplexity scores given a list of inputs. Args: input_texts (List[str]): A list of strings. mask_length (Optional[List[int]]): A list of mask lengths. If provided, the perplexity scores will be calculated with the first mask_length[i] tokens masked out. Returns: List[float]: A list of perplexity scores. """ outputs, inputs = self.get_logits(input_texts) shift_logits = outputs[..., :-1, :].contiguous() shift_labels = inputs[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.pad_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) if mask_length is not None: mask = torch.zeros_like(shift_labels) # [batch,seqlen] for i in range(len(mask)): for j in range(mask_length[i] - 1, len(mask[i])): mask[i][j] = 1 loss = loss * mask lens = (inputs != self.pad_id).sum(-1).cpu().numpy() if mask_length is not None: lens -= np.array(mask_length) ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens return ce_loss def get_loglikelihood(self, input_texts: List[str], conts: List[str]) -> List[float]: outputs, inputs = self.get_logits(input_texts) shift_logits = outputs[..., :-1, :].contiguous() shift_labels = inputs[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.pad_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) lens = (inputs != self.pad_id).sum(-1).cpu().numpy() replaced_texts = [ input_text.replace(cont, '') for input_text, cont in zip(input_texts, conts) ] replaced_lens = [ len(self.encode(input_text)[0]) for input_text in replaced_texts ] loglikelihoods = [] for nloss, nlen, rlen in zip(loss, lens, replaced_lens): nlen, rlen = int(nlen), int(rlen) nloss = nloss[:nlen] nloss = nloss[rlen:].float().sum().cpu().detach().numpy() loglikelihoods.append(-nloss) return np.array(loglikelihoods) def get_mink_percent(self, input_texts: List[str], k: int = 20) -> List[float]: """https://swj0419.github.io/detect-pretrain.github.io/""" outputs, inputs = self.get_logits(input_texts) shift_logits = outputs[..., :-1, :].contiguous() shift_labels = inputs[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.pad_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) lens = (inputs != self.pad_id).sum(-1).cpu().numpy() mink_percent = [] for nloss, nlen in zip(loss, lens): nlen = int(nlen) minklen = max(nlen * k // 100, 1) nloss = torch.topk(loss[-nlen:], minklen, dim=-1)[0] nloss = -nloss.float().mean().cpu().detach().numpy() mink_percent.append(nloss) return np.array(mink_percent) def get_logits(self, input_texts: Union[str, List[str]]): tokens = self.batch_encode(input_texts, max_seq_len=self.max_seq_len) outputs = self.model(input_ids=tokens) if isinstance(outputs, tuple): # moe returns (hidden_states, moe_losses) outputs = outputs[0] return outputs, tokens def batch_encode(self, input_texts: Union[str, List[str]], max_seq_len: int, left_padding=False): if isinstance(input_texts, str): input_texts = [input_texts] tokens = [self.tokenizer(text) for text in input_texts] max_len = min(max_seq_len, max([len(t) for t in tokens])) for i in range(len(tokens)): cur_input = tokens[i] padding_len = max_len - len(cur_input) if self.mode == 'none': cur_input = cur_input[:max_len] elif self.mode == 'mid' and len(cur_input) > max_len: mid_cut_len = max_len // 2 cur_input = cur_input[:mid_cut_len] + cur_input[-mid_cut_len:] if left_padding: # left padding with bos tokens[i] = [self.tokenizer.bos_id] * padding_len + cur_input else: tokens[i] = cur_input + [self.pad_id] * padding_len return torch.LongTensor(tokens).cuda() def batch_decode(self, outputs, stopping_criteria: List[str] = []): # outputs: bsz, seq_len output_text = [] for output in outputs: text = self.tokenizer.decode(output.tolist()) for stop_word in stopping_criteria: text = text.split(stop_word)[0] output_text.append(text) return output_text