From 59586a8b4a3e4dc2c24b6e55a3d1074e5fbe10ab Mon Sep 17 00:00:00 2001 From: changyeyu Date: Fri, 9 Aug 2024 11:36:30 +0800 Subject: [PATCH] [Feature] Enable Truncation of Mid-Section for Long Prompts in `huggingface_above_v4_33.py` (#1373) * Retain the first and last halves of the tokens from the prompt, discarding the middle, to avoid exceeding the model's maximum length. * Add default parameter: mode * Modified a comment. * Modified variable names. * fix yapf lint --- opencompass/models/huggingface_above_v4_33.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 23f3c830..0276ebba 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -2,6 +2,8 @@ # yapf: disable from typing import Dict, List, Optional, Union +import torch + from opencompass.models.base import BaseModel, LMTemplateParser from opencompass.models.base_api import APITemplateParser from opencompass.registry import MODELS @@ -140,6 +142,13 @@ def _set_model_kwargs_torch_dtype(model_kwargs): @MODELS.register_module() class HuggingFacewithChatTemplate(BaseModel): + """Model wrapper for HuggingFace models designed for chat. + + Args: + mode (str, optional): The method of input truncation when input length + exceeds max_seq_len. 'mid' represents the part of input to + truncate. Defaults to 'none'. + """ def __init__(self, path: str, @@ -155,6 +164,7 @@ class HuggingFacewithChatTemplate(BaseModel): pad_token_id: Optional[int] = None, fastchat_template: Optional[str] = None, stop_words: Optional[str] = [], + mode: str = 'none', **other_kwargs): self.logger = get_logger() @@ -168,6 +178,8 @@ class HuggingFacewithChatTemplate(BaseModel): self.generation_kwargs = generation_kwargs self.fastchat_template = fastchat_template self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) + assert mode in ['none', 'mid'] + self.mode = mode self.logger.info(f'using stop words: {self.stop_words}') for k, v in other_kwargs.items(): @@ -431,6 +443,24 @@ class HuggingFacewithChatTemplate(BaseModel): tokens = {k: v.to(self.model.device) for k, v in tokens.items()} + if self.mode == 'mid': + # Reserve space for the tokens to be generated in the future. + max_prompt_len = self.max_seq_len - max_out_len + + # Retain the first 0.5 * max_prompt_len tokens and the last 0.5 * max_prompt_len tokens, discarding the middle ones, + # because the prompts' questions are usually at the beginning or the end. + # To avoid the warning: + # This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length. + # Depending on the model, you may observe exceptions, performance degradation, or nothing at all. + half_max_prompt_len = max_prompt_len // 2 + if half_max_prompt_len > 0 and tokens['input_ids'].shape[1] > max_prompt_len: + for key in tokens.keys(): + if tokens[key].shape[1] > max_prompt_len: + field_values = tokens[key] + tokens[key] = torch.cat( + (field_values[:, :half_max_prompt_len], field_values[:, -half_max_prompt_len:]), dim=1 + ) + generation_kwargs = self.generation_kwargs.copy() generation_kwargs.update(kwargs) stopping_criteria = list(set(stopping_criteria + self.stop_words))