add h2o@infinitebench implementation

This commit is contained in:
ziyang zhang 2025-02-27 14:22:07 +08:00
parent bdb2d46f59
commit 61afff8836
5 changed files with 487 additions and 10 deletions

View File

@ -258,7 +258,7 @@ def main():
cfg.dump(output_config_path) cfg.dump(output_config_path)
# Config is intentally reloaded here to avoid initialized # Config is intentally reloaded here to avoid initialized
# types cannot be serialized # types cannot be serialized
cfg = Config.fromfile(output_config_path, format_python_code=False) # cfg = Config.fromfile(output_config_path, format_python_code=False)
# report to lark bot if specify --lark # report to lark bot if specify --lark
if not args.lark: if not args.lark:

View File

@ -0,0 +1,352 @@
import os
import pdb
import copy
import math
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
import torch.utils.checkpoint
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional
from opencompass.models import HuggingFaceBaseModel
from opencompass.configs.datasets.infinitebench.infinitebench import (
infinitebench_datasets,
)
from transformers import AutoConfig
from mmengine.device import is_npu_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
LlamaAttention,
apply_rotary_pos_emb,
)
__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"]
def local_heavy_hitter_mask(attn_weights, heavy_budget):
# attn_weights (BS, head, query, keys)
dtype_attn_weights = attn_weights.dtype
seq_length = attn_weights.shape[-1]
padding_length = 0
offset = torch.finfo(attn_weights.dtype).min
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
dtype_attn_weights
)
accumulated_attention_score = torch.sum(
tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2
) # (head, keys)
accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0
accumulated_attention_score[:, :, :padding_length] = 0
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
mask_bottom[
:,
:,
padding_length : heavy_budget + padding_length,
padding_length : heavy_budget + padding_length,
] = True
for token_index in range(heavy_budget + padding_length, seq_length):
tmp_attn_index = nn.functional.softmax(
attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32
).to(dtype_attn_weights)
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1)
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
mask_bottom_index = zeros_index.scatter(
-1, tmp_topk_index, True
) # (head, keys)
mask_bottom_index[:, :, token_index] = True
mask_bottom[:, :, token_index, :] = mask_bottom_index
accumulated_attention_score += tmp_attn_index
accumulated_attention_score = accumulated_attention_score * mask_bottom_index
return mask_bottom
class LlamaAttention_heavy_hitter(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rotary_emb = LlamaRotaryEmbedding(
self.config
)
self.heavy_budget_ratio = config.heavy_ratio
self.recent_budget_ratio = config.recent_ratio
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
### Heavy + Recent
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])
# # Heavy Hitter Mask (Based on local statistics)
# if heavy_budget > 0:
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
# else:
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
# ones = torch.ones_like(attn_weights, dtype=torch.bool)
# ones = torch.triu(ones, diagonal=-recent_budget)
# mask_bottom = torch.logical_or(mask_bottom, ones)
# mask_bottom = torch.tril(mask_bottom, diagonal=0)
# # mask_bottom = ones
# attn_weights[~mask_bottom] = torch.min(attention_mask)
# Heavy Hitter Mask (Based on global statistics)
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
attn_weights.dtype
)
tmp_sum = torch.sum(tmp_attn, dim=-2)
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)
zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
mask_bottom = mask_bottom.expand(
mask_bottom.shape[0],
mask_bottom.shape[1],
attn_weights.shape[-2],
mask_bottom.shape[-1],
)
ones = torch.ones_like(attn_weights, dtype=torch.bool)
ones = torch.tril(ones, diagonal=recent_budget)
ones = torch.triu(ones, diagonal=-recent_budget)
mask_bottom = torch.logical_or(mask_bottom, ones)
# mask_bottom = ones
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def convert_kvcache_llama_heavy_recent(model, config):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
model._modules[name] = convert_kvcache_llama_heavy_recent(module, config)
if isinstance(module, LlamaAttention):
model._modules[name] = LlamaAttention_heavy_hitter(config)
return model
ENABLE_Heavy_Hitter_FUNCTIONS = {
"llama": convert_kvcache_llama_heavy_recent,
}
def _set_model_kwargs_torch_dtype(model_kwargs):
import torch
if "torch_dtype" not in model_kwargs:
torch_dtype = torch.float16
else:
torch_dtype = {
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.float": torch.float,
"auto": "auto",
"None": None,
}.get(model_kwargs["torch_dtype"])
if torch_dtype is not None:
model_kwargs["torch_dtype"] = torch_dtype
return model_kwargs
class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel):
def __init__(
self,
path: str,
model_kwargs: dict = dict(),
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
peft_path: Optional[str] = None,
peft_kwargs: dict = dict(),
tokenizer_only: bool = False,
generation_kwargs: dict = dict(),
max_seq_len: Optional[int] = None,
pad_token_id: Optional[int] = None,
stop_words: Optional[str] = [],
drop_middle: bool = False,
**other_kwargs,
):
self.heavy_ratio = other_kwargs["heavy_ratio"]
self.recent_ratio = other_kwargs["recent_ratio"]
super().__init__(
path=path,
model_kwargs=model_kwargs,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs,
peft_path=peft_path,
peft_kwargs=peft_kwargs,
tokenizer_only=tokenizer_only,
generation_kwargs=generation_kwargs,
max_seq_len=max_seq_len,
pad_token_id=pad_token_id,
stop_words=stop_words,
drop_middle=drop_middle
)
def _load_model(
self,
path: str,
kwargs: dict,
peft_path: Optional[str] = None,
peft_kwargs: dict = dict(),
):
# self.logger("load modified model")
# exit(0)
from transformers import AutoModel, AutoModelForCausalLM
DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True)
model_kwargs = DEFAULT_MODEL_KWARGS
model_kwargs.update(kwargs)
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
self.logger.debug(f"using model_kwargs: {model_kwargs}")
if is_npu_available():
model_kwargs["device_map"] = "npu"
try:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
except ValueError:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
if peft_path is not None:
from peft import PeftModel
peft_kwargs["is_trainable"] = False
self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs)
config = AutoConfig.from_pretrained(path)
config.heavy_ratio = self.heavy_ratio
self.model.eval()
self.model.generation_config.do_sample = False

