OpenCompass/opencompass/configs/models/hf_llama/modify_llama.py
2025-02-27 14:22:07 +08:00

353 lines
12 KiB
Python

import os
import pdb
import copy
import math
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
import torch.utils.checkpoint
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional
from opencompass.models import HuggingFaceBaseModel
from opencompass.configs.datasets.infinitebench.infinitebench import (
infinitebench_datasets,
)
from transformers import AutoConfig
from mmengine.device import is_npu_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
LlamaAttention,
apply_rotary_pos_emb,
)
__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"]
def local_heavy_hitter_mask(attn_weights, heavy_budget):
# attn_weights (BS, head, query, keys)
dtype_attn_weights = attn_weights.dtype
seq_length = attn_weights.shape[-1]
padding_length = 0
offset = torch.finfo(attn_weights.dtype).min
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
dtype_attn_weights
)
accumulated_attention_score = torch.sum(
tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2
) # (head, keys)
accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0
accumulated_attention_score[:, :, :padding_length] = 0
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
mask_bottom[
:,
:,
padding_length : heavy_budget + padding_length,
padding_length : heavy_budget + padding_length,
] = True
for token_index in range(heavy_budget + padding_length, seq_length):
tmp_attn_index = nn.functional.softmax(
attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32
).to(dtype_attn_weights)
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1)
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
mask_bottom_index = zeros_index.scatter(
-1, tmp_topk_index, True
) # (head, keys)
mask_bottom_index[:, :, token_index] = True
mask_bottom[:, :, token_index, :] = mask_bottom_index
accumulated_attention_score += tmp_attn_index
accumulated_attention_score = accumulated_attention_score * mask_bottom_index
return mask_bottom
class LlamaAttention_heavy_hitter(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rotary_emb = LlamaRotaryEmbedding(
self.config
)
self.heavy_budget_ratio = config.heavy_ratio
self.recent_budget_ratio = config.recent_ratio
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
### Heavy + Recent
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])
# # Heavy Hitter Mask (Based on local statistics)
# if heavy_budget > 0:
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
# else:
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
# ones = torch.ones_like(attn_weights, dtype=torch.bool)
# ones = torch.triu(ones, diagonal=-recent_budget)
# mask_bottom = torch.logical_or(mask_bottom, ones)
# mask_bottom = torch.tril(mask_bottom, diagonal=0)
# # mask_bottom = ones
# attn_weights[~mask_bottom] = torch.min(attention_mask)
# Heavy Hitter Mask (Based on global statistics)
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
attn_weights.dtype
)
tmp_sum = torch.sum(tmp_attn, dim=-2)
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)
zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
mask_bottom = mask_bottom.expand(
mask_bottom.shape[0],
mask_bottom.shape[1],
attn_weights.shape[-2],
mask_bottom.shape[-1],
)
ones = torch.ones_like(attn_weights, dtype=torch.bool)
ones = torch.tril(ones, diagonal=recent_budget)
ones = torch.triu(ones, diagonal=-recent_budget)
mask_bottom = torch.logical_or(mask_bottom, ones)
# mask_bottom = ones
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def convert_kvcache_llama_heavy_recent(model, config):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
model._modules[name] = convert_kvcache_llama_heavy_recent(module, config)
if isinstance(module, LlamaAttention):
model._modules[name] = LlamaAttention_heavy_hitter(config)
return model
ENABLE_Heavy_Hitter_FUNCTIONS = {
"llama": convert_kvcache_llama_heavy_recent,
}
def _set_model_kwargs_torch_dtype(model_kwargs):
import torch
if "torch_dtype" not in model_kwargs:
torch_dtype = torch.float16
else:
torch_dtype = {
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.float": torch.float,
"auto": "auto",
"None": None,
}.get(model_kwargs["torch_dtype"])
if torch_dtype is not None:
model_kwargs["torch_dtype"] = torch_dtype
return model_kwargs
class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel):
def __init__(
self,
path: str,
model_kwargs: dict = dict(),
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
peft_path: Optional[str] = None,
peft_kwargs: dict = dict(),
tokenizer_only: bool = False,
generation_kwargs: dict = dict(),
max_seq_len: Optional[int] = None,
pad_token_id: Optional[int] = None,
stop_words: Optional[str] = [],
drop_middle: bool = False,
**other_kwargs,
):
self.heavy_ratio = other_kwargs["heavy_ratio"]
self.recent_ratio = other_kwargs["recent_ratio"]
super().__init__(
path=path,
model_kwargs=model_kwargs,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs,
peft_path=peft_path,
peft_kwargs=peft_kwargs,
tokenizer_only=tokenizer_only,
generation_kwargs=generation_kwargs,
max_seq_len=max_seq_len,
pad_token_id=pad_token_id,
stop_words=stop_words,
drop_middle=drop_middle
)
def _load_model(
self,
path: str,
kwargs: dict,
peft_path: Optional[str] = None,
peft_kwargs: dict = dict(),
):
# self.logger("load modified model")
# exit(0)
from transformers import AutoModel, AutoModelForCausalLM
DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True)
model_kwargs = DEFAULT_MODEL_KWARGS
model_kwargs.update(kwargs)
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
self.logger.debug(f"using model_kwargs: {model_kwargs}")
if is_npu_available():
model_kwargs["device_map"] = "npu"
try:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
except ValueError:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
if peft_path is not None:
from peft import PeftModel
peft_kwargs["is_trainable"] = False
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
config = AutoConfig.from_pretrained(path)
config.heavy_ratio = self.heavy_ratio
self.model.eval()
self.model.generation_config.do_sample = False