mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add Interntrain model support (#1548)
Co-authored-by: x54-729 <xingshuhao.dispatch@pjlab.org.cn>
This commit is contained in:
parent
24915aeb3f
commit
335667183a
@ -20,6 +20,7 @@ from .huggingface_above_v4_33 import HuggingFaceBaseModel # noqa: F401
|
||||
from .huggingface_above_v4_33 import HuggingFacewithChatTemplate # noqa: F401
|
||||
from .hunyuan_api import Hunyuan # noqa: F401
|
||||
from .intern_model import InternLM # noqa: F401
|
||||
from .interntrain import InternTrain # noqa: F401
|
||||
from .krgpt_api import KrGPT # noqa: F401
|
||||
from .lightllm_api import LightllmAPI, LightllmChatAPI # noqa: F401
|
||||
from .llama2 import Llama2, Llama2Chat # noqa: F401
|
||||
|
419
opencompass/models/interntrain.py
Normal file
419
opencompass/models/interntrain.py
Normal file
@ -0,0 +1,419 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.logging import get_logger
|
||||
|
||||
|
||||
class InternTrainManager:
|
||||
|
||||
def __init__(self, module_path):
|
||||
self.module_path = module_path
|
||||
|
||||
@staticmethod
|
||||
def build(module_path):
|
||||
sys.path.insert(0, module_path)
|
||||
try:
|
||||
from internlm.core.context.registry import \
|
||||
register_model_initializer # noqa: F401
|
||||
return CurrentInternTrainManager(module_path)
|
||||
except ImportError:
|
||||
return LegacyInternTrainManager(module_path)
|
||||
|
||||
|
||||
class CurrentInternTrainManager(InternTrainManager):
|
||||
|
||||
def load_config(self, path, model_config=None):
|
||||
from internlm.config import Config
|
||||
if model_config is None:
|
||||
model_config = torch.load(os.path.join(path, 'model_config.pt'))
|
||||
elif isinstance(model_config, dict):
|
||||
model_config = Config(model_config)
|
||||
elif isinstance(model_config, str):
|
||||
model_config = Config.fromfile(model_config).model
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'model_config should be None, dict or filename.')
|
||||
|
||||
return model_config
|
||||
|
||||
def initialize_model(self):
|
||||
from internlm.train.pipeline import (initialize_model,
|
||||
initialize_parallel_communicator)
|
||||
model = initialize_model().model
|
||||
initialize_parallel_communicator(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class LegacyInternTrainManager(InternTrainManager):
|
||||
|
||||
def load_config(self, path, model_config=None):
|
||||
from internlm.core.context import Config
|
||||
if model_config is None:
|
||||
model_config = torch.load(os.path.join(path, 'model_config.pt'))
|
||||
elif isinstance(model_config, dict):
|
||||
model_config = Config(model_config)
|
||||
elif isinstance(model_config, str):
|
||||
model_config = Config.from_file(model_config).model
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'model_config should be None, dict or filename.')
|
||||
|
||||
return model_config
|
||||
|
||||
def initialize_model(self):
|
||||
from internlm.train.pipeline import initialize_model
|
||||
model = initialize_model().model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InternTrain(BaseModel):
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
module_path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_type: str = 'INTERNLM',
|
||||
model_config: Optional[str] = None,
|
||||
model_type: str = 'INTERNLM2',
|
||||
ckpt_type: Optional[str] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
model_dtype: Optional[str] = None,
|
||||
generation_kwargs={},
|
||||
sync_rank: bool = False,
|
||||
mode='none'):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template,
|
||||
sync_rank=sync_rank)
|
||||
self.logger = get_logger()
|
||||
# insert interntrain module
|
||||
self.manager = InternTrainManager.build(module_path)
|
||||
|
||||
# TODO: mode is not a good name, change it both here and huggingface.py
|
||||
# mode = 'mid' is used only in longtext eval, which cut off tokens in
|
||||
# the middle
|
||||
# https://github.com/THUDM/LongBench
|
||||
assert mode in ['none', 'mid']
|
||||
self.mode = mode
|
||||
|
||||
self._load_tokenizer(tokenizer_path=tokenizer_path,
|
||||
tokenizer_type=tokenizer_type)
|
||||
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path,
|
||||
model_config=model_config,
|
||||
model_type=model_type,
|
||||
model_dtype=model_dtype,
|
||||
ckpt_type=ckpt_type)
|
||||
|
||||
# default generation_kwargs
|
||||
assert generation_kwargs.pop('num_return_sequences', 1) == 1 # TODO
|
||||
self.generation_kwargs = {
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 50,
|
||||
'do_sample': False,
|
||||
'repetition_penalty': 1.0,
|
||||
}
|
||||
self.generation_kwargs.update(generation_kwargs)
|
||||
self.logger.info(f'generation_kwargs: {self.generation_kwargs}')
|
||||
|
||||
# generator
|
||||
from internlm.apis.inference import SequenceGenerator
|
||||
eos_token_ids = self.generation_kwargs.get('eos_token_id', [])
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
eos_token_ids.append(self.tokenizer.eos_id)
|
||||
if self.eos_token_id is not None:
|
||||
eos_token_ids.append(self.eos_token_id)
|
||||
eos_token_ids = list(set(eos_token_ids))
|
||||
self.generator = SequenceGenerator(self.model,
|
||||
bos_token_id=self.tokenizer.bos_id,
|
||||
pad_token_id=self.tokenizer.bos_id,
|
||||
eos_token_id=eos_token_ids)
|
||||
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
model_config: Optional[str] = None,
|
||||
model_type: str = 'INTERNLM2',
|
||||
model_dtype: Optional[str] = None,
|
||||
ckpt_type: Optional[str] = None):
|
||||
# funcs
|
||||
from internlm.checkpoint.load_funcs import (LOAD_FUNC_DICT,
|
||||
merge_pp_within_tp)
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.initialize.launch import launch
|
||||
from internlm.utils.storage_manager import (get_storage_manager,
|
||||
init_storage_manager)
|
||||
|
||||
# config
|
||||
model_config = self.manager.load_config(path, model_config)
|
||||
model_config['parallel_output'] = False
|
||||
model_config['dtype'] = self._convert_dtype(model_config['dtype'],
|
||||
model_dtype=model_dtype)
|
||||
|
||||
world_size = int(os.getenv('WORLD_SIZE', '1'))
|
||||
tp_size = world_size # TODO
|
||||
self.logger.info(f'world size: {world_size} tp: {tp_size}')
|
||||
parallel_config = dict(zero1=dict(size=1, fsdp=False),
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=tp_size, mode='mtp'),
|
||||
sequence_parallel=False)
|
||||
config = dict(model=model_config,
|
||||
parallel=parallel_config,
|
||||
data=dict(use_packed_dataset=False),
|
||||
model_type=model_type,
|
||||
use_cuda_flash_attn=model_config.get(
|
||||
'use_flash_attn', True))
|
||||
launch(
|
||||
config=config,
|
||||
seed=42,
|
||||
local_rank=int(os.getenv('RANK', '0')),
|
||||
rank=int(os.getenv('LOCAL_RANK', '0')),
|
||||
world_size=int(os.getenv('WORLD_SIZE', '1')),
|
||||
host=os.getenv('MASTER_ADDR', '127.0.0.1'),
|
||||
port=int(os.getenv('MASTER_PORT', random.randint(12000, 32000))),
|
||||
)
|
||||
self.logger.info(f'Config: {gpc.config}')
|
||||
|
||||
self.model = self.manager.initialize_model()
|
||||
|
||||
# load state dict
|
||||
try:
|
||||
get_storage_manager()
|
||||
except AssertionError:
|
||||
init_storage_manager(False, None, None)
|
||||
get_storage_manager()
|
||||
if ckpt_type is None or ckpt_type == 'internevo':
|
||||
state_dict = merge_pp_within_tp(path, del_model_prefix=True)
|
||||
load_info = self.model.load_state_dict(state_dict, strict=False)
|
||||
self.logger.info(load_info)
|
||||
else:
|
||||
load_func = LOAD_FUNC_DICT[ckpt_type]
|
||||
load_func(path, self.model)
|
||||
|
||||
self.model.to(model_config['dtype']).eval().cuda()
|
||||
|
||||
def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str):
|
||||
from internlm.core.context.registry import TOKENIZER_INITIALIZER
|
||||
tokenizer_cls = TOKENIZER_INITIALIZER.get_module(tokenizer_type)
|
||||
self.tokenizer = tokenizer_cls(
|
||||
model_path=tokenizer_path,
|
||||
use_bos=True,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
# TODO use bos as pad temporarily
|
||||
if self.tokenizer.pad_id == -1:
|
||||
self.pad_id = self.tokenizer.bos_id
|
||||
else:
|
||||
self.pad_id = self.tokenizer.pad_id
|
||||
|
||||
def _convert_dtype(self, default_dtype, model_dtype=None):
|
||||
if model_dtype is None:
|
||||
return default_dtype
|
||||
elif isinstance(model_dtype, torch.dtype):
|
||||
return model_dtype
|
||||
elif model_dtype == 'torch.bfloat16':
|
||||
return torch.bfloat16
|
||||
elif model_dtype in ('torch.float16', 'torch.half'):
|
||||
return torch.float16
|
||||
elif model_dtype in ('torch.float32', 'torch.float'):
|
||||
return torch.float32
|
||||
elif model_dtype in ('torch.tf32'):
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
return torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model dtype {model_dtype}')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
tokens = self.tokenizer(prompt, use_bos=True, use_eos=True)
|
||||
return len(tokens)
|
||||
|
||||
def generate(self,
|
||||
inputs: List[str],
|
||||
max_out_len: int,
|
||||
min_out_len: Optional[int] = None,
|
||||
stopping_criteria: List[str] = []) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
if min_out_len is None:
|
||||
# keep same with InternTrain's default value
|
||||
min_out_len = 1
|
||||
|
||||
tokens = self.batch_encode(inputs,
|
||||
self.max_seq_len - max_out_len,
|
||||
left_padding=True)
|
||||
|
||||
# random seed for pass@k
|
||||
seed = torch.tensor(time.time(), dtype=torch.int64).cuda()
|
||||
|
||||
dist.broadcast(seed, src=0)
|
||||
torch.cuda.manual_seed(seed.item())
|
||||
dist.barrier()
|
||||
outputs = self.generator.generate(
|
||||
tokens,
|
||||
max_length=tokens.shape[1] + max_out_len,
|
||||
**self.generation_kwargs) # bsz, num_return_sequences, max_length
|
||||
outputs = outputs[:, 0, tokens.shape[1]:]
|
||||
output_text = self.batch_decode(outputs,
|
||||
stopping_criteria=stopping_criteria)
|
||||
|
||||
return output_text
|
||||
|
||||
def get_ppl(self,
|
||||
input_texts: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
"""Get perplexity scores given a list of inputs.
|
||||
|
||||
Args:
|
||||
input_texts (List[str]): A list of strings.
|
||||
mask_length (Optional[List[int]]): A list of mask lengths. If
|
||||
provided, the perplexity scores will be calculated with the
|
||||
first mask_length[i] tokens masked out.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of perplexity scores.
|
||||
"""
|
||||
outputs, inputs = self.get_logits(input_texts)
|
||||
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
shift_labels = inputs[..., 1:].contiguous()
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none',
|
||||
ignore_index=self.pad_id)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
if mask_length is not None:
|
||||
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
|
||||
for i in range(len(mask)):
|
||||
for j in range(mask_length[i] - 1, len(mask[i])):
|
||||
mask[i][j] = 1
|
||||
loss = loss * mask
|
||||
|
||||
lens = (inputs != self.pad_id).sum(-1).cpu().numpy()
|
||||
if mask_length is not None:
|
||||
lens -= np.array(mask_length)
|
||||
ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
|
||||
return ce_loss
|
||||
|
||||
def get_loglikelihood(self, input_texts: List[str],
|
||||
conts: List[str]) -> List[float]:
|
||||
outputs, inputs = self.get_logits(input_texts)
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
shift_labels = inputs[..., 1:].contiguous()
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none',
|
||||
ignore_index=self.pad_id)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1)).view(shift_labels.size())
|
||||
lens = (inputs != self.pad_id).sum(-1).cpu().numpy()
|
||||
replaced_texts = [
|
||||
input_text.replace(cont, '')
|
||||
for input_text, cont in zip(input_texts, conts)
|
||||
]
|
||||
replaced_lens = [
|
||||
len(self.encode(input_text)[0]) for input_text in replaced_texts
|
||||
]
|
||||
loglikelihoods = []
|
||||
for nloss, nlen, rlen in zip(loss, lens, replaced_lens):
|
||||
nlen, rlen = int(nlen), int(rlen)
|
||||
nloss = nloss[:nlen]
|
||||
nloss = nloss[rlen:].float().sum().cpu().detach().numpy()
|
||||
loglikelihoods.append(-nloss)
|
||||
return np.array(loglikelihoods)
|
||||
|
||||
def get_mink_percent(self,
|
||||
input_texts: List[str],
|
||||
k: int = 20) -> List[float]:
|
||||
"""https://swj0419.github.io/detect-pretrain.github.io/"""
|
||||
outputs, inputs = self.get_logits(input_texts)
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
shift_labels = inputs[..., 1:].contiguous()
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none',
|
||||
ignore_index=self.pad_id)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1)).view(shift_labels.size())
|
||||
lens = (inputs != self.pad_id).sum(-1).cpu().numpy()
|
||||
mink_percent = []
|
||||
for nloss, nlen in zip(loss, lens):
|
||||
nlen = int(nlen)
|
||||
minklen = max(nlen * k // 100, 1)
|
||||
nloss = torch.topk(loss[-nlen:], minklen, dim=-1)[0]
|
||||
nloss = -nloss.float().mean().cpu().detach().numpy()
|
||||
mink_percent.append(nloss)
|
||||
return np.array(mink_percent)
|
||||
|
||||
def get_logits(self, input_texts: Union[str, List[str]]):
|
||||
tokens = self.batch_encode(input_texts, max_seq_len=self.max_seq_len)
|
||||
outputs = self.model(input_ids=tokens)
|
||||
if isinstance(outputs, tuple):
|
||||
# moe returns (hidden_states, moe_losses)
|
||||
outputs = outputs[0]
|
||||
return outputs, tokens
|
||||
|
||||
def batch_encode(self,
|
||||
input_texts: Union[str, List[str]],
|
||||
max_seq_len: int,
|
||||
left_padding=False):
|
||||
if isinstance(input_texts, str):
|
||||
input_texts = [input_texts]
|
||||
tokens = [self.tokenizer(text) for text in input_texts]
|
||||
max_len = min(max_seq_len, max([len(t) for t in tokens]))
|
||||
for i in range(len(tokens)):
|
||||
cur_input = tokens[i]
|
||||
padding_len = max_len - len(cur_input)
|
||||
if self.mode == 'none':
|
||||
cur_input = cur_input[:max_len]
|
||||
elif self.mode == 'mid' and len(cur_input) > max_len:
|
||||
mid_cut_len = max_len // 2
|
||||
cur_input = cur_input[:mid_cut_len] + cur_input[-mid_cut_len:]
|
||||
|
||||
if left_padding:
|
||||
# left padding with bos
|
||||
tokens[i] = [self.tokenizer.bos_id] * padding_len + cur_input
|
||||
else:
|
||||
tokens[i] = cur_input + [self.pad_id] * padding_len
|
||||
|
||||
return torch.LongTensor(tokens).cuda()
|
||||
|
||||
def batch_decode(self, outputs, stopping_criteria: List[str] = []):
|
||||
# outputs: bsz, seq_len
|
||||
output_text = []
|
||||
for output in outputs:
|
||||
text = self.tokenizer.decode(output.tolist())
|
||||
for stop_word in stopping_criteria:
|
||||
text = text.split(stop_word)[0]
|
||||
output_text.append(text)
|
||||
|
||||
return output_text
|
Loading…
Reference in New Issue
Block a user