mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[BUG] Fix eos token handling and add comments for InternTrain (#1569)
Co-authored-by: x54-729 <xingshuhao.dispatch@pjlab.org.cn>
This commit is contained in:
parent
763d7755b6
commit
bbdca5eb4c
@ -79,6 +79,50 @@ class LegacyInternTrainManager(InternTrainManager):
|
||||
|
||||
@MODELS.register_module()
|
||||
class InternTrain(BaseModel):
|
||||
"""Model wrapper for InternTrain.
|
||||
|
||||
Args:
|
||||
path (str): The name or path to HuggingFace's model.
|
||||
module_path (str): Path of InternTrain repository.
|
||||
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||
to 2048.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_type: InternTrain's tokenizer type. Defaults to 'InternLM'.
|
||||
model_config (str, dict, optional): Config of model. There are several
|
||||
options for this parameter:
|
||||
|
||||
- filename (str): The config items are defined in a python file
|
||||
so the model will load configs from this file.
|
||||
- config (dict): The configuration items are defined in a dict
|
||||
and the model will be initialized from ```model_config```.
|
||||
- None: The config is loaded from ```path```. In this case,
|
||||
please make sure that ```path``` contains a config file named
|
||||
``model_config.pt``.
|
||||
|
||||
Defaults to None.
|
||||
model_type: Type of model. Defaults to 'InternTrain'
|
||||
ckpt_type: The type of load function in InternTrain when checkpoints
|
||||
are loaded. Defaults to None, which means load the checkpoint
|
||||
directlywith pipeline merged.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
model_dtype: The model's dtype. If None, will use dtype defined in
|
||||
```model_config```. Defaults to None.
|
||||
generation_kwargs (Dict, optional): The generation kwargs for the
|
||||
model. Defaults to dict().
|
||||
sync_rank (bool): Whether to sync inputs between ranks. Do not use this
|
||||
if you are not familiar with this behavior. Check `sync_inputs`
|
||||
function for more details. Defaults to False.
|
||||
mode (str, optional): The method of input truncation when input length
|
||||
exceeds max_seq_len. 'mid' represents the part of input to
|
||||
truncate. Defaults to 'none'.
|
||||
end_str (str, optional): Whether to trim generated strings with end_str
|
||||
if the model has special ending strings that are not handled well.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
@ -87,14 +131,15 @@ class InternTrain(BaseModel):
|
||||
tokenizer_only: bool = False,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_type: str = 'INTERNLM',
|
||||
model_config: Optional[str] = None,
|
||||
model_config: Optional[Union[str, Dict]] = None,
|
||||
model_type: str = 'INTERNLM2',
|
||||
ckpt_type: Optional[str] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
model_dtype: Optional[str] = None,
|
||||
generation_kwargs={},
|
||||
sync_rank: bool = False,
|
||||
mode='none'):
|
||||
mode='none',
|
||||
end_str: Optional[str] = None):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
@ -146,6 +191,7 @@ class InternTrain(BaseModel):
|
||||
bos_token_id=self.tokenizer.bos_id,
|
||||
pad_token_id=self.tokenizer.bos_id,
|
||||
eos_token_id=eos_token_ids)
|
||||
self.end_str = end_str
|
||||
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
@ -287,7 +333,9 @@ class InternTrain(BaseModel):
|
||||
max_length=tokens.shape[1] + max_out_len,
|
||||
**self.generation_kwargs) # bsz, num_return_sequences, max_length
|
||||
outputs = outputs[:, 0, tokens.shape[1]:]
|
||||
output_text = self.batch_decode(outputs,
|
||||
output_text = self.batch_decode(
|
||||
outputs,
|
||||
eos_token_ids=self.generator.eos_token_id,
|
||||
stopping_criteria=stopping_criteria)
|
||||
|
||||
return output_text
|
||||
@ -407,11 +455,22 @@ class InternTrain(BaseModel):
|
||||
|
||||
return torch.LongTensor(tokens).cuda()
|
||||
|
||||
def batch_decode(self, outputs, stopping_criteria: List[str] = []):
|
||||
def batch_decode(self,
|
||||
outputs,
|
||||
eos_token_ids: List[int],
|
||||
stopping_criteria: List[str] = []):
|
||||
# outputs: bsz, seq_len
|
||||
output_text = []
|
||||
outputs = outputs.tolist()
|
||||
for output in outputs:
|
||||
text = self.tokenizer.decode(output.tolist())
|
||||
# cut off by eos_token_ids
|
||||
eos_idx = len(output)
|
||||
for eos_id in eos_token_ids:
|
||||
if eos_id in output:
|
||||
eos_idx = min(output.index(eos_id), eos_idx)
|
||||
text = self.tokenizer.decode(output[:eos_idx])
|
||||
if self.end_str is not None:
|
||||
text = text.split(self.end_str)[0]
|
||||
for stop_word in stopping_criteria:
|
||||
text = text.split(stop_word)[0]
|
||||
output_text.append(text)
|
||||
|
Loading…
Reference in New Issue
Block a user