mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Enable Truncation of Mid-Section for Long Prompts in huggingface_above_v4_33.py
(#1373)
* Retain the first and last halves of the tokens from the prompt, discarding the middle, to avoid exceeding the model's maximum length. * Add default parameter: mode * Modified a comment. * Modified variable names. * fix yapf lint
This commit is contained in:
parent
88eb91219b
commit
59586a8b4a
@ -2,6 +2,8 @@
|
||||
# yapf: disable
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||
from opencompass.models.base_api import APITemplateParser
|
||||
from opencompass.registry import MODELS
|
||||
@ -140,6 +142,13 @@ def _set_model_kwargs_torch_dtype(model_kwargs):
|
||||
|
||||
@MODELS.register_module()
|
||||
class HuggingFacewithChatTemplate(BaseModel):
|
||||
"""Model wrapper for HuggingFace models designed for chat.
|
||||
|
||||
Args:
|
||||
mode (str, optional): The method of input truncation when input length
|
||||
exceeds max_seq_len. 'mid' represents the part of input to
|
||||
truncate. Defaults to 'none'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
@ -155,6 +164,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
pad_token_id: Optional[int] = None,
|
||||
fastchat_template: Optional[str] = None,
|
||||
stop_words: Optional[str] = [],
|
||||
mode: str = 'none',
|
||||
**other_kwargs):
|
||||
|
||||
self.logger = get_logger()
|
||||
@ -168,6 +178,8 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.fastchat_template = fastchat_template
|
||||
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
||||
assert mode in ['none', 'mid']
|
||||
self.mode = mode
|
||||
self.logger.info(f'using stop words: {self.stop_words}')
|
||||
|
||||
for k, v in other_kwargs.items():
|
||||
@ -431,6 +443,24 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
|
||||
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
||||
|
||||
if self.mode == 'mid':
|
||||
# Reserve space for the tokens to be generated in the future.
|
||||
max_prompt_len = self.max_seq_len - max_out_len
|
||||
|
||||
# Retain the first 0.5 * max_prompt_len tokens and the last 0.5 * max_prompt_len tokens, discarding the middle ones,
|
||||
# because the prompts' questions are usually at the beginning or the end.
|
||||
# To avoid the warning:
|
||||
# This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length.
|
||||
# Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
|
||||
half_max_prompt_len = max_prompt_len // 2
|
||||
if half_max_prompt_len > 0 and tokens['input_ids'].shape[1] > max_prompt_len:
|
||||
for key in tokens.keys():
|
||||
if tokens[key].shape[1] > max_prompt_len:
|
||||
field_values = tokens[key]
|
||||
tokens[key] = torch.cat(
|
||||
(field_values[:, :half_max_prompt_len], field_values[:, -half_max_prompt_len:]), dim=1
|
||||
)
|
||||
|
||||
generation_kwargs = self.generation_kwargs.copy()
|
||||
generation_kwargs.update(kwargs)
|
||||
stopping_criteria = list(set(stopping_criteria + self.stop_words))
|
||||
|
Loading…
Reference in New Issue
Block a user