[BUG] Fix eos token handling and add comments for InternTrain (#1569)

Co-authored-by: x54-729 <xingshuhao.dispatch@pjlab.org.cn>
This commit is contained in:
x54-729 2024-09-30 15:46:06 +08:00 committed by GitHub
parent 763d7755b6
commit bbdca5eb4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -79,6 +79,50 @@ class LegacyInternTrainManager(InternTrainManager):
@MODELS.register_module() @MODELS.register_module()
class InternTrain(BaseModel): class InternTrain(BaseModel):
"""Model wrapper for InternTrain.
Args:
path (str): The name or path to HuggingFace's model.
module_path (str): Path of InternTrain repository.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_type: InternTrain's tokenizer type. Defaults to 'InternLM'.
model_config (str, dict, optional): Config of model. There are several
options for this parameter:
- filename (str): The config items are defined in a python file
so the model will load configs from this file.
- config (dict): The configuration items are defined in a dict
and the model will be initialized from ```model_config```.
- None: The config is loaded from ```path```. In this case,
please make sure that ```path``` contains a config file named
``model_config.pt``.
Defaults to None.
model_type: Type of model. Defaults to 'InternTrain'
ckpt_type: The type of load function in InternTrain when checkpoints
are loaded. Defaults to None, which means load the checkpoint
directlywith pipeline merged.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
model_dtype: The model's dtype. If None, will use dtype defined in
```model_config```. Defaults to None.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
sync_rank (bool): Whether to sync inputs between ranks. Do not use this
if you are not familiar with this behavior. Check `sync_inputs`
function for more details. Defaults to False.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'mid' represents the part of input to
truncate. Defaults to 'none'.
end_str (str, optional): Whether to trim generated strings with end_str
if the model has special ending strings that are not handled well.
Defaults to None.
"""
def __init__(self, def __init__(self,
path: str, path: str,
@ -87,14 +131,15 @@ class InternTrain(BaseModel):
tokenizer_only: bool = False, tokenizer_only: bool = False,
tokenizer_path: Optional[str] = None, tokenizer_path: Optional[str] = None,
tokenizer_type: str = 'INTERNLM', tokenizer_type: str = 'INTERNLM',
model_config: Optional[str] = None, model_config: Optional[Union[str, Dict]] = None,
model_type: str = 'INTERNLM2', model_type: str = 'INTERNLM2',
ckpt_type: Optional[str] = None, ckpt_type: Optional[str] = None,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
model_dtype: Optional[str] = None, model_dtype: Optional[str] = None,
generation_kwargs={}, generation_kwargs={},
sync_rank: bool = False, sync_rank: bool = False,
mode='none'): mode='none',
end_str: Optional[str] = None):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only, tokenizer_only=tokenizer_only,
@ -146,6 +191,7 @@ class InternTrain(BaseModel):
bos_token_id=self.tokenizer.bos_id, bos_token_id=self.tokenizer.bos_id,
pad_token_id=self.tokenizer.bos_id, pad_token_id=self.tokenizer.bos_id,
eos_token_id=eos_token_ids) eos_token_id=eos_token_ids)
self.end_str = end_str
def _load_model(self, def _load_model(self,
path: str, path: str,
@ -287,8 +333,10 @@ class InternTrain(BaseModel):
max_length=tokens.shape[1] + max_out_len, max_length=tokens.shape[1] + max_out_len,
**self.generation_kwargs) # bsz, num_return_sequences, max_length **self.generation_kwargs) # bsz, num_return_sequences, max_length
outputs = outputs[:, 0, tokens.shape[1]:] outputs = outputs[:, 0, tokens.shape[1]:]
output_text = self.batch_decode(outputs, output_text = self.batch_decode(
stopping_criteria=stopping_criteria) outputs,
eos_token_ids=self.generator.eos_token_id,
stopping_criteria=stopping_criteria)
return output_text return output_text
@ -407,11 +455,22 @@ class InternTrain(BaseModel):
return torch.LongTensor(tokens).cuda() return torch.LongTensor(tokens).cuda()
def batch_decode(self, outputs, stopping_criteria: List[str] = []): def batch_decode(self,
outputs,
eos_token_ids: List[int],
stopping_criteria: List[str] = []):
# outputs: bsz, seq_len # outputs: bsz, seq_len
output_text = [] output_text = []
outputs = outputs.tolist()
for output in outputs: for output in outputs:
text = self.tokenizer.decode(output.tolist()) # cut off by eos_token_ids
eos_idx = len(output)
for eos_id in eos_token_ids:
if eos_id in output:
eos_idx = min(output.index(eos_id), eos_idx)
text = self.tokenizer.decode(output[:eos_idx])
if self.end_str is not None:
text = text.split(self.end_str)[0]
for stop_word in stopping_criteria: for stop_word in stopping_criteria:
text = text.split(stop_word)[0] text = text.split(stop_word)[0]
output_text.append(text) output_text.append(text)