diff --git a/opencompass/cli/main.py b/opencompass/cli/main.py index 21308e10..647fff73 100644 --- a/opencompass/cli/main.py +++ b/opencompass/cli/main.py @@ -258,7 +258,7 @@ def main(): cfg.dump(output_config_path) # Config is intentally reloaded here to avoid initialized # types cannot be serialized - cfg = Config.fromfile(output_config_path, format_python_code=False) + # cfg = Config.fromfile(output_config_path, format_python_code=False) # report to lark bot if specify --lark if not args.lark: diff --git a/opencompass/configs/models/hf_llama/modify_llama.py b/opencompass/configs/models/hf_llama/modify_llama.py new file mode 100644 index 00000000..1c8ff4c5 --- /dev/null +++ b/opencompass/configs/models/hf_llama/modify_llama.py @@ -0,0 +1,352 @@ +import os +import pdb +import copy +import math +import numpy as np +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn +import torch.utils.checkpoint +import torch.nn.functional as F +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from typing import Optional + +from opencompass.models import HuggingFaceBaseModel +from opencompass.configs.datasets.infinitebench.infinitebench import ( + infinitebench_datasets, +) +from transformers import AutoConfig + +from mmengine.device import is_npu_available + + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaRotaryEmbedding, + LlamaAttention, + apply_rotary_pos_emb, +) + + +__all__ = ["convert_kvcache_llama_heavy_recent", "LlamaAttention_heavy_hitter"] + + +def local_heavy_hitter_mask(attn_weights, heavy_budget): + # attn_weights (BS, head, query, keys) + dtype_attn_weights = attn_weights.dtype + seq_length = attn_weights.shape[-1] + padding_length = 0 + + offset = torch.finfo(attn_weights.dtype).min + tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + dtype_attn_weights + ) + + accumulated_attention_score = torch.sum( + tmp_attn[:, :, padding_length : heavy_budget + padding_length, :], dim=-2 + ) # (head, keys) + accumulated_attention_score[:, :, heavy_budget + padding_length :] = 0 + accumulated_attention_score[:, :, :padding_length] = 0 + + mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + mask_bottom[ + :, + :, + padding_length : heavy_budget + padding_length, + padding_length : heavy_budget + padding_length, + ] = True + + for token_index in range(heavy_budget + padding_length, seq_length): + tmp_attn_index = nn.functional.softmax( + attn_weights[:, :, token_index, :], dim=-1, dtype=torch.float32 + ).to(dtype_attn_weights) + _, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget - 1, dim=-1) + zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool) + mask_bottom_index = zeros_index.scatter( + -1, tmp_topk_index, True + ) # (head, keys) + mask_bottom_index[:, :, token_index] = True + + mask_bottom[:, :, token_index, :] = mask_bottom_index + accumulated_attention_score += tmp_attn_index + accumulated_attention_score = accumulated_attention_score * mask_bottom_index + + return mask_bottom + + +class LlamaAttention_heavy_hitter(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self.rotary_emb = LlamaRotaryEmbedding( + self.config + ) + + self.heavy_budget_ratio = config.heavy_ratio + self.recent_budget_ratio = config.recent_ratio + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + ### Heavy + Recent + heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1]) + recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1]) + + # # Heavy Hitter Mask (Based on local statistics) + # if heavy_budget > 0: + # mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input + # else: + # mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + + # ones = torch.ones_like(attn_weights, dtype=torch.bool) + # ones = torch.triu(ones, diagonal=-recent_budget) + # mask_bottom = torch.logical_or(mask_bottom, ones) + + # mask_bottom = torch.tril(mask_bottom, diagonal=0) + + # # mask_bottom = ones + # attn_weights[~mask_bottom] = torch.min(attention_mask) + + # Heavy Hitter Mask (Based on global statistics) + tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + attn_weights.dtype + ) + tmp_sum = torch.sum(tmp_attn, dim=-2) + _, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1) + + zeros = torch.zeros_like(tmp_sum, dtype=torch.bool) + mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2) + mask_bottom = mask_bottom.expand( + mask_bottom.shape[0], + mask_bottom.shape[1], + attn_weights.shape[-2], + mask_bottom.shape[-1], + ) + + ones = torch.ones_like(attn_weights, dtype=torch.bool) + ones = torch.tril(ones, diagonal=recent_budget) + ones = torch.triu(ones, diagonal=-recent_budget) + mask_bottom = torch.logical_or(mask_bottom, ones) + # mask_bottom = ones + attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def convert_kvcache_llama_heavy_recent(model, config): + for name, module in reversed(model._modules.items()): + if len(list(module.children())) > 0: + model._modules[name] = convert_kvcache_llama_heavy_recent(module, config) + + if isinstance(module, LlamaAttention): + model._modules[name] = LlamaAttention_heavy_hitter(config) + + return model + + +ENABLE_Heavy_Hitter_FUNCTIONS = { + "llama": convert_kvcache_llama_heavy_recent, +} + + +def _set_model_kwargs_torch_dtype(model_kwargs): + import torch + + if "torch_dtype" not in model_kwargs: + torch_dtype = torch.float16 + else: + torch_dtype = { + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.float": torch.float, + "auto": "auto", + "None": None, + }.get(model_kwargs["torch_dtype"]) + if torch_dtype is not None: + model_kwargs["torch_dtype"] = torch_dtype + return model_kwargs + + +class H2OLLAMABenchmarkRunner(HuggingFaceBaseModel): + def __init__( + self, + path: str, + model_kwargs: dict = dict(), + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + peft_path: Optional[str] = None, + peft_kwargs: dict = dict(), + tokenizer_only: bool = False, + generation_kwargs: dict = dict(), + max_seq_len: Optional[int] = None, + pad_token_id: Optional[int] = None, + stop_words: Optional[str] = [], + drop_middle: bool = False, + **other_kwargs, + ): + self.heavy_ratio = other_kwargs["heavy_ratio"] + self.recent_ratio = other_kwargs["recent_ratio"] + super().__init__( + path=path, + model_kwargs=model_kwargs, + tokenizer_path=tokenizer_path, + tokenizer_kwargs=tokenizer_kwargs, + peft_path=peft_path, + peft_kwargs=peft_kwargs, + tokenizer_only=tokenizer_only, + generation_kwargs=generation_kwargs, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + stop_words=stop_words, + drop_middle=drop_middle + ) + + def _load_model( + self, + path: str, + kwargs: dict, + peft_path: Optional[str] = None, + peft_kwargs: dict = dict(), + ): + # self.logger("load modified model") + # exit(0) + from transformers import AutoModel, AutoModelForCausalLM + + DEFAULT_MODEL_KWARGS = dict(device_map="auto", trust_remote_code=True) + model_kwargs = DEFAULT_MODEL_KWARGS + model_kwargs.update(kwargs) + model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) + self.logger.debug(f"using model_kwargs: {model_kwargs}") + if is_npu_available(): + model_kwargs["device_map"] = "npu" + + try: + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) + except ValueError: + self.model = AutoModel.from_pretrained(path, **model_kwargs) + + if peft_path is not None: + from peft import PeftModel + + peft_kwargs["is_trainable"] = False + self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs) + + config = AutoConfig.from_pretrained(path) + config.heavy_ratio = self.heavy_ratio + self.model.eval() + self.model.generation_config.do_sample = False diff --git a/opencompass/configs/models/hf_llama/sparse_ben.py b/opencompass/configs/models/hf_llama/sparse_ben.py new file mode 100644 index 00000000..10fd72f5 --- /dev/null +++ b/opencompass/configs/models/hf_llama/sparse_ben.py @@ -0,0 +1,121 @@ +from opencompass.configs.models.hf_llama.modify_llama import H2OLLAMABenchmarkRunner +from opencompass.datasets import InfiniteBenchcodedebugDataset, InfiniteBenchcoderunDataset, InfiniteBenchendiaDataset, InfiniteBenchenmcDataset, InfiniteBenchenqaDataset, InfiniteBenchensumDataset, InfiniteBenchmathcalcDataset, InfiniteBenchmathfindDataset, InfiniteBenchretrievekvDataset, InfiniteBenchretrievenumberDataset, InfiniteBenchretrievepasskeyDataset, InfiniteBenchzhqaDataset +from opencompass.configs.datasets.infinitebench.infinitebenchcodedebug.infinitebench_codedebug_gen_276a42 import InfiniteBench_codedebug_reader_cfg, InfiniteBench_codedebug_infer_cfg, InfiniteBench_codedebug_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchcoderun.infinitebench_coderun_gen_1a76bd import InfiniteBench_coderun_reader_cfg, InfiniteBench_coderun_infer_cfg, InfiniteBench_coderun_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchendia.infinitebench_endia_gen_c96eb5 import InfiniteBench_endia_reader_cfg, InfiniteBench_endia_infer_cfg, InfiniteBench_endia_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchenmc.infinitebench_enmc_gen_3a4102 import InfiniteBench_enmc_reader_cfg, InfiniteBench_enmc_infer_cfg, InfiniteBench_enmc_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchenqa.infinitebench_enqa_gen_a1640c import InfiniteBench_enqa_reader_cfg, InfiniteBench_enqa_infer_cfg, InfiniteBench_enqa_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchensum.infinitebench_ensum_gen_cfbc08 import InfiniteBench_ensum_reader_cfg, InfiniteBench_ensum_infer_cfg, InfiniteBench_ensum_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchmathcalc.infinitebench_mathcalc_gen_78d17e import InfiniteBench_mathcalc_reader_cfg, InfiniteBench_mathcalc_infer_cfg, InfiniteBench_mathcalc_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchmathfind.infinitebench_mathfind_gen_6d799e import InfiniteBench_mathfind_reader_cfg, InfiniteBench_mathfind_infer_cfg, InfiniteBench_mathfind_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchretrievekv.infinitebench_retrievekv_gen_06b3ac import InfiniteBench_retrievekv_reader_cfg, InfiniteBench_retrievekv_infer_cfg, InfiniteBench_retrievekv_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchretrievepasskey.infinitebench_retrievepasskey_gen_62ff68 import InfiniteBench_retrievepasskey_reader_cfg, InfiniteBench_retrievepasskey_infer_cfg, InfiniteBench_retrievepasskey_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchretrievenumber.infinitebench_retrievenumber_gen_047436 import InfiniteBench_retrievenumber_reader_cfg, InfiniteBench_retrievenumber_infer_cfg, InfiniteBench_retrievenumber_eval_cfg +from opencompass.configs.datasets.infinitebench.infinitebenchzhqa.infinitebench_zhqa_gen_1e5293 import InfiniteBench_zhqa_reader_cfg, InfiniteBench_zhqa_infer_cfg, InfiniteBench_zhqa_eval_cfg + + +models = [ + dict( + type=H2OLLAMABenchmarkRunner, + # for represent in result + abbr="sparse-llama-7b", + # for huggingface + path="meta-llama/Meta-Llama-3.1-8B-Instruct", + max_out_len=1024, + batch_size=1, + run_cfg=dict( + num_gpus=1, + heavy_ratio=0.5, + ), + # modify here + heavy_ratio=0.5, + recent_ratio=0.1, + ) +] +datasets = [ + + dict( + type=InfiniteBenchcodedebugDataset, + abbr='InfiniteBench_codedebug', + path='./data/InfiniteBench/code_debug.jsonl', + reader_cfg=InfiniteBench_codedebug_reader_cfg, + infer_cfg=InfiniteBench_codedebug_infer_cfg, + eval_cfg=InfiniteBench_codedebug_eval_cfg), + dict( + type=InfiniteBenchcoderunDataset, + abbr='InfiniteBench_coderun', + path='./data/InfiniteBench/code_run.jsonl', + reader_cfg=InfiniteBench_coderun_reader_cfg, + infer_cfg=InfiniteBench_coderun_infer_cfg, + eval_cfg=InfiniteBench_coderun_eval_cfg), + dict( + type=InfiniteBenchendiaDataset, + abbr='InfiniteBench_endia', + path='./data/InfiniteBench/longdialogue_qa_eng.jsonl', + reader_cfg=InfiniteBench_endia_reader_cfg, + infer_cfg=InfiniteBench_endia_infer_cfg, + eval_cfg=InfiniteBench_endia_eval_cfg), + dict( + type=InfiniteBenchenmcDataset, + abbr='InfiniteBench_enmc', + path='./data/InfiniteBench/longbook_choice_eng.jsonl', + reader_cfg=InfiniteBench_enmc_reader_cfg, + infer_cfg=InfiniteBench_enmc_infer_cfg, + eval_cfg=InfiniteBench_enmc_eval_cfg), + dict( + type=InfiniteBenchenqaDataset, + abbr='InfiniteBench_enqa', + path='./data/InfiniteBench/longbook_qa_eng.jsonl', + reader_cfg=InfiniteBench_enqa_reader_cfg, + infer_cfg=InfiniteBench_enqa_infer_cfg, + eval_cfg=InfiniteBench_enqa_eval_cfg), + dict( + type=InfiniteBenchensumDataset, + abbr='InfiniteBench_ensum', + path='./data/InfiniteBench/longbook_sum_eng.jsonl', + reader_cfg=InfiniteBench_ensum_reader_cfg, + infer_cfg=InfiniteBench_ensum_infer_cfg, + eval_cfg=InfiniteBench_ensum_eval_cfg), + dict( + type=InfiniteBenchmathcalcDataset, + abbr='InfiniteBench_mathcalc', + path='./data/InfiniteBench/math_calc.jsonl', + reader_cfg=InfiniteBench_mathcalc_reader_cfg, + infer_cfg=InfiniteBench_mathcalc_infer_cfg, + eval_cfg=InfiniteBench_mathcalc_eval_cfg), + dict( + type=InfiniteBenchmathfindDataset, + abbr='InfiniteBench_mathfind', + path='./data/InfiniteBench/math_find.jsonl', + reader_cfg=InfiniteBench_mathfind_reader_cfg, + infer_cfg=InfiniteBench_mathfind_infer_cfg, + eval_cfg=InfiniteBench_mathfind_eval_cfg), + dict( + type=InfiniteBenchretrievekvDataset, + abbr='InfiniteBench_retrievekv', + path='./data/InfiniteBench/kv_retrieval.jsonl', + reader_cfg=InfiniteBench_retrievekv_reader_cfg, + infer_cfg=InfiniteBench_retrievekv_infer_cfg, + eval_cfg=InfiniteBench_retrievekv_eval_cfg), + dict( + type=InfiniteBenchretrievenumberDataset, + abbr='InfiniteBench_retrievenumber', + path='./data/InfiniteBench/number_string.jsonl', + reader_cfg=InfiniteBench_retrievenumber_reader_cfg, + infer_cfg=InfiniteBench_retrievenumber_infer_cfg, + eval_cfg=InfiniteBench_retrievenumber_eval_cfg), + dict( + type=InfiniteBenchretrievepasskeyDataset, + abbr='InfiniteBench_retrievepasskey', + path='./data/InfiniteBench/passkey.jsonl', + reader_cfg=InfiniteBench_retrievepasskey_reader_cfg, + infer_cfg=InfiniteBench_retrievepasskey_infer_cfg, + eval_cfg=InfiniteBench_retrievepasskey_eval_cfg), + dict( + type=InfiniteBenchzhqaDataset, + abbr='InfiniteBench_zhqa', + path='./data/InfiniteBench/longbook_qa_chn.jsonl', + reader_cfg=InfiniteBench_zhqa_reader_cfg, + infer_cfg=InfiniteBench_zhqa_infer_cfg, + eval_cfg=InfiniteBench_zhqa_eval_cfg) +] diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 5cd38b4a..1898f449 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -9,8 +9,10 @@ from opencompass.models.base import BaseModel, LMTemplateParser from opencompass.models.base_api import APITemplateParser from opencompass.registry import MODELS from opencompass.utils.logging import get_logger +from opencompass.models import HuggingFaceCausalLM from opencompass.utils.prompt import PromptList + PromptType = Union[PromptList, str] @@ -506,7 +508,7 @@ def _convert_base_messages(inputs): return outputs -class HuggingFaceBaseModel(HuggingFacewithChatTemplate): +class HuggingFaceBaseModel(HuggingFaceCausalLM): def __init__(self, path: str, @@ -529,7 +531,8 @@ class HuggingFaceBaseModel(HuggingFacewithChatTemplate): self.template_parser = LMTemplateParser() self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) self.drop_middle = drop_middle - self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id) + self.pad_token_id = None + self._load_tokenizer(tokenizer_path or path, None, tokenizer_kwargs) if not tokenizer_only: self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self.generation_kwargs = generation_kwargs diff --git a/opencompass/tasks/openicl_infer.py b/opencompass/tasks/openicl_infer.py index 1c89c305..b30523b1 100644 --- a/opencompass/tasks/openicl_infer.py +++ b/opencompass/tasks/openicl_infer.py @@ -92,13 +92,14 @@ class OpenICLInferTask(BaseTask): self.logger.info( f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}') - assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \ + + assert 'ice_template' in self.infer_cfg or 'prompt_template' in self.infer_cfg, \ 'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501 - if hasattr(self.infer_cfg, 'ice_template'): + if 'ice_template' in self.infer_cfg: ice_template = ICL_PROMPT_TEMPLATES.build( self.infer_cfg['ice_template']) - if hasattr(self.infer_cfg, 'prompt_template'): + if 'prompt_template' in self.infer_cfg: prompt_template = ICL_PROMPT_TEMPLATES.build( self.infer_cfg['prompt_template']) @@ -123,15 +124,15 @@ class OpenICLInferTask(BaseTask): out_dir, out_file = osp.split(out_path) mkdir_or_exist(out_dir) - if hasattr(self.infer_cfg, 'prompt_template') and \ - hasattr(self.infer_cfg, 'ice_template'): + if 'prompt_template' in self.infer_cfg and \ + 'ice_template' in self.infer_cfg: inferencer.inference(retriever, ice_template=ice_template, prompt_template=prompt_template, output_json_filepath=out_dir, output_json_filename=out_file) - elif hasattr(self.infer_cfg, 'prompt_template'): - inferencer.inference(retriever, + elif "prompt_template" in self.infer_cfg: + inferencer.inference(retriever, prompt_template=prompt_template, output_json_filepath=out_dir, output_json_filename=out_file)