From bbdca5eb4cb08c24a386c22bf677d1856485f5f4 Mon Sep 17 00:00:00 2001 From: x54-729 <45304952+x54-729@users.noreply.github.com> Date: Mon, 30 Sep 2024 15:46:06 +0800 Subject: [PATCH] [BUG] Fix eos token handling and add comments for InternTrain (#1569) Co-authored-by: x54-729 --- opencompass/models/interntrain.py | 71 ++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/opencompass/models/interntrain.py b/opencompass/models/interntrain.py index d6c233cd..6d904acf 100644 --- a/opencompass/models/interntrain.py +++ b/opencompass/models/interntrain.py @@ -79,6 +79,50 @@ class LegacyInternTrainManager(InternTrainManager): @MODELS.register_module() class InternTrain(BaseModel): + """Model wrapper for InternTrain. + + Args: + path (str): The name or path to HuggingFace's model. + module_path (str): Path of InternTrain repository. + max_seq_len (int): The maximum length of the input sequence. Defaults + to 2048. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + tokenizer_path (str): The path to the tokenizer. Defaults to None. + tokenizer_type: InternTrain's tokenizer type. Defaults to 'InternLM'. + model_config (str, dict, optional): Config of model. There are several + options for this parameter: + + - filename (str): The config items are defined in a python file + so the model will load configs from this file. + - config (dict): The configuration items are defined in a dict + and the model will be initialized from ```model_config```. + - None: The config is loaded from ```path```. In this case, + please make sure that ```path``` contains a config file named + ``model_config.pt``. + + Defaults to None. + model_type: Type of model. Defaults to 'InternTrain' + ckpt_type: The type of load function in InternTrain when checkpoints + are loaded. Defaults to None, which means load the checkpoint + directlywith pipeline merged. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + model_dtype: The model's dtype. If None, will use dtype defined in + ```model_config```. Defaults to None. + generation_kwargs (Dict, optional): The generation kwargs for the + model. Defaults to dict(). + sync_rank (bool): Whether to sync inputs between ranks. Do not use this + if you are not familiar with this behavior. Check `sync_inputs` + function for more details. Defaults to False. + mode (str, optional): The method of input truncation when input length + exceeds max_seq_len. 'mid' represents the part of input to + truncate. Defaults to 'none'. + end_str (str, optional): Whether to trim generated strings with end_str + if the model has special ending strings that are not handled well. + Defaults to None. + """ def __init__(self, path: str, @@ -87,14 +131,15 @@ class InternTrain(BaseModel): tokenizer_only: bool = False, tokenizer_path: Optional[str] = None, tokenizer_type: str = 'INTERNLM', - model_config: Optional[str] = None, + model_config: Optional[Union[str, Dict]] = None, model_type: str = 'INTERNLM2', ckpt_type: Optional[str] = None, meta_template: Optional[Dict] = None, model_dtype: Optional[str] = None, generation_kwargs={}, sync_rank: bool = False, - mode='none'): + mode='none', + end_str: Optional[str] = None): super().__init__(path=path, max_seq_len=max_seq_len, tokenizer_only=tokenizer_only, @@ -146,6 +191,7 @@ class InternTrain(BaseModel): bos_token_id=self.tokenizer.bos_id, pad_token_id=self.tokenizer.bos_id, eos_token_id=eos_token_ids) + self.end_str = end_str def _load_model(self, path: str, @@ -287,8 +333,10 @@ class InternTrain(BaseModel): max_length=tokens.shape[1] + max_out_len, **self.generation_kwargs) # bsz, num_return_sequences, max_length outputs = outputs[:, 0, tokens.shape[1]:] - output_text = self.batch_decode(outputs, - stopping_criteria=stopping_criteria) + output_text = self.batch_decode( + outputs, + eos_token_ids=self.generator.eos_token_id, + stopping_criteria=stopping_criteria) return output_text @@ -407,11 +455,22 @@ class InternTrain(BaseModel): return torch.LongTensor(tokens).cuda() - def batch_decode(self, outputs, stopping_criteria: List[str] = []): + def batch_decode(self, + outputs, + eos_token_ids: List[int], + stopping_criteria: List[str] = []): # outputs: bsz, seq_len output_text = [] + outputs = outputs.tolist() for output in outputs: - text = self.tokenizer.decode(output.tolist()) + # cut off by eos_token_ids + eos_idx = len(output) + for eos_id in eos_token_ids: + if eos_id in output: + eos_idx = min(output.index(eos_id), eos_idx) + text = self.tokenizer.decode(output[:eos_idx]) + if self.end_str is not None: + text = text.split(self.end_str)[0] for stop_word in stopping_criteria: text = text.split(stop_word)[0] output_text.append(text)