mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add h2o@infinitebench implementation
This commit is contained in:
parent
bdb2d46f59
commit
61afff8836
@ -258,7 +258,7 @@ def main():
|
||||
cfg.dump(output_config_path)
|
||||
# Config is intentally reloaded here to avoid initialized
|
||||
# types cannot be serialized
|
||||
cfg = Config.fromfile(output_config_path, format_python_code=False)
|
||||
# cfg = Config.fromfile(output_config_path, format_python_code=False)
|
||||
|
||||
# report to lark bot if specify --lark
|
||||
if not args.lark:
|
||||
|
352
opencompass/configs/models/hf_llama/modify_llama.py
Normal file
352
opencompass/configs/models/hf_llama/modify_llama.py
Normal file
@ -0,0 +1,352 @@
|
||||
import os
|
||||
import pdb
|
||||
import copy
|
||||
import math
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from opencompass.models import HuggingFaceBaseModel
|
||||
from opencompass.configs.datasets.infinitebench.infinitebench import (
|
||||
infinitebench_datasets,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
from mmengine.device import is_npu_available
|
||||
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRotaryEmbedding,
|
||||
LlamaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"]
|
||||
|
||||
|
||||
def local_heavy_hitter_mask(attn_weights, heavy_budget):
|
||||
# attn_weights (BS, head, query, keys)
|
||||
dtype_attn_weights = attn_weights.dtype
|
||||
seq_length = attn_weights.shape[-1]
|
||||
padding_length = 0
|
||||
|
||||
offset = torch.finfo(attn_weights.dtype).min
|
||||
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
dtype_attn_weights
|
||||
)
|
||||
|
||||
accumulated_attention_score = torch.sum(
|
||||
tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2
|
||||
) # (head, keys)
|
||||
accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0
|
||||
accumulated_attention_score[:, :, :padding_length] = 0
|
||||
|
||||
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
||||
mask_bottom[
|
||||
:,
|
||||
:,
|
||||
padding_length : heavy_budget + padding_length,
|
||||
padding_length : heavy_budget + padding_length,
|
||||
] = True
|
||||
|
||||
for token_index in range(heavy_budget + padding_length, seq_length):
|
||||
tmp_attn_index = nn.functional.softmax(
|
||||
attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32
|
||||
).to(dtype_attn_weights)
|
||||
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1)
|
||||
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
|
||||
mask_bottom_index = zeros_index.scatter(
|
||||
-1, tmp_topk_index, True
|
||||
) # (head, keys)
|
||||
mask_bottom_index[:, :, token_index] = True
|
||||
|
||||
mask_bottom[:, :, token_index, :] = mask_bottom_index
|
||||
accumulated_attention_score += tmp_attn_index
|
||||
accumulated_attention_score = accumulated_attention_score * mask_bottom_index
|
||||
|
||||
return mask_bottom
|
||||
|
||||
|
||||
class LlamaAttention_heavy_hitter(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.config
|
||||
)
|
||||
|
||||
self.heavy_budget_ratio = config.heavy_ratio
|
||||
self.recent_budget_ratio = config.recent_ratio
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
### Heavy + Recent
|
||||
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
|
||||
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])
|
||||
|
||||
# # Heavy Hitter Mask (Based on local statistics)
|
||||
# if heavy_budget > 0:
|
||||
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
|
||||
# else:
|
||||
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
|
||||
|
||||
# ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
||||
# ones = torch.triu(ones, diagonal=-recent_budget)
|
||||
# mask_bottom = torch.logical_or(mask_bottom, ones)
|
||||
|
||||
# mask_bottom = torch.tril(mask_bottom, diagonal=0)
|
||||
|
||||
# # mask_bottom = ones
|
||||
# attn_weights[~mask_bottom] = torch.min(attention_mask)
|
||||
|
||||
# Heavy Hitter Mask (Based on global statistics)
|
||||
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
attn_weights.dtype
|
||||
)
|
||||
tmp_sum = torch.sum(tmp_attn, dim=-2)
|
||||
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)
|
||||
|
||||
zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
|
||||
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
|
||||
mask_bottom = mask_bottom.expand(
|
||||
mask_bottom.shape[0],
|
||||
mask_bottom.shape[1],
|
||||
attn_weights.shape[-2],
|
||||
mask_bottom.shape[-1],
|
||||
)
|
||||
|
||||
ones = torch.ones_like(attn_weights, dtype=torch.bool)
|
||||
ones = torch.tril(ones, diagonal=recent_budget)
|
||||
ones = torch.triu(ones, diagonal=-recent_budget)
|
||||
mask_bottom = torch.logical_or(mask_bottom, ones)
|
||||
# mask_bottom = ones
|
||||
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def convert_kvcache_llama_heavy_recent(model, config):
|
||||
for name, module in reversed(model._modules.items()):
|
||||
if len(list(module.children())) > 0:
|
||||
model._modules[name] = convert_kvcache_llama_heavy_recent(module, config)
|
||||
|
||||
if isinstance(module, LlamaAttention):
|
||||
model._modules[name] = LlamaAttention_heavy_hitter(config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
ENABLE_Heavy_Hitter_FUNCTIONS = {
|
||||
"llama": convert_kvcache_llama_heavy_recent,
|
||||
}
|
||||
|
||||
|
||||
def _set_model_kwargs_torch_dtype(model_kwargs):
|
||||
import torch
|
||||
|
||||
if "torch_dtype" not in model_kwargs:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = {
|
||||
"torch.float16": torch.float16,
|
||||
"torch.bfloat16": torch.bfloat16,
|
||||
"torch.float": torch.float,
|
||||
"auto": "auto",
|
||||
"None": None,
|
||||
}.get(model_kwargs["torch_dtype"])
|
||||
if torch_dtype is not None:
|
||||
model_kwargs["torch_dtype"] = torch_dtype
|
||||
return model_kwargs
|
||||
|
||||
|
||||
class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_kwargs: dict = dict(),
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
peft_path: Optional[str] = None,
|
||||
peft_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
generation_kwargs: dict = dict(),
|
||||
max_seq_len: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
stop_words: Optional[str] = [],
|
||||
drop_middle: bool = False,
|
||||
**other_kwargs,
|
||||
):
|
||||
self.heavy_ratio = other_kwargs["heavy_ratio"]
|
||||
self.recent_ratio = other_kwargs["recent_ratio"]
|
||||
super().__init__(
|
||||
path=path,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
peft_path=peft_path,
|
||||
peft_kwargs=peft_kwargs,
|
||||
tokenizer_only=tokenizer_only,
|
||||
generation_kwargs=generation_kwargs,
|
||||
max_seq_len=max_seq_len,
|
||||
pad_token_id=pad_token_id,
|
||||
stop_words=stop_words,
|
||||
drop_middle=drop_middle
|
||||
)
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
path: str,
|
||||
kwargs: dict,
|
||||
peft_path: Optional[str] = None,
|
||||
peft_kwargs: dict = dict(),
|
||||
):
|
||||
# self.logger("load modified model")
|
||||
# exit(0)
|
||||
from transformers import AutoModel, AutoModelForCausalLM
|
||||
|
||||
DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True)
|
||||
model_kwargs = DEFAULT_MODEL_KWARGS
|
||||
model_kwargs.update(kwargs)
|
||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
||||
self.logger.debug(f"using model_kwargs: {model_kwargs}")
|
||||
if is_npu_available():
|
||||
model_kwargs["device_map"] = "npu"
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
except ValueError:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
|
||||
if peft_path is not None:
|
||||
from peft import PeftModel
|
||||
|
||||
peft_kwargs["is_trainable"] = False
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
|
||||
|
||||
config = AutoConfig.from_pretrained(path)
|
||||
config.heavy_ratio = self.heavy_ratio
|
||||
self.model.eval()
|
||||
self.model.generation_config.do_sample = False
|
121
opencompass/configs/models/hf_llama/sparse_ben.py
Normal file
121
opencompass/configs/models/hf_llama/sparse_ben.py
Normal file
@ -0,0 +1,121 @@
|
||||
from opencompass.configs.models.hf_llama.modify_llama import H2OLLAMABenchmarkRunner
|
||||
from opencompass.datasets import InfiniteBenchcodedebugDataset, InfiniteBenchcoderunDataset, InfiniteBenchendiaDataset, InfiniteBenchenmcDataset, InfiniteBenchenqaDataset, InfiniteBenchensumDataset, InfiniteBenchmathcalcDataset, InfiniteBenchmathfindDataset, InfiniteBenchretrievekvDataset, InfiniteBenchretrievenumberDataset, InfiniteBenchretrievepasskeyDataset, InfiniteBenchzhqaDataset
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchcodedebug.infinitebench_codedebug_gen_276a42 import InfiniteBench_codedebug_reader_cfg, InfiniteBench_codedebug_infer_cfg, InfiniteBench_codedebug_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchcoderun.infinitebench_coderun_gen_1a76bd import InfiniteBench_coderun_reader_cfg, InfiniteBench_coderun_infer_cfg, InfiniteBench_coderun_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchendia.infinitebench_endia_gen_c96eb5 import InfiniteBench_endia_reader_cfg, InfiniteBench_endia_infer_cfg, InfiniteBench_endia_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchenmc.infinitebench_enmc_gen_3a4102 import InfiniteBench_enmc_reader_cfg, InfiniteBench_enmc_infer_cfg, InfiniteBench_enmc_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchenqa.infinitebench_enqa_gen_a1640c import InfiniteBench_enqa_reader_cfg, InfiniteBench_enqa_infer_cfg, InfiniteBench_enqa_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchensum.infinitebench_ensum_gen_cfbc08 import InfiniteBench_ensum_reader_cfg, InfiniteBench_ensum_infer_cfg, InfiniteBench_ensum_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchmathcalc.infinitebench_mathcalc_gen_78d17e import InfiniteBench_mathcalc_reader_cfg, InfiniteBench_mathcalc_infer_cfg, InfiniteBench_mathcalc_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchmathfind.infinitebench_mathfind_gen_6d799e import InfiniteBench_mathfind_reader_cfg, InfiniteBench_mathfind_infer_cfg, InfiniteBench_mathfind_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchretrievekv.infinitebench_retrievekv_gen_06b3ac import InfiniteBench_retrievekv_reader_cfg, InfiniteBench_retrievekv_infer_cfg, InfiniteBench_retrievekv_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchretrievepasskey.infinitebench_retrievepasskey_gen_62ff68 import InfiniteBench_retrievepasskey_reader_cfg, InfiniteBench_retrievepasskey_infer_cfg, InfiniteBench_retrievepasskey_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchretrievenumber.infinitebench_retrievenumber_gen_047436 import InfiniteBench_retrievenumber_reader_cfg, InfiniteBench_retrievenumber_infer_cfg, InfiniteBench_retrievenumber_eval_cfg
|
||||
from opencompass.configs.datasets.infinitebench.infinitebenchzhqa.infinitebench_zhqa_gen_1e5293 import InfiniteBench_zhqa_reader_cfg, InfiniteBench_zhqa_infer_cfg, InfiniteBench_zhqa_eval_cfg
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=H2OLLAMABenchmarkRunner,
|
||||
# for represent in result
|
||||
abbr="sparse-llama-7b",
|
||||
# for huggingface
|
||||
path="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
max_out_len=1024,
|
||||
batch_size=1,
|
||||
run_cfg=dict(
|
||||
num_gpus=1,
|
||||
heavy_ratio=0.5,
|
||||
),
|
||||
# modify here
|
||||
heavy_ratio=0.5,
|
||||
recent_ratio=0.1,
|
||||
)
|
||||
]
|
||||
datasets = [
|
||||
|
||||
dict(
|
||||
type=InfiniteBenchcodedebugDataset,
|
||||
abbr='InfiniteBench_codedebug',
|
||||
path='./data/InfiniteBench/code_debug.jsonl',
|
||||
reader_cfg=InfiniteBench_codedebug_reader_cfg,
|
||||
infer_cfg=InfiniteBench_codedebug_infer_cfg,
|
||||
eval_cfg=InfiniteBench_codedebug_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchcoderunDataset,
|
||||
abbr='InfiniteBench_coderun',
|
||||
path='./data/InfiniteBench/code_run.jsonl',
|
||||
reader_cfg=InfiniteBench_coderun_reader_cfg,
|
||||
infer_cfg=InfiniteBench_coderun_infer_cfg,
|
||||
eval_cfg=InfiniteBench_coderun_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchendiaDataset,
|
||||
abbr='InfiniteBench_endia',
|
||||
path='./data/InfiniteBench/longdialogue_qa_eng.jsonl',
|
||||
reader_cfg=InfiniteBench_endia_reader_cfg,
|
||||
infer_cfg=InfiniteBench_endia_infer_cfg,
|
||||
eval_cfg=InfiniteBench_endia_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchenmcDataset,
|
||||
abbr='InfiniteBench_enmc',
|
||||
path='./data/InfiniteBench/longbook_choice_eng.jsonl',
|
||||
reader_cfg=InfiniteBench_enmc_reader_cfg,
|
||||
infer_cfg=InfiniteBench_enmc_infer_cfg,
|
||||
eval_cfg=InfiniteBench_enmc_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchenqaDataset,
|
||||
abbr='InfiniteBench_enqa',
|
||||
path='./data/InfiniteBench/longbook_qa_eng.jsonl',
|
||||
reader_cfg=InfiniteBench_enqa_reader_cfg,
|
||||
infer_cfg=InfiniteBench_enqa_infer_cfg,
|
||||
eval_cfg=InfiniteBench_enqa_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchensumDataset,
|
||||
abbr='InfiniteBench_ensum',
|
||||
path='./data/InfiniteBench/longbook_sum_eng.jsonl',
|
||||
reader_cfg=InfiniteBench_ensum_reader_cfg,
|
||||
infer_cfg=InfiniteBench_ensum_infer_cfg,
|
||||
eval_cfg=InfiniteBench_ensum_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchmathcalcDataset,
|
||||
abbr='InfiniteBench_mathcalc',
|
||||
path='./data/InfiniteBench/math_calc.jsonl',
|
||||
reader_cfg=InfiniteBench_mathcalc_reader_cfg,
|
||||
infer_cfg=InfiniteBench_mathcalc_infer_cfg,
|
||||
eval_cfg=InfiniteBench_mathcalc_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchmathfindDataset,
|
||||
abbr='InfiniteBench_mathfind',
|
||||
path='./data/InfiniteBench/math_find.jsonl',
|
||||
reader_cfg=InfiniteBench_mathfind_reader_cfg,
|
||||
infer_cfg=InfiniteBench_mathfind_infer_cfg,
|
||||
eval_cfg=InfiniteBench_mathfind_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchretrievekvDataset,
|
||||
abbr='InfiniteBench_retrievekv',
|
||||
path='./data/InfiniteBench/kv_retrieval.jsonl',
|
||||
reader_cfg=InfiniteBench_retrievekv_reader_cfg,
|
||||
infer_cfg=InfiniteBench_retrievekv_infer_cfg,
|
||||
eval_cfg=InfiniteBench_retrievekv_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchretrievenumberDataset,
|
||||
abbr='InfiniteBench_retrievenumber',
|
||||
path='./data/InfiniteBench/number_string.jsonl',
|
||||
reader_cfg=InfiniteBench_retrievenumber_reader_cfg,
|
||||
infer_cfg=InfiniteBench_retrievenumber_infer_cfg,
|
||||
eval_cfg=InfiniteBench_retrievenumber_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchretrievepasskeyDataset,
|
||||
abbr='InfiniteBench_retrievepasskey',
|
||||
path='./data/InfiniteBench/passkey.jsonl',
|
||||
reader_cfg=InfiniteBench_retrievepasskey_reader_cfg,
|
||||
infer_cfg=InfiniteBench_retrievepasskey_infer_cfg,
|
||||
eval_cfg=InfiniteBench_retrievepasskey_eval_cfg),
|
||||
dict(
|
||||
type=InfiniteBenchzhqaDataset,
|
||||
abbr='InfiniteBench_zhqa',
|
||||
path='./data/InfiniteBench/longbook_qa_chn.jsonl',
|
||||
reader_cfg=InfiniteBench_zhqa_reader_cfg,
|
||||
infer_cfg=InfiniteBench_zhqa_infer_cfg,
|
||||
eval_cfg=InfiniteBench_zhqa_eval_cfg)
|
||||
]
|
@ -9,8 +9,10 @@ from opencompass.models.base import BaseModel, LMTemplateParser
|
||||
from opencompass.models.base_api import APITemplateParser
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
@ -506,7 +508,7 @@ def _convert_base_messages(inputs):
|
||||
return outputs
|
||||
|
||||
|
||||
class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
|
||||
class HuggingFaceBaseModel(HuggingFaceCausalLM):
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
@ -529,7 +531,8 @@ class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
|
||||
self.template_parser = LMTemplateParser()
|
||||
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
||||
self.drop_middle = drop_middle
|
||||
self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id)
|
||||
self.pad_token_id = None
|
||||
self._load_tokenizer(tokenizer_path or path, None, tokenizer_kwargs)
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
||||
self.generation_kwargs = generation_kwargs
|
||||
|
@ -92,13 +92,14 @@ class OpenICLInferTask(BaseTask):
|
||||
self.logger.info(
|
||||
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
|
||||
|
||||
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
|
||||
|
||||
assert 'ice_template' in self.infer_cfg or 'prompt_template' in self.infer_cfg, \
|
||||
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||
if hasattr(self.infer_cfg, 'ice_template'):
|
||||
if 'ice_template' in self.infer_cfg:
|
||||
ice_template = ICL_PROMPT_TEMPLATES.build(
|
||||
self.infer_cfg['ice_template'])
|
||||
|
||||
if hasattr(self.infer_cfg, 'prompt_template'):
|
||||
if 'prompt_template' in self.infer_cfg:
|
||||
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
||||
self.infer_cfg['prompt_template'])
|
||||
|
||||
@ -123,14 +124,14 @@ class OpenICLInferTask(BaseTask):
|
||||
out_dir, out_file = osp.split(out_path)
|
||||
mkdir_or_exist(out_dir)
|
||||
|
||||
if hasattr(self.infer_cfg, 'prompt_template') and \
|
||||
hasattr(self.infer_cfg, 'ice_template'):
|
||||
if 'prompt_template' in self.infer_cfg and \
|
||||
'ice_template' in self.infer_cfg:
|
||||
inferencer.inference(retriever,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template,
|
||||
output_json_filepath=out_dir,
|
||||
output_json_filename=out_file)
|
||||
elif hasattr(self.infer_cfg, 'prompt_template'):
|
||||
elif "prompt_template" in self.infer_cfg:
|
||||
inferencer.inference(retriever,
|
||||
prompt_template=prompt_template,
|
||||
output_json_filepath=out_dir,
|
||||
|
Loading…
Reference in New Issue
Block a user