mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[BUG] Fix eos token handling and add comments for InternTrain (#1569)
Co-authored-by: x54-729 <xingshuhao.dispatch@pjlab.org.cn>
This commit is contained in:
parent
763d7755b6
commit
bbdca5eb4c
@ -79,6 +79,50 @@ class LegacyInternTrainManager(InternTrainManager):
|
|||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class InternTrain(BaseModel):
|
class InternTrain(BaseModel):
|
||||||
|
"""Model wrapper for InternTrain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name or path to HuggingFace's model.
|
||||||
|
module_path (str): Path of InternTrain repository.
|
||||||
|
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||||
|
to 2048.
|
||||||
|
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||||
|
Defaults to False.
|
||||||
|
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||||
|
tokenizer_type: InternTrain's tokenizer type. Defaults to 'InternLM'.
|
||||||
|
model_config (str, dict, optional): Config of model. There are several
|
||||||
|
options for this parameter:
|
||||||
|
|
||||||
|
- filename (str): The config items are defined in a python file
|
||||||
|
so the model will load configs from this file.
|
||||||
|
- config (dict): The configuration items are defined in a dict
|
||||||
|
and the model will be initialized from ```model_config```.
|
||||||
|
- None: The config is loaded from ```path```. In this case,
|
||||||
|
please make sure that ```path``` contains a config file named
|
||||||
|
``model_config.pt``.
|
||||||
|
|
||||||
|
Defaults to None.
|
||||||
|
model_type: Type of model. Defaults to 'InternTrain'
|
||||||
|
ckpt_type: The type of load function in InternTrain when checkpoints
|
||||||
|
are loaded. Defaults to None, which means load the checkpoint
|
||||||
|
directlywith pipeline merged.
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
model_dtype: The model's dtype. If None, will use dtype defined in
|
||||||
|
```model_config```. Defaults to None.
|
||||||
|
generation_kwargs (Dict, optional): The generation kwargs for the
|
||||||
|
model. Defaults to dict().
|
||||||
|
sync_rank (bool): Whether to sync inputs between ranks. Do not use this
|
||||||
|
if you are not familiar with this behavior. Check `sync_inputs`
|
||||||
|
function for more details. Defaults to False.
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'mid' represents the part of input to
|
||||||
|
truncate. Defaults to 'none'.
|
||||||
|
end_str (str, optional): Whether to trim generated strings with end_str
|
||||||
|
if the model has special ending strings that are not handled well.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
path: str,
|
path: str,
|
||||||
@ -87,14 +131,15 @@ class InternTrain(BaseModel):
|
|||||||
tokenizer_only: bool = False,
|
tokenizer_only: bool = False,
|
||||||
tokenizer_path: Optional[str] = None,
|
tokenizer_path: Optional[str] = None,
|
||||||
tokenizer_type: str = 'INTERNLM',
|
tokenizer_type: str = 'INTERNLM',
|
||||||
model_config: Optional[str] = None,
|
model_config: Optional[Union[str, Dict]] = None,
|
||||||
model_type: str = 'INTERNLM2',
|
model_type: str = 'INTERNLM2',
|
||||||
ckpt_type: Optional[str] = None,
|
ckpt_type: Optional[str] = None,
|
||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
model_dtype: Optional[str] = None,
|
model_dtype: Optional[str] = None,
|
||||||
generation_kwargs={},
|
generation_kwargs={},
|
||||||
sync_rank: bool = False,
|
sync_rank: bool = False,
|
||||||
mode='none'):
|
mode='none',
|
||||||
|
end_str: Optional[str] = None):
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
tokenizer_only=tokenizer_only,
|
tokenizer_only=tokenizer_only,
|
||||||
@ -146,6 +191,7 @@ class InternTrain(BaseModel):
|
|||||||
bos_token_id=self.tokenizer.bos_id,
|
bos_token_id=self.tokenizer.bos_id,
|
||||||
pad_token_id=self.tokenizer.bos_id,
|
pad_token_id=self.tokenizer.bos_id,
|
||||||
eos_token_id=eos_token_ids)
|
eos_token_id=eos_token_ids)
|
||||||
|
self.end_str = end_str
|
||||||
|
|
||||||
def _load_model(self,
|
def _load_model(self,
|
||||||
path: str,
|
path: str,
|
||||||
@ -287,8 +333,10 @@ class InternTrain(BaseModel):
|
|||||||
max_length=tokens.shape[1] + max_out_len,
|
max_length=tokens.shape[1] + max_out_len,
|
||||||
**self.generation_kwargs) # bsz, num_return_sequences, max_length
|
**self.generation_kwargs) # bsz, num_return_sequences, max_length
|
||||||
outputs = outputs[:, 0, tokens.shape[1]:]
|
outputs = outputs[:, 0, tokens.shape[1]:]
|
||||||
output_text = self.batch_decode(outputs,
|
output_text = self.batch_decode(
|
||||||
stopping_criteria=stopping_criteria)
|
outputs,
|
||||||
|
eos_token_ids=self.generator.eos_token_id,
|
||||||
|
stopping_criteria=stopping_criteria)
|
||||||
|
|
||||||
return output_text
|
return output_text
|
||||||
|
|
||||||
@ -407,11 +455,22 @@ class InternTrain(BaseModel):
|
|||||||
|
|
||||||
return torch.LongTensor(tokens).cuda()
|
return torch.LongTensor(tokens).cuda()
|
||||||
|
|
||||||
def batch_decode(self, outputs, stopping_criteria: List[str] = []):
|
def batch_decode(self,
|
||||||
|
outputs,
|
||||||
|
eos_token_ids: List[int],
|
||||||
|
stopping_criteria: List[str] = []):
|
||||||
# outputs: bsz, seq_len
|
# outputs: bsz, seq_len
|
||||||
output_text = []
|
output_text = []
|
||||||
|
outputs = outputs.tolist()
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
text = self.tokenizer.decode(output.tolist())
|
# cut off by eos_token_ids
|
||||||
|
eos_idx = len(output)
|
||||||
|
for eos_id in eos_token_ids:
|
||||||
|
if eos_id in output:
|
||||||
|
eos_idx = min(output.index(eos_id), eos_idx)
|
||||||
|
text = self.tokenizer.decode(output[:eos_idx])
|
||||||
|
if self.end_str is not None:
|
||||||
|
text = text.split(self.end_str)[0]
|
||||||
for stop_word in stopping_criteria:
|
for stop_word in stopping_criteria:
|
||||||
text = text.split(stop_word)[0]
|
text = text.split(stop_word)[0]
|
||||||
output_text.append(text)
|
output_text.append(text)
|
||||||
|
Loading…
Reference in New Issue
Block a user