diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 0d384fed..403eb5d6 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -20,6 +20,7 @@ from .huggingface_above_v4_33 import HuggingFaceBaseModel # noqa: F401 from .huggingface_above_v4_33 import HuggingFacewithChatTemplate # noqa: F401 from .hunyuan_api import Hunyuan # noqa: F401 from .intern_model import InternLM # noqa: F401 +from .interntrain import InternTrain # noqa: F401 from .krgpt_api import KrGPT # noqa: F401 from .lightllm_api import LightllmAPI, LightllmChatAPI # noqa: F401 from .llama2 import Llama2, Llama2Chat # noqa: F401 diff --git a/opencompass/models/interntrain.py b/opencompass/models/interntrain.py new file mode 100644 index 00000000..d6c233cd --- /dev/null +++ b/opencompass/models/interntrain.py @@ -0,0 +1,419 @@ +import os +import random +import sys +import time +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist + +from opencompass.models.base import BaseModel +from opencompass.registry import MODELS +from opencompass.utils.logging import get_logger + + +class InternTrainManager: + + def __init__(self, module_path): + self.module_path = module_path + + @staticmethod + def build(module_path): + sys.path.insert(0, module_path) + try: + from internlm.core.context.registry import \ + register_model_initializer # noqa: F401 + return CurrentInternTrainManager(module_path) + except ImportError: + return LegacyInternTrainManager(module_path) + + +class CurrentInternTrainManager(InternTrainManager): + + def load_config(self, path, model_config=None): + from internlm.config import Config + if model_config is None: + model_config = torch.load(os.path.join(path, 'model_config.pt')) + elif isinstance(model_config, dict): + model_config = Config(model_config) + elif isinstance(model_config, str): + model_config = Config.fromfile(model_config).model + else: + raise NotImplementedError( + 'model_config should be None, dict or filename.') + + return model_config + + def initialize_model(self): + from internlm.train.pipeline import (initialize_model, + initialize_parallel_communicator) + model = initialize_model().model + initialize_parallel_communicator(model) + + return model + + +class LegacyInternTrainManager(InternTrainManager): + + def load_config(self, path, model_config=None): + from internlm.core.context import Config + if model_config is None: + model_config = torch.load(os.path.join(path, 'model_config.pt')) + elif isinstance(model_config, dict): + model_config = Config(model_config) + elif isinstance(model_config, str): + model_config = Config.from_file(model_config).model + else: + raise NotImplementedError( + 'model_config should be None, dict or filename.') + + return model_config + + def initialize_model(self): + from internlm.train.pipeline import initialize_model + model = initialize_model().model + + return model + + +@MODELS.register_module() +class InternTrain(BaseModel): + + def __init__(self, + path: str, + module_path: str, + max_seq_len: int = 2048, + tokenizer_only: bool = False, + tokenizer_path: Optional[str] = None, + tokenizer_type: str = 'INTERNLM', + model_config: Optional[str] = None, + model_type: str = 'INTERNLM2', + ckpt_type: Optional[str] = None, + meta_template: Optional[Dict] = None, + model_dtype: Optional[str] = None, + generation_kwargs={}, + sync_rank: bool = False, + mode='none'): + super().__init__(path=path, + max_seq_len=max_seq_len, + tokenizer_only=tokenizer_only, + meta_template=meta_template, + sync_rank=sync_rank) + self.logger = get_logger() + # insert interntrain module + self.manager = InternTrainManager.build(module_path) + + # TODO: mode is not a good name, change it both here and huggingface.py + # mode = 'mid' is used only in longtext eval, which cut off tokens in + # the middle + # https://github.com/THUDM/LongBench + assert mode in ['none', 'mid'] + self.mode = mode + + self._load_tokenizer(tokenizer_path=tokenizer_path, + tokenizer_type=tokenizer_type) + + if not tokenizer_only: + self._load_model(path=path, + model_config=model_config, + model_type=model_type, + model_dtype=model_dtype, + ckpt_type=ckpt_type) + + # default generation_kwargs + assert generation_kwargs.pop('num_return_sequences', 1) == 1 # TODO + self.generation_kwargs = { + 'temperature': 1.0, + 'top_p': 1.0, + 'top_k': 50, + 'do_sample': False, + 'repetition_penalty': 1.0, + } + self.generation_kwargs.update(generation_kwargs) + self.logger.info(f'generation_kwargs: {self.generation_kwargs}') + + # generator + from internlm.apis.inference import SequenceGenerator + eos_token_ids = self.generation_kwargs.get('eos_token_id', []) + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + eos_token_ids.append(self.tokenizer.eos_id) + if self.eos_token_id is not None: + eos_token_ids.append(self.eos_token_id) + eos_token_ids = list(set(eos_token_ids)) + self.generator = SequenceGenerator(self.model, + bos_token_id=self.tokenizer.bos_id, + pad_token_id=self.tokenizer.bos_id, + eos_token_id=eos_token_ids) + + def _load_model(self, + path: str, + model_config: Optional[str] = None, + model_type: str = 'INTERNLM2', + model_dtype: Optional[str] = None, + ckpt_type: Optional[str] = None): + # funcs + from internlm.checkpoint.load_funcs import (LOAD_FUNC_DICT, + merge_pp_within_tp) + from internlm.core.context import global_context as gpc + from internlm.initialize.launch import launch + from internlm.utils.storage_manager import (get_storage_manager, + init_storage_manager) + + # config + model_config = self.manager.load_config(path, model_config) + model_config['parallel_output'] = False + model_config['dtype'] = self._convert_dtype(model_config['dtype'], + model_dtype=model_dtype) + + world_size = int(os.getenv('WORLD_SIZE', '1')) + tp_size = world_size # TODO + self.logger.info(f'world size: {world_size} tp: {tp_size}') + parallel_config = dict(zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1), + tensor=dict(size=tp_size, mode='mtp'), + sequence_parallel=False) + config = dict(model=model_config, + parallel=parallel_config, + data=dict(use_packed_dataset=False), + model_type=model_type, + use_cuda_flash_attn=model_config.get( + 'use_flash_attn', True)) + launch( + config=config, + seed=42, + local_rank=int(os.getenv('RANK', '0')), + rank=int(os.getenv('LOCAL_RANK', '0')), + world_size=int(os.getenv('WORLD_SIZE', '1')), + host=os.getenv('MASTER_ADDR', '127.0.0.1'), + port=int(os.getenv('MASTER_PORT', random.randint(12000, 32000))), + ) + self.logger.info(f'Config: {gpc.config}') + + self.model = self.manager.initialize_model() + + # load state dict + try: + get_storage_manager() + except AssertionError: + init_storage_manager(False, None, None) + get_storage_manager() + if ckpt_type is None or ckpt_type == 'internevo': + state_dict = merge_pp_within_tp(path, del_model_prefix=True) + load_info = self.model.load_state_dict(state_dict, strict=False) + self.logger.info(load_info) + else: + load_func = LOAD_FUNC_DICT[ckpt_type] + load_func(path, self.model) + + self.model.to(model_config['dtype']).eval().cuda() + + def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str): + from internlm.core.context.registry import TOKENIZER_INITIALIZER + tokenizer_cls = TOKENIZER_INITIALIZER.get_module(tokenizer_type) + self.tokenizer = tokenizer_cls( + model_path=tokenizer_path, + use_bos=True, + use_eos=False, + ) + + # TODO use bos as pad temporarily + if self.tokenizer.pad_id == -1: + self.pad_id = self.tokenizer.bos_id + else: + self.pad_id = self.tokenizer.pad_id + + def _convert_dtype(self, default_dtype, model_dtype=None): + if model_dtype is None: + return default_dtype + elif isinstance(model_dtype, torch.dtype): + return model_dtype + elif model_dtype == 'torch.bfloat16': + return torch.bfloat16 + elif model_dtype in ('torch.float16', 'torch.half'): + return torch.float16 + elif model_dtype in ('torch.float32', 'torch.float'): + return torch.float32 + elif model_dtype in ('torch.tf32'): + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + return torch.float32 + else: + raise NotImplementedError(f'Unknown model dtype {model_dtype}') + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + tokens = self.tokenizer(prompt, use_bos=True, use_eos=True) + return len(tokens) + + def generate(self, + inputs: List[str], + max_out_len: int, + min_out_len: Optional[int] = None, + stopping_criteria: List[str] = []) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + if min_out_len is None: + # keep same with InternTrain's default value + min_out_len = 1 + + tokens = self.batch_encode(inputs, + self.max_seq_len - max_out_len, + left_padding=True) + + # random seed for pass@k + seed = torch.tensor(time.time(), dtype=torch.int64).cuda() + + dist.broadcast(seed, src=0) + torch.cuda.manual_seed(seed.item()) + dist.barrier() + outputs = self.generator.generate( + tokens, + max_length=tokens.shape[1] + max_out_len, + **self.generation_kwargs) # bsz, num_return_sequences, max_length + outputs = outputs[:, 0, tokens.shape[1]:] + output_text = self.batch_decode(outputs, + stopping_criteria=stopping_criteria) + + return output_text + + def get_ppl(self, + input_texts: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Get perplexity scores given a list of inputs. + + Args: + input_texts (List[str]): A list of strings. + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. + + Returns: + List[float]: A list of perplexity scores. + """ + outputs, inputs = self.get_logits(input_texts) + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = inputs[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction='none', + ignore_index=self.pad_id) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + + if mask_length is not None: + mask = torch.zeros_like(shift_labels) # [batch,seqlen] + for i in range(len(mask)): + for j in range(mask_length[i] - 1, len(mask[i])): + mask[i][j] = 1 + loss = loss * mask + + lens = (inputs != self.pad_id).sum(-1).cpu().numpy() + if mask_length is not None: + lens -= np.array(mask_length) + ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens + return ce_loss + + def get_loglikelihood(self, input_texts: List[str], + conts: List[str]) -> List[float]: + outputs, inputs = self.get_logits(input_texts) + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = inputs[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction='none', + ignore_index=self.pad_id) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + lens = (inputs != self.pad_id).sum(-1).cpu().numpy() + replaced_texts = [ + input_text.replace(cont, '') + for input_text, cont in zip(input_texts, conts) + ] + replaced_lens = [ + len(self.encode(input_text)[0]) for input_text in replaced_texts + ] + loglikelihoods = [] + for nloss, nlen, rlen in zip(loss, lens, replaced_lens): + nlen, rlen = int(nlen), int(rlen) + nloss = nloss[:nlen] + nloss = nloss[rlen:].float().sum().cpu().detach().numpy() + loglikelihoods.append(-nloss) + return np.array(loglikelihoods) + + def get_mink_percent(self, + input_texts: List[str], + k: int = 20) -> List[float]: + """https://swj0419.github.io/detect-pretrain.github.io/""" + outputs, inputs = self.get_logits(input_texts) + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = inputs[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction='none', + ignore_index=self.pad_id) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + lens = (inputs != self.pad_id).sum(-1).cpu().numpy() + mink_percent = [] + for nloss, nlen in zip(loss, lens): + nlen = int(nlen) + minklen = max(nlen * k // 100, 1) + nloss = torch.topk(loss[-nlen:], minklen, dim=-1)[0] + nloss = -nloss.float().mean().cpu().detach().numpy() + mink_percent.append(nloss) + return np.array(mink_percent) + + def get_logits(self, input_texts: Union[str, List[str]]): + tokens = self.batch_encode(input_texts, max_seq_len=self.max_seq_len) + outputs = self.model(input_ids=tokens) + if isinstance(outputs, tuple): + # moe returns (hidden_states, moe_losses) + outputs = outputs[0] + return outputs, tokens + + def batch_encode(self, + input_texts: Union[str, List[str]], + max_seq_len: int, + left_padding=False): + if isinstance(input_texts, str): + input_texts = [input_texts] + tokens = [self.tokenizer(text) for text in input_texts] + max_len = min(max_seq_len, max([len(t) for t in tokens])) + for i in range(len(tokens)): + cur_input = tokens[i] + padding_len = max_len - len(cur_input) + if self.mode == 'none': + cur_input = cur_input[:max_len] + elif self.mode == 'mid' and len(cur_input) > max_len: + mid_cut_len = max_len // 2 + cur_input = cur_input[:mid_cut_len] + cur_input[-mid_cut_len:] + + if left_padding: + # left padding with bos + tokens[i] = [self.tokenizer.bos_id] * padding_len + cur_input + else: + tokens[i] = cur_input + [self.pad_id] * padding_len + + return torch.LongTensor(tokens).cuda() + + def batch_decode(self, outputs, stopping_criteria: List[str] = []): + # outputs: bsz, seq_len + output_text = [] + for output in outputs: + text = self.tokenizer.decode(output.tolist()) + for stop_word in stopping_criteria: + text = text.split(stop_word)[0] + output_text.append(text) + + return output_text