mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Enable Truncation of Mid-Section for Long Prompts in huggingface_above_v4_33.py
(#1373)
* Retain the first and last halves of the tokens from the prompt, discarding the middle, to avoid exceeding the model's maximum length. * Add default parameter: mode * Modified a comment. * Modified variable names. * fix yapf lint
This commit is contained in:
parent
88eb91219b
commit
59586a8b4a
@ -2,6 +2,8 @@
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from opencompass.models.base import BaseModel, LMTemplateParser
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||||
from opencompass.models.base_api import APITemplateParser
|
from opencompass.models.base_api import APITemplateParser
|
||||||
from opencompass.registry import MODELS
|
from opencompass.registry import MODELS
|
||||||
@ -140,6 +142,13 @@ def _set_model_kwargs_torch_dtype(model_kwargs):
|
|||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class HuggingFacewithChatTemplate(BaseModel):
|
class HuggingFacewithChatTemplate(BaseModel):
|
||||||
|
"""Model wrapper for HuggingFace models designed for chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'mid' represents the part of input to
|
||||||
|
truncate. Defaults to 'none'.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
path: str,
|
path: str,
|
||||||
@ -155,6 +164,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
fastchat_template: Optional[str] = None,
|
fastchat_template: Optional[str] = None,
|
||||||
stop_words: Optional[str] = [],
|
stop_words: Optional[str] = [],
|
||||||
|
mode: str = 'none',
|
||||||
**other_kwargs):
|
**other_kwargs):
|
||||||
|
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
@ -168,6 +178,8 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.fastchat_template = fastchat_template
|
self.fastchat_template = fastchat_template
|
||||||
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
||||||
|
assert mode in ['none', 'mid']
|
||||||
|
self.mode = mode
|
||||||
self.logger.info(f'using stop words: {self.stop_words}')
|
self.logger.info(f'using stop words: {self.stop_words}')
|
||||||
|
|
||||||
for k, v in other_kwargs.items():
|
for k, v in other_kwargs.items():
|
||||||
@ -431,6 +443,24 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
|
|
||||||
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
||||||
|
|
||||||
|
if self.mode == 'mid':
|
||||||
|
# Reserve space for the tokens to be generated in the future.
|
||||||
|
max_prompt_len = self.max_seq_len - max_out_len
|
||||||
|
|
||||||
|
# Retain the first 0.5 * max_prompt_len tokens and the last 0.5 * max_prompt_len tokens, discarding the middle ones,
|
||||||
|
# because the prompts' questions are usually at the beginning or the end.
|
||||||
|
# To avoid the warning:
|
||||||
|
# This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length.
|
||||||
|
# Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
|
||||||
|
half_max_prompt_len = max_prompt_len // 2
|
||||||
|
if half_max_prompt_len > 0 and tokens['input_ids'].shape[1] > max_prompt_len:
|
||||||
|
for key in tokens.keys():
|
||||||
|
if tokens[key].shape[1] > max_prompt_len:
|
||||||
|
field_values = tokens[key]
|
||||||
|
tokens[key] = torch.cat(
|
||||||
|
(field_values[:, :half_max_prompt_len], field_values[:, -half_max_prompt_len:]), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
generation_kwargs = self.generation_kwargs.copy()
|
generation_kwargs = self.generation_kwargs.copy()
|
||||||
generation_kwargs.update(kwargs)
|
generation_kwargs.update(kwargs)
|
||||||
stopping_criteria = list(set(stopping_criteria + self.stop_words))
|
stopping_criteria = list(set(stopping_criteria + self.stop_words))
|
||||||
|
Loading…
Reference in New Issue
Block a user