[Feature] Enable Truncation of Mid-Section for Long Prompts in huggingface_above_v4_33.py (#1373)

* Retain the first and last halves of the tokens from the prompt, discarding the middle, to avoid exceeding the model's maximum length.

* Add default parameter: mode

* Modified a comment.

* Modified variable names.

* fix yapf lint
This commit is contained in:
changyeyu 2024-08-09 11:36:30 +08:00 committed by GitHub
parent 88eb91219b
commit 59586a8b4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,8 @@
# yapf: disable
from typing import Dict, List, Optional, Union
import torch
from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.models.base_api import APITemplateParser
from opencompass.registry import MODELS
@ -140,6 +142,13 @@ def _set_model_kwargs_torch_dtype(model_kwargs):
@MODELS.register_module()
class HuggingFacewithChatTemplate(BaseModel):
"""Model wrapper for HuggingFace models designed for chat.
Args:
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'mid' represents the part of input to
truncate. Defaults to 'none'.
"""
def __init__(self,
path: str,
@ -155,6 +164,7 @@ class HuggingFacewithChatTemplate(BaseModel):
pad_token_id: Optional[int] = None,
fastchat_template: Optional[str] = None,
stop_words: Optional[str] = [],
mode: str = 'none',
**other_kwargs):
self.logger = get_logger()
@ -168,6 +178,8 @@ class HuggingFacewithChatTemplate(BaseModel):
self.generation_kwargs = generation_kwargs
self.fastchat_template = fastchat_template
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
assert mode in ['none', 'mid']
self.mode = mode
self.logger.info(f'using stop words: {self.stop_words}')
for k, v in other_kwargs.items():
@ -431,6 +443,24 @@ class HuggingFacewithChatTemplate(BaseModel):
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
if self.mode == 'mid':
# Reserve space for the tokens to be generated in the future.
max_prompt_len = self.max_seq_len - max_out_len
# Retain the first 0.5 * max_prompt_len tokens and the last 0.5 * max_prompt_len tokens, discarding the middle ones,
# because the prompts' questions are usually at the beginning or the end.
# To avoid the warning:
# This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length.
# Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
half_max_prompt_len = max_prompt_len // 2
if half_max_prompt_len > 0 and tokens['input_ids'].shape[1] > max_prompt_len:
for key in tokens.keys():
if tokens[key].shape[1] > max_prompt_len:
field_values = tokens[key]
tokens[key] = torch.cat(
(field_values[:, :half_max_prompt_len], field_values[:, -half_max_prompt_len:]), dim=1
)
generation_kwargs = self.generation_kwargs.copy()
generation_kwargs.update(kwargs)
stopping_criteria = list(set(stopping_criteria + self.stop_words))