View File

@ -0,0 +1,121 @@
from opencompass.configs.models.hf_llama.modify_llama import H2OLLAMABenchmarkRunner
from opencompass.datasets import InfiniteBenchcodedebugDataset, InfiniteBenchcoderunDataset, InfiniteBenchendiaDataset, InfiniteBenchenmcDataset, InfiniteBenchenqaDataset, InfiniteBenchensumDataset, InfiniteBenchmathcalcDataset, InfiniteBenchmathfindDataset, InfiniteBenchretrievekvDataset, InfiniteBenchretrievenumberDataset, InfiniteBenchretrievepasskeyDataset, InfiniteBenchzhqaDataset
from opencompass.configs.datasets.infinitebench.infinitebenchcodedebug.infinitebench_codedebug_gen_276a42 import InfiniteBench_codedebug_reader_cfg, InfiniteBench_codedebug_infer_cfg, InfiniteBench_codedebug_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchcoderun.infinitebench_coderun_gen_1a76bd import InfiniteBench_coderun_reader_cfg, InfiniteBench_coderun_infer_cfg, InfiniteBench_coderun_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchendia.infinitebench_endia_gen_c96eb5 import InfiniteBench_endia_reader_cfg, InfiniteBench_endia_infer_cfg, InfiniteBench_endia_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchenmc.infinitebench_enmc_gen_3a4102 import InfiniteBench_enmc_reader_cfg, InfiniteBench_enmc_infer_cfg, InfiniteBench_enmc_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchenqa.infinitebench_enqa_gen_a1640c import InfiniteBench_enqa_reader_cfg, InfiniteBench_enqa_infer_cfg, InfiniteBench_enqa_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchensum.infinitebench_ensum_gen_cfbc08 import InfiniteBench_ensum_reader_cfg, InfiniteBench_ensum_infer_cfg, InfiniteBench_ensum_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchmathcalc.infinitebench_mathcalc_gen_78d17e import InfiniteBench_mathcalc_reader_cfg, InfiniteBench_mathcalc_infer_cfg, InfiniteBench_mathcalc_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchmathfind.infinitebench_mathfind_gen_6d799e import InfiniteBench_mathfind_reader_cfg, InfiniteBench_mathfind_infer_cfg, InfiniteBench_mathfind_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchretrievekv.infinitebench_retrievekv_gen_06b3ac import InfiniteBench_retrievekv_reader_cfg, InfiniteBench_retrievekv_infer_cfg, InfiniteBench_retrievekv_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchretrievepasskey.infinitebench_retrievepasskey_gen_62ff68 import InfiniteBench_retrievepasskey_reader_cfg, InfiniteBench_retrievepasskey_infer_cfg, InfiniteBench_retrievepasskey_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchretrievenumber.infinitebench_retrievenumber_gen_047436 import InfiniteBench_retrievenumber_reader_cfg, InfiniteBench_retrievenumber_infer_cfg, InfiniteBench_retrievenumber_eval_cfg
from opencompass.configs.datasets.infinitebench.infinitebenchzhqa.infinitebench_zhqa_gen_1e5293 import InfiniteBench_zhqa_reader_cfg, InfiniteBench_zhqa_infer_cfg, InfiniteBench_zhqa_eval_cfg
models = [
dict(
type=H2OLLAMABenchmarkRunner,
# for represent in result
abbr="sparse-llama-7b",
# for huggingface
path="meta-llama/Meta-Llama-3.1-8B-Instruct",
max_out_len=1024,
batch_size=1,
run_cfg=dict(
num_gpus=1,
heavy_ratio=0.5,
),
# modify here
heavy_ratio=0.5,
recent_ratio=0.1,
)
]
datasets = [
dict(
type=InfiniteBenchcodedebugDataset,
abbr='InfiniteBench_codedebug',
path='./data/InfiniteBench/code_debug.jsonl',
reader_cfg=InfiniteBench_codedebug_reader_cfg,
infer_cfg=InfiniteBench_codedebug_infer_cfg,
eval_cfg=InfiniteBench_codedebug_eval_cfg),
dict(
type=InfiniteBenchcoderunDataset,
abbr='InfiniteBench_coderun',
path='./data/InfiniteBench/code_run.jsonl',
reader_cfg=InfiniteBench_coderun_reader_cfg,
infer_cfg=InfiniteBench_coderun_infer_cfg,
eval_cfg=InfiniteBench_coderun_eval_cfg),
dict(
type=InfiniteBenchendiaDataset,
abbr='InfiniteBench_endia',
path='./data/InfiniteBench/longdialogue_qa_eng.jsonl',
reader_cfg=InfiniteBench_endia_reader_cfg,
infer_cfg=InfiniteBench_endia_infer_cfg,
eval_cfg=InfiniteBench_endia_eval_cfg),
dict(
type=InfiniteBenchenmcDataset,
abbr='InfiniteBench_enmc',
path='./data/InfiniteBench/longbook_choice_eng.jsonl',
reader_cfg=InfiniteBench_enmc_reader_cfg,
infer_cfg=InfiniteBench_enmc_infer_cfg,
eval_cfg=InfiniteBench_enmc_eval_cfg),
dict(
type=InfiniteBenchenqaDataset,
abbr='InfiniteBench_enqa',
path='./data/InfiniteBench/longbook_qa_eng.jsonl',
reader_cfg=InfiniteBench_enqa_reader_cfg,
infer_cfg=InfiniteBench_enqa_infer_cfg,
eval_cfg=InfiniteBench_enqa_eval_cfg),
dict(
type=InfiniteBenchensumDataset,
abbr='InfiniteBench_ensum',
path='./data/InfiniteBench/longbook_sum_eng.jsonl',
reader_cfg=InfiniteBench_ensum_reader_cfg,
infer_cfg=InfiniteBench_ensum_infer_cfg,
eval_cfg=InfiniteBench_ensum_eval_cfg),
dict(
type=InfiniteBenchmathcalcDataset,
abbr='InfiniteBench_mathcalc',
path='./data/InfiniteBench/math_calc.jsonl',
reader_cfg=InfiniteBench_mathcalc_reader_cfg,
infer_cfg=InfiniteBench_mathcalc_infer_cfg,
eval_cfg=InfiniteBench_mathcalc_eval_cfg),
dict(
type=InfiniteBenchmathfindDataset,
abbr='InfiniteBench_mathfind',
path='./data/InfiniteBench/math_find.jsonl',
reader_cfg=InfiniteBench_mathfind_reader_cfg,
infer_cfg=InfiniteBench_mathfind_infer_cfg,
eval_cfg=InfiniteBench_mathfind_eval_cfg),
dict(
type=InfiniteBenchretrievekvDataset,
abbr='InfiniteBench_retrievekv',
path='./data/InfiniteBench/kv_retrieval.jsonl',
reader_cfg=InfiniteBench_retrievekv_reader_cfg,
infer_cfg=InfiniteBench_retrievekv_infer_cfg,
eval_cfg=InfiniteBench_retrievekv_eval_cfg),
dict(
type=InfiniteBenchretrievenumberDataset,
abbr='InfiniteBench_retrievenumber',
path='./data/InfiniteBench/number_string.jsonl',
reader_cfg=InfiniteBench_retrievenumber_reader_cfg,
infer_cfg=InfiniteBench_retrievenumber_infer_cfg,
eval_cfg=InfiniteBench_retrievenumber_eval_cfg),
dict(
type=InfiniteBenchretrievepasskeyDataset,
abbr='InfiniteBench_retrievepasskey',
path='./data/InfiniteBench/passkey.jsonl',
reader_cfg=InfiniteBench_retrievepasskey_reader_cfg,
infer_cfg=InfiniteBench_retrievepasskey_infer_cfg,
eval_cfg=InfiniteBench_retrievepasskey_eval_cfg),
dict(
type=InfiniteBenchzhqaDataset,
abbr='InfiniteBench_zhqa',
path='./data/InfiniteBench/longbook_qa_chn.jsonl',
reader_cfg=InfiniteBench_zhqa_reader_cfg,
infer_cfg=InfiniteBench_zhqa_infer_cfg,
eval_cfg=InfiniteBench_zhqa_eval_cfg)
]

