mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
353 lines
12 KiB
Python
353 lines
12 KiB
Python
import os
|
|
import pdb
|
|
import copy
|
|
import math
|
|
import numpy as np
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.utils.checkpoint
|
|
import torch.nn.functional as F
|
|
from torch.cuda.amp import autocast
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from typing import Optional
|
|
|
|
from opencompass.models import HuggingFaceBaseModel
|
|
from opencompass.configs.datasets.infinitebench.infinitebench import (
|
|
infinitebench_datasets,
|
|
)
|
|
from transformers import AutoConfig
|
|
|
|
from mmengine.device import is_npu_available
|
|
|
|
|
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
from transformers.models.llama.modeling_llama import (
|
|
LlamaRotaryEmbedding,
|
|
LlamaAttention,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
|
|
__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"]
|
|
|
|
|
|
def local_heavy_hitter_mask(attn_weights, heavy_budget):
|
|
# attn_weights (BS, head, query, keys)
|
|
dtype_attn_weights = attn_weights.dtype
|
|
seq_length = attn_weights.shape[-1]
|
|
padding_length = 0
|
|
|
|
offset = torch.finfo(attn_weights.dtype).min
|
|
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
dtype_attn_weights
|
|
)
|
|
|
|
accumulated_attention_score = torch.sum(
|
|
tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2
|
|
) # (head, keys)
|
|
accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0
|
|
accumulated_attention_score[:, :, :padding_length] = 0
|
|
|
|
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
|
mask_bottom[
|
|
:,
|
|
:,
|
|
padding_length : heavy_budget + padding_length,
|
|
padding_length : heavy_budget + padding_length,
|
|
] = True
|
|
|
|
for token_index in range(heavy_budget + padding_length, seq_length):
|
|
tmp_attn_index = nn.functional.softmax(
|
|
attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32
|
|
).to(dtype_attn_weights)
|
|
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1)
|
|
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
|
|
mask_bottom_index = zeros_index.scatter(
|
|
-1, tmp_topk_index, True
|
|
) # (head, keys)
|
|
mask_bottom_index[:, :, token_index] = True
|
|
|
|
mask_bottom[:, :, token_index, :] = mask_bottom_index
|
|
accumulated_attention_score += tmp_attn_index
|
|
accumulated_attention_score = accumulated_attention_score * mask_bottom_index
|
|
|
|
return mask_bottom
|
|
|
|
|
|
class LlamaAttention_heavy_hitter(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: LlamaConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
self.q_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.k_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.v_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
)
|
|
self.o_proj = nn.Linear(
|
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
|
)
|
|
self.rotary_emb = LlamaRotaryEmbedding(
|
|
self.config
|
|
)
|
|
|
|
self.heavy_budget_ratio = config.heavy_ratio
|
|
self.recent_budget_ratio = config.recent_ratio
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return (
|
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = (
|
|
self.q_proj(hidden_states)
|
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
key_states = (
|
|
self.k_proj(hidden_states)
|
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
value_states = (
|
|
self.v_proj(hidden_states)
|
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, position_ids
|
|
)
|
|
# [bsz, nh, t, hd]
|
|
|
|
if past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
attn_weights = torch.matmul(
|
|
query_states, key_states.transpose(2, 3)
|
|
) / math.sqrt(self.head_dim)
|
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights + attention_mask
|
|
attn_weights = torch.max(
|
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
)
|
|
|
|
### Heavy + Recent
|
|
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
|
|
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])
|
|
|
|
# # Heavy Hitter Mask (Based on local statistics)
|
|
# if heavy_budget > 0:
|
|
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
|
|
# else:
|
|
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
|
|
|
# ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
|
# ones = torch.triu(ones, diagonal=-recent_budget)
|
|
# mask_bottom = torch.logical_or(mask_bottom, ones)
|
|
|
|
# mask_bottom = torch.tril(mask_bottom, diagonal=0)
|
|
|
|
# # mask_bottom = ones
|
|
# attn_weights[~mask_bottom] = torch.min(attention_mask)
|
|
|
|
# Heavy Hitter Mask (Based on global statistics)
|
|
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
attn_weights.dtype
|
|
)
|
|
tmp_sum = torch.sum(tmp_attn, dim=-2)
|
|
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)
|
|
|
|
zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
|
|
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
|
|
mask_bottom = mask_bottom.expand(
|
|
mask_bottom.shape[0],
|
|
mask_bottom.shape[1],
|
|
attn_weights.shape[-2],
|
|
mask_bottom.shape[-1],
|
|
)
|
|
|
|
ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
|
ones = torch.tril(ones, diagonal=recent_budget)
|
|
ones = torch.triu(ones, diagonal=-recent_budget)
|
|
mask_bottom = torch.logical_or(mask_bottom, ones)
|
|
# mask_bottom = ones
|
|
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(query_states.dtype)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def convert_kvcache_llama_heavy_recent(model, config):
|
|
for name, module in reversed(model._modules.items()):
|
|
if len(list(module.children())) > 0:
|
|
model._modules[name] = convert_kvcache_llama_heavy_recent(module, config)
|
|
|
|
if isinstance(module, LlamaAttention):
|
|
model._modules[name] = LlamaAttention_heavy_hitter(config)
|
|
|
|
return model
|
|
|
|
|
|
ENABLE_Heavy_Hitter_FUNCTIONS = {
|
|
"llama": convert_kvcache_llama_heavy_recent,
|
|
}
|
|
|
|
|
|
def _set_model_kwargs_torch_dtype(model_kwargs):
|
|
import torch
|
|
|
|
if "torch_dtype" not in model_kwargs:
|
|
torch_dtype = torch.float16
|
|
else:
|
|
torch_dtype = {
|
|
"torch.float16": torch.float16,
|
|
"torch.bfloat16": torch.bfloat16,
|
|
"torch.float": torch.float,
|
|
"auto": "auto",
|
|
"None": None,
|
|
}.get(model_kwargs["torch_dtype"])
|
|
if torch_dtype is not None:
|
|
model_kwargs["torch_dtype"] = torch_dtype
|
|
return model_kwargs
|
|
|
|
|
|
class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel):
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
model_kwargs: dict = dict(),
|
|
tokenizer_path: Optional[str] = None,
|
|
tokenizer_kwargs: dict = dict(),
|
|
peft_path: Optional[str] = None,
|
|
peft_kwargs: dict = dict(),
|
|
tokenizer_only: bool = False,
|
|
generation_kwargs: dict = dict(),
|
|
max_seq_len: Optional[int] = None,
|
|
pad_token_id: Optional[int] = None,
|
|
stop_words: Optional[str] = [],
|
|
drop_middle: bool = False,
|
|
**other_kwargs,
|
|
):
|
|
self.heavy_ratio = other_kwargs["heavy_ratio"]
|
|
self.recent_ratio = other_kwargs["recent_ratio"]
|
|
super().__init__(
|
|
path=path,
|
|
model_kwargs=model_kwargs,
|
|
tokenizer_path=tokenizer_path,
|
|
tokenizer_kwargs=tokenizer_kwargs,
|
|
peft_path=peft_path,
|
|
peft_kwargs=peft_kwargs,
|
|
tokenizer_only=tokenizer_only,
|
|
generation_kwargs=generation_kwargs,
|
|
max_seq_len=max_seq_len,
|
|
pad_token_id=pad_token_id,
|
|
stop_words=stop_words,
|
|
drop_middle=drop_middle
|
|
)
|
|
|
|
def _load_model(
|
|
self,
|
|
path: str,
|
|
kwargs: dict,
|
|
peft_path: Optional[str] = None,
|
|
peft_kwargs: dict = dict(),
|
|
):
|
|
# self.logger("load modified model")
|
|
# exit(0)
|
|
from transformers import AutoModel, AutoModelForCausalLM
|
|
|
|
DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True)
|
|
model_kwargs = DEFAULT_MODEL_KWARGS
|
|
model_kwargs.update(kwargs)
|
|
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
|
self.logger.debug(f"using model_kwargs: {model_kwargs}")
|
|
if is_npu_available():
|
|
model_kwargs["device_map"] = "npu"
|
|
|
|
try:
|
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
|
except ValueError:
|
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
|
|
|
if peft_path is not None:
|
|
from peft import PeftModel
|
|
|
|
peft_kwargs["is_trainable"] = False
|
|
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
|
|
|
|
config = AutoConfig.from_pretrained(path)
|
|
config.heavy_ratio = self.heavy_ratio
|
|
self.model.eval()
|
|
self.model.generation_config.do_sample = False
|