mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add h2o@infinitebench implementation
This commit is contained in:
parent
bdb2d46f59
commit
61afff8836
@ -258,7 +258,7 @@ def main():
|
|||||||
cfg.dump(output_config_path)
|
cfg.dump(output_config_path)
|
||||||
# Config is intentally reloaded here to avoid initialized
|
# Config is intentally reloaded here to avoid initialized
|
||||||
# types cannot be serialized
|
# types cannot be serialized
|
||||||
cfg = Config.fromfile(output_config_path, format_python_code=False)
|
# cfg = Config.fromfile(output_config_path, format_python_code=False)
|
||||||
|
|
||||||
# report to lark bot if specify --lark
|
# report to lark bot if specify --lark
|
||||||
if not args.lark:
|
if not args.lark:
|
||||||
|
352
opencompass/configs/models/hf_llama/modify_llama.py
Normal file
352
opencompass/configs/models/hf_llama/modify_llama.py
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from opencompass.models import HuggingFaceBaseModel
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebench import (
|
||||||
|
infinitebench_datasets,
|
||||||
|
)
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from mmengine.device import is_npu_available
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
LlamaAttention,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"]
|
||||||
|
|
||||||
|
|
||||||
|
def local_heavy_hitter_mask(attn_weights, heavy_budget):
|
||||||
|
# attn_weights (BS, head, query, keys)
|
||||||
|
dtype_attn_weights = attn_weights.dtype
|
||||||
|
seq_length = attn_weights.shape[-1]
|
||||||
|
padding_length = 0
|
||||||
|
|
||||||
|
offset = torch.finfo(attn_weights.dtype).min
|
||||||
|
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||||
|
dtype_attn_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulated_attention_score = torch.sum(
|
||||||
|
tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2
|
||||||
|
) # (head, keys)
|
||||||
|
accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0
|
||||||
|
accumulated_attention_score[:, :, :padding_length] = 0
|
||||||
|
|
||||||
|
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
||||||
|
mask_bottom[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
padding_length : heavy_budget + padding_length,
|
||||||
|
padding_length : heavy_budget + padding_length,
|
||||||
|
] = True
|
||||||
|
|
||||||
|
for token_index in range(heavy_budget + padding_length, seq_length):
|
||||||
|
tmp_attn_index = nn.functional.softmax(
|
||||||
|
attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32
|
||||||
|
).to(dtype_attn_weights)
|
||||||
|
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1)
|
||||||
|
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
|
||||||
|
mask_bottom_index = zeros_index.scatter(
|
||||||
|
-1, tmp_topk_index, True
|
||||||
|
) # (head, keys)
|
||||||
|
mask_bottom_index[:, :, token_index] = True
|
||||||
|
|
||||||
|
mask_bottom[:, :, token_index, :] = mask_bottom_index
|
||||||
|
accumulated_attention_score += tmp_attn_index
|
||||||
|
accumulated_attention_score = accumulated_attention_score * mask_bottom_index
|
||||||
|
|
||||||
|
return mask_bottom
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention_heavy_hitter(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||||||
|
self.config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.heavy_budget_ratio = config.heavy_ratio
|
||||||
|
self.recent_budget_ratio = config.recent_ratio
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return (
|
||||||
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = (
|
||||||
|
self.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(
|
||||||
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||||
|
)
|
||||||
|
|
||||||
|
### Heavy + Recent
|
||||||
|
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
|
||||||
|
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])
|
||||||
|
|
||||||
|
# # Heavy Hitter Mask (Based on local statistics)
|
||||||
|
# if heavy_budget > 0:
|
||||||
|
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
|
||||||
|
# else:
|
||||||
|
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
||||||
|
|
||||||
|
# ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
||||||
|
# ones = torch.triu(ones, diagonal=-recent_budget)
|
||||||
|
# mask_bottom = torch.logical_or(mask_bottom, ones)
|
||||||
|
|
||||||
|
# mask_bottom = torch.tril(mask_bottom, diagonal=0)
|
||||||
|
|
||||||
|
# # mask_bottom = ones
|
||||||
|
# attn_weights[~mask_bottom] = torch.min(attention_mask)
|
||||||
|
|
||||||
|
# Heavy Hitter Mask (Based on global statistics)
|
||||||
|
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||||
|
attn_weights.dtype
|
||||||
|
)
|
||||||
|
tmp_sum = torch.sum(tmp_attn, dim=-2)
|
||||||
|
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
|
||||||
|
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
|
||||||
|
mask_bottom = mask_bottom.expand(
|
||||||
|
mask_bottom.shape[0],
|
||||||
|
mask_bottom.shape[1],
|
||||||
|
attn_weights.shape[-2],
|
||||||
|
mask_bottom.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
||||||
|
ones = torch.tril(ones, diagonal=recent_budget)
|
||||||
|
ones = torch.triu(ones, diagonal=-recent_budget)
|
||||||
|
mask_bottom = torch.logical_or(mask_bottom, ones)
|
||||||
|
# mask_bottom = ones
|
||||||
|
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def convert_kvcache_llama_heavy_recent(model, config):
|
||||||
|
for name, module in reversed(model._modules.items()):
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
model._modules[name] = convert_kvcache_llama_heavy_recent(module, config)
|
||||||
|
|
||||||
|
if isinstance(module, LlamaAttention):
|
||||||
|
model._modules[name] = LlamaAttention_heavy_hitter(config)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
ENABLE_Heavy_Hitter_FUNCTIONS = {
|
||||||
|
"llama": convert_kvcache_llama_heavy_recent,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_model_kwargs_torch_dtype(model_kwargs):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if "torch_dtype" not in model_kwargs:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = {
|
||||||
|
"torch.float16": torch.float16,
|
||||||
|
"torch.bfloat16": torch.bfloat16,
|
||||||
|
"torch.float": torch.float,
|
||||||
|
"auto": "auto",
|
||||||
|
"None": None,
|
||||||
|
}.get(model_kwargs["torch_dtype"])
|
||||||
|
if torch_dtype is not None:
|
||||||
|
model_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict = dict(),
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
tokenizer_kwargs: dict = dict(),
|
||||||
|
peft_path: Optional[str] = None,
|
||||||
|
peft_kwargs: dict = dict(),
|
||||||
|
tokenizer_only: bool = False,
|
||||||
|
generation_kwargs: dict = dict(),
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
stop_words: Optional[str] = [],
|
||||||
|
drop_middle: bool = False,
|
||||||
|
**other_kwargs,
|
||||||
|
):
|
||||||
|
self.heavy_ratio = other_kwargs["heavy_ratio"]
|
||||||
|
self.recent_ratio = other_kwargs["recent_ratio"]
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
peft_path=peft_path,
|
||||||
|
peft_kwargs=peft_kwargs,
|
||||||
|
tokenizer_only=tokenizer_only,
|
||||||
|
generation_kwargs=generation_kwargs,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
stop_words=stop_words,
|
||||||
|
drop_middle=drop_middle
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
kwargs: dict,
|
||||||
|
peft_path: Optional[str] = None,
|
||||||
|
peft_kwargs: dict = dict(),
|
||||||
|
):
|
||||||
|
# self.logger("load modified model")
|
||||||
|
# exit(0)
|
||||||
|
from transformers import AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True)
|
||||||
|
model_kwargs = DEFAULT_MODEL_KWARGS
|
||||||
|
model_kwargs.update(kwargs)
|
||||||
|
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
||||||
|
self.logger.debug(f"using model_kwargs: {model_kwargs}")
|
||||||
|
if is_npu_available():
|
||||||
|
model_kwargs["device_map"] = "npu"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
|
except ValueError:
|
||||||
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||||
|
|
||||||
|
if peft_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
peft_kwargs["is_trainable"] = False
|
||||||
|
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(path)
|
||||||
|
config.heavy_ratio = self.heavy_ratio
|
||||||
|
self.model.eval()
|
||||||
|
self.model.generation_config.do_sample = False
|
121
opencompass/configs/models/hf_llama/sparse_ben.py
Normal file
121
opencompass/configs/models/hf_llama/sparse_ben.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from opencompass.configs.models.hf_llama.modify_llama import H2OLLAMABenchmarkRunner
|
||||||
|
from opencompass.datasets import InfiniteBenchcodedebugDataset, InfiniteBenchcoderunDataset, InfiniteBenchendiaDataset, InfiniteBenchenmcDataset, InfiniteBenchenqaDataset, InfiniteBenchensumDataset, InfiniteBenchmathcalcDataset, InfiniteBenchmathfindDataset, InfiniteBenchretrievekvDataset, InfiniteBenchretrievenumberDataset, InfiniteBenchretrievepasskeyDataset, InfiniteBenchzhqaDataset
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchcodedebug.infinitebench_codedebug_gen_276a42 import InfiniteBench_codedebug_reader_cfg, InfiniteBench_codedebug_infer_cfg, InfiniteBench_codedebug_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchcoderun.infinitebench_coderun_gen_1a76bd import InfiniteBench_coderun_reader_cfg, InfiniteBench_coderun_infer_cfg, InfiniteBench_coderun_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchendia.infinitebench_endia_gen_c96eb5 import InfiniteBench_endia_reader_cfg, InfiniteBench_endia_infer_cfg, InfiniteBench_endia_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchenmc.infinitebench_enmc_gen_3a4102 import InfiniteBench_enmc_reader_cfg, InfiniteBench_enmc_infer_cfg, InfiniteBench_enmc_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchenqa.infinitebench_enqa_gen_a1640c import InfiniteBench_enqa_reader_cfg, InfiniteBench_enqa_infer_cfg, InfiniteBench_enqa_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchensum.infinitebench_ensum_gen_cfbc08 import InfiniteBench_ensum_reader_cfg, InfiniteBench_ensum_infer_cfg, InfiniteBench_ensum_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchmathcalc.infinitebench_mathcalc_gen_78d17e import InfiniteBench_mathcalc_reader_cfg, InfiniteBench_mathcalc_infer_cfg, InfiniteBench_mathcalc_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchmathfind.infinitebench_mathfind_gen_6d799e import InfiniteBench_mathfind_reader_cfg, InfiniteBench_mathfind_infer_cfg, InfiniteBench_mathfind_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchretrievekv.infinitebench_retrievekv_gen_06b3ac import InfiniteBench_retrievekv_reader_cfg, InfiniteBench_retrievekv_infer_cfg, InfiniteBench_retrievekv_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchretrievepasskey.infinitebench_retrievepasskey_gen_62ff68 import InfiniteBench_retrievepasskey_reader_cfg, InfiniteBench_retrievepasskey_infer_cfg, InfiniteBench_retrievepasskey_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchretrievenumber.infinitebench_retrievenumber_gen_047436 import InfiniteBench_retrievenumber_reader_cfg, InfiniteBench_retrievenumber_infer_cfg, InfiniteBench_retrievenumber_eval_cfg
|
||||||
|
from opencompass.configs.datasets.infinitebench.infinitebenchzhqa.infinitebench_zhqa_gen_1e5293 import InfiniteBench_zhqa_reader_cfg, InfiniteBench_zhqa_infer_cfg, InfiniteBench_zhqa_eval_cfg
|
||||||
|
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=H2OLLAMABenchmarkRunner,
|
||||||
|
# for represent in result
|
||||||
|
abbr="sparse-llama-7b",
|
||||||
|
# for huggingface
|
||||||
|
path="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
max_out_len=1024,
|
||||||
|
batch_size=1,
|
||||||
|
run_cfg=dict(
|
||||||
|
num_gpus=1,
|
||||||
|
heavy_ratio=0.5,
|
||||||
|
),
|
||||||
|
# modify here
|
||||||
|
heavy_ratio=0.5,
|
||||||
|
recent_ratio=0.1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
datasets = [
|
||||||
|
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchcodedebugDataset,
|
||||||
|
abbr='InfiniteBench_codedebug',
|
||||||
|
path='./data/InfiniteBench/code_debug.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_codedebug_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_codedebug_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_codedebug_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchcoderunDataset,
|
||||||
|
abbr='InfiniteBench_coderun',
|
||||||
|
path='./data/InfiniteBench/code_run.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_coderun_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_coderun_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_coderun_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchendiaDataset,
|
||||||
|
abbr='InfiniteBench_endia',
|
||||||
|
path='./data/InfiniteBench/longdialogue_qa_eng.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_endia_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_endia_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_endia_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchenmcDataset,
|
||||||
|
abbr='InfiniteBench_enmc',
|
||||||
|
path='./data/InfiniteBench/longbook_choice_eng.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_enmc_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_enmc_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_enmc_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchenqaDataset,
|
||||||
|
abbr='InfiniteBench_enqa',
|
||||||
|
path='./data/InfiniteBench/longbook_qa_eng.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_enqa_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_enqa_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_enqa_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchensumDataset,
|
||||||
|
abbr='InfiniteBench_ensum',
|
||||||
|
path='./data/InfiniteBench/longbook_sum_eng.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_ensum_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_ensum_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_ensum_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchmathcalcDataset,
|
||||||
|
abbr='InfiniteBench_mathcalc',
|
||||||
|
path='./data/InfiniteBench/math_calc.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_mathcalc_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_mathcalc_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_mathcalc_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchmathfindDataset,
|
||||||
|
abbr='InfiniteBench_mathfind',
|
||||||
|
path='./data/InfiniteBench/math_find.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_mathfind_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_mathfind_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_mathfind_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchretrievekvDataset,
|
||||||
|
abbr='InfiniteBench_retrievekv',
|
||||||
|
path='./data/InfiniteBench/kv_retrieval.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_retrievekv_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_retrievekv_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_retrievekv_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchretrievenumberDataset,
|
||||||
|
abbr='InfiniteBench_retrievenumber',
|
||||||
|
path='./data/InfiniteBench/number_string.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_retrievenumber_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_retrievenumber_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_retrievenumber_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchretrievepasskeyDataset,
|
||||||
|
abbr='InfiniteBench_retrievepasskey',
|
||||||
|
path='./data/InfiniteBench/passkey.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_retrievepasskey_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_retrievepasskey_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_retrievepasskey_eval_cfg),
|
||||||
|
dict(
|
||||||
|
type=InfiniteBenchzhqaDataset,
|
||||||
|
abbr='InfiniteBench_zhqa',
|
||||||
|
path='./data/InfiniteBench/longbook_qa_chn.jsonl',
|
||||||
|
reader_cfg=InfiniteBench_zhqa_reader_cfg,
|
||||||
|
infer_cfg=InfiniteBench_zhqa_infer_cfg,
|
||||||
|
eval_cfg=InfiniteBench_zhqa_eval_cfg)
|
||||||
|
]
|
@ -9,8 +9,10 @@ from opencompass.models.base import BaseModel, LMTemplateParser
|
|||||||
from opencompass.models.base_api import APITemplateParser
|
from opencompass.models.base_api import APITemplateParser
|
||||||
from opencompass.registry import MODELS
|
from opencompass.registry import MODELS
|
||||||
from opencompass.utils.logging import get_logger
|
from opencompass.utils.logging import get_logger
|
||||||
|
from opencompass.models import HuggingFaceCausalLM
|
||||||
from opencompass.utils.prompt import PromptList
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
|
||||||
PromptType = Union[PromptList, str]
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
@ -506,7 +508,7 @@ def _convert_base_messages(inputs):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
|
class HuggingFaceBaseModel(HuggingFaceCausalLM):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
path: str,
|
path: str,
|
||||||
@ -529,7 +531,8 @@ class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
|
|||||||
self.template_parser = LMTemplateParser()
|
self.template_parser = LMTemplateParser()
|
||||||
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
||||||
self.drop_middle = drop_middle
|
self.drop_middle = drop_middle
|
||||||
self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id)
|
self.pad_token_id = None
|
||||||
|
self._load_tokenizer(tokenizer_path or path, None, tokenizer_kwargs)
|
||||||
if not tokenizer_only:
|
if not tokenizer_only:
|
||||||
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
|
@ -92,13 +92,14 @@ class OpenICLInferTask(BaseTask):
|
|||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
|
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
|
||||||
|
|
||||||
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
|
|
||||||
|
assert 'ice_template' in self.infer_cfg or 'prompt_template' in self.infer_cfg, \
|
||||||
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||||
if hasattr(self.infer_cfg, 'ice_template'):
|
if 'ice_template' in self.infer_cfg:
|
||||||
ice_template = ICL_PROMPT_TEMPLATES.build(
|
ice_template = ICL_PROMPT_TEMPLATES.build(
|
||||||
self.infer_cfg['ice_template'])
|
self.infer_cfg['ice_template'])
|
||||||
|
|
||||||
if hasattr(self.infer_cfg, 'prompt_template'):
|
if 'prompt_template' in self.infer_cfg:
|
||||||
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
||||||
self.infer_cfg['prompt_template'])
|
self.infer_cfg['prompt_template'])
|
||||||
|
|
||||||
@ -123,15 +124,15 @@ class OpenICLInferTask(BaseTask):
|
|||||||
out_dir, out_file = osp.split(out_path)
|
out_dir, out_file = osp.split(out_path)
|
||||||
mkdir_or_exist(out_dir)
|
mkdir_or_exist(out_dir)
|
||||||
|
|
||||||
if hasattr(self.infer_cfg, 'prompt_template') and \
|
if 'prompt_template' in self.infer_cfg and \
|
||||||
hasattr(self.infer_cfg, 'ice_template'):
|
'ice_template' in self.infer_cfg:
|
||||||
inferencer.inference(retriever,
|
inferencer.inference(retriever,
|
||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
output_json_filepath=out_dir,
|
output_json_filepath=out_dir,
|
||||||
output_json_filename=out_file)
|
output_json_filename=out_file)
|
||||||
elif hasattr(self.infer_cfg, 'prompt_template'):
|
elif "prompt_template" in self.infer_cfg:
|
||||||
inferencer.inference(retriever,
|
inferencer.inference(retriever,
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
output_json_filepath=out_dir,
|
output_json_filepath=out_dir,
|
||||||
output_json_filename=out_file)
|
output_json_filename=out_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user