View File

@ -9,8 +9,10 @@ from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.models.base_api import APITemplateParser from opencompass.models.base_api import APITemplateParser
from opencompass.registry import MODELS from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger from opencompass.utils.logging import get_logger
from opencompass.models import HuggingFaceCausalLM
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
@ -506,7 +508,7 @@ def _convert_base_messages(inputs):
return outputs return outputs
class HuggingFaceBaseModel(HuggingFacewithChatTemplate): class HuggingFaceBaseModel(HuggingFaceCausalLM):
def __init__(self, def __init__(self,
path: str, path: str,
@ -529,7 +531,8 @@ class HuggingFaceBaseModel(HuggingFacewithChatTemplate):
self.template_parser = LMTemplateParser() self.template_parser = LMTemplateParser()
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
self.drop_middle = drop_middle self.drop_middle = drop_middle
self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id) self.pad_token_id = None
self._load_tokenizer(tokenizer_path or path, None, tokenizer_kwargs)
if not tokenizer_only: if not tokenizer_only:
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs

View File

@ -92,13 +92,14 @@ class OpenICLInferTask(BaseTask):
self.logger.info( self.logger.info(
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}') f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
assert 'ice_template' in self.infer_cfg or 'prompt_template' in self.infer_cfg, \
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501 'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
if hasattr(self.infer_cfg, 'ice_template'): if 'ice_template' in self.infer_cfg:
ice_template = ICL_PROMPT_TEMPLATES.build( ice_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['ice_template']) self.infer_cfg['ice_template'])
if hasattr(self.infer_cfg, 'prompt_template'): if 'prompt_template' in self.infer_cfg:
prompt_template = ICL_PROMPT_TEMPLATES.build( prompt_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['prompt_template']) self.infer_cfg['prompt_template'])
@ -123,15 +124,15 @@ class OpenICLInferTask(BaseTask):
out_dir, out_file = osp.split(out_path) out_dir, out_file = osp.split(out_path)
mkdir_or_exist(out_dir) mkdir_or_exist(out_dir)
if hasattr(self.infer_cfg, 'prompt_template') and \ if 'prompt_template' in self.infer_cfg and \
hasattr(self.infer_cfg, 'ice_template'): 'ice_template' in self.infer_cfg:
inferencer.inference(retriever, inferencer.inference(retriever,
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template, prompt_template=prompt_template,
output_json_filepath=out_dir, output_json_filepath=out_dir,
output_json_filename=out_file) output_json_filename=out_file)
elif hasattr(self.infer_cfg, 'prompt_template'): elif "prompt_template" in self.infer_cfg:
inferencer.inference(retriever, inferencer.inference(retriever,
prompt_template=prompt_template, prompt_template=prompt_template,
output_json_filepath=out_dir, output_json_filepath=out_dir,
output_json_filename=out_file) output_json_filename=out_file)