diff --git a/configs/datasets/inference_ppl/README.md b/configs/datasets/inference_ppl/README.md new file mode 100644 index 00000000..bec80f52 --- /dev/null +++ b/configs/datasets/inference_ppl/README.md @@ -0,0 +1,26 @@ +# Inference-PPL Datasets + +- **Description**: Compute the loss only on the labeled positions, especially used for reasoning corpus. +- **Datasets**: cn-reasoning-val.jsonl (example datasets, inference-ppl can be generalized to more corpus). + +# PPL Computation + +$$ \text{ppl} = - \frac{1}{n} \sum_{i=0}^n \sum_{c=0}^{vocab\_size} y_{i,c} \log p_{i,c} \tag{1} $$ + +where Eq. (1) is the normal mean ppl computation formula, for inference-ppl, we only compute the average score based on pre-labeled position. + +# Quick Start + +```shell +cd opencompass +python run.py configs/eval_inference_ppl.py +``` + +# Some results + +| Model | Result | +| ----------- | ----------- | +| Qwen1.5-7b | 0.59 | +| Qwen1.5-14b | 0.54 | +| Llama2-7b | 0.49 | +| Llama2-13b | 0.43 | diff --git a/configs/datasets/inference_ppl/inference_ppl.py b/configs/datasets/inference_ppl/inference_ppl.py new file mode 100644 index 00000000..45b88595 --- /dev/null +++ b/configs/datasets/inference_ppl/inference_ppl.py @@ -0,0 +1,38 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import InferencePPLOnlyInferencer +from opencompass.openicl.icl_evaluator import AverageInferencePPLEvaluator + +from opencompass.datasets import InferencePPLDataset + +# Build InferencePPLDataset +inference_ppl_datasets = [] + +llm_cmp_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template='{text}', + ), + # No in-context example, using ZeroRetriever + retriever=dict(type=ZeroRetriever), + # compute inference-ppl + inferencer=dict(type=InferencePPLOnlyInferencer), +) + +# Average the inference-ppl scores +llm_cmp_eval_cfg = dict(evaluator=dict(type=AverageInferencePPLEvaluator)) + +inference_ppl_datasets.append( + dict( + abbr=f'inference-ppl', + type=InferencePPLDataset, + path='./data/inference_ppl', + name='cn-reasoning-val', + samples=None, # Set small samples for testing + reader_cfg=dict( + input_columns=['text'], + output_column=None, + ), + infer_cfg=llm_cmp_infer_cfg, + eval_cfg=llm_cmp_eval_cfg, + )) diff --git a/configs/eval_inference_ppl.py b/configs/eval_inference_ppl.py new file mode 100644 index 00000000..0e384cf0 --- /dev/null +++ b/configs/eval_inference_ppl.py @@ -0,0 +1,62 @@ +from mmengine.config import read_base + +with read_base(): + # Inference PPL datasets + from .datasets.inference_ppl.inference_ppl import inference_ppl_datasets + + # Model configs + from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b + from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b + from .models.hf_llama.hf_llama2_7b import models as llama2_7b + from .models.hf_llama.hf_llama2_13b import models as llama2_13b + + +from opencompass.partitioners import NaivePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask + + +# -------------Inference Stage ---------------------------------------- + +datasets = [*inference_ppl_datasets] +workdir = 'outputs/inference_ppl' + +models = [ + *qwen1_5_7b, + *qwen1_5_14b, + *llama2_7b, + *llama2_13b, +] + + + +# Set custom batch_size and num_gpus for faster loss calculation +# Smaller batch_size should give more precise results, at the cost of worse efficiency +model_cfg = dict( + batch_size=8, + run_cfg=dict(num_gpus=4, num_procs=1) +) + +for mdl in models: + mdl.update(model_cfg) + + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + task=dict(type=OpenICLInferTask), + max_num_workers=256, # Maximum concurrent evaluation task count + ), +) + + +# -------------Evaluation Stage ---------------------------------------- +eval = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + task=dict(type=OpenICLEvalTask), + max_num_workers=256, + ) +) diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index f08f6844..936ba45f 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -53,6 +53,7 @@ from .humaneval import * # noqa: F401, F403 from .humanevalx import * # noqa: F401, F403 from .hungarian_math import * # noqa: F401, F403 from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403 +from .inference_ppl import InferencePPLDataset # noqa: F401, F403 from .infinitebench import * # noqa: F401, F403 from .iwslt2017 import * # noqa: F401, F403 from .jigsawmultilingual import * # noqa: F401, F403 diff --git a/opencompass/datasets/inference_ppl.py b/opencompass/datasets/inference_ppl.py new file mode 100644 index 00000000..251bb682 --- /dev/null +++ b/opencompass/datasets/inference_ppl.py @@ -0,0 +1,37 @@ +import os.path as osp +from typing import List + +from datasets import load_dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class InferencePPLDataset(BaseDataset): + + @staticmethod + def load(path: str, name: List[str] = None, samples: int = None): + + # Check if file exists in the given path + supported_extensions = ['jsonl'] + for ext in supported_extensions: + filename = osp.join( + path, f'{name}.{ext}') # name refers to data subset name + + if osp.exists(filename): + break + else: + raise FileNotFoundError(f'{filename} not found.') + + samples = 'test' if samples is None else f'test[:{samples}]' + + data_files = {'test': filename} + + dataset = load_dataset('json', data_files=data_files, split=samples) + + # Filter out empty samples + dataset = dataset.filter(lambda example: len(example['text']) > 0) + + return dataset diff --git a/opencompass/models/base.py b/opencompass/models/base.py index 74236134..9e983f39 100644 --- a/opencompass/models/base.py +++ b/opencompass/models/base.py @@ -85,6 +85,28 @@ class BaseModel: ' ppl-based evaluation yet, try gen-based ' 'instead.') + @abstractmethod + def get_ppl_tokenwise( + self, + inputs: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Get tokenwise perplexity scores given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. It's okay to skip + its implementation if advanced features in PPLInfernecer is + not needed. + + Returns: + List[float]: A list of perplexity scores. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not support' + ' ppl-based evaluation yet, try gen-based ' + 'instead.') + @abstractmethod def encode(self, prompt: str) -> torch.Tensor: """Encode prompt to tokens. Not necessary for most cases. @@ -151,6 +173,20 @@ class BaseModel: inputs = self.parse_template(templates, mode='ppl') return self.get_ppl(inputs, mask_length) + def get_ppl_tokenwise_from_template(self, + templates: List[PromptType], + label: List[List[int]], + mask_length=None): + """Get token-wise perplexity given a list of templates. + + Args: + templates (List[PromptType]): A list of templates. + mask_length (List[int]): A list of mask lengths. If provided, the + perplexity will be calculated only on the unmasked tokens. + """ + inputs = self.parse_template(templates, mode='ppl') + return self.get_ppl_tokenwise(inputs, label, mask_length) + def generate_from_template(self, templates: List[PromptType], max_out_len: int, **kwargs): """Generate completion from a list of templates. diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 329ea2a3..23f3c830 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -226,6 +226,165 @@ class HuggingFacewithChatTemplate(BaseModel): self.model.eval() self.model.generation_config.do_sample = False + + def get_ppl_tokenwise(self, inputs: List[str], label: List[List[int]], mask_length: Optional[List[int]] = None) -> List[float]: + """Get inference-ppl per token given a list of inputs and label. + + Args: + inputs (List[str]): A list of strings. + label (List[List[int]]): A list of list of label, each label is a tuple of (start, end, 1) + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. It's okay to skip + its implementation if advanced features in PPLInfernecer is + not needed. + + Returns: + List[float]: A list of perplexity scores. + """ + assert self.tokenizer.pad_token + import torch + import torch.nn.functional as F + pad_token_id = self.tokenizer.pad_token_id + messages = _convert_base_messages(inputs) + + tokenize_kwargs = dict( + return_tensors='pt', + padding=True, + truncation=True, + add_special_tokens=True, + max_length=self.max_seq_len, + ) + + self.tokenizer.padding_side = 'right' + self.tokenizer.truncation_side = 'right' + + tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) + + tokens = {k: v.to(self.model.device) for k, v in tokens.items()} + outputs = self.model(**tokens)[0] + + batch_size, seq_len, vocab_size = outputs.shape + shift_logits = outputs[:, :-1, :].contiguous().float() + shift_labels = tokens['input_ids'][:, 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, vocab_size), + shift_labels.view(-1), + ignore_index=pad_token_id, + reduction='none').view(batch_size, seq_len - 1) + lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy() + + if mask_length is not None: + import numpy as np + mask = torch.zeros_like(shift_labels) # [batch,seqlen] + for i in range(len(mask)): + for j in range(mask_length[i] - 1, len(mask[i])): + mask[i][j] = 1 + loss = loss * mask + lens -= np.array(mask_length) + + loss = loss.cpu().numpy() + + decode_messages = [[self.tokenizer.decode([input_id]) for input_id in token] for token in tokens['input_ids']] + char_messages = [[ch for ch in message] for message in messages] + + # shifted to align label and loss + for i in range(len(decode_messages)): + decode_messages[i] = decode_messages[i][1:] + + aggregated_label_list = [[] for _ in range(len(decode_messages))] + + tag_list = [[] for _ in range(len(decode_messages))] + + for tmp_index, label_list in enumerate(label): + for single_label in label_list: + left = single_label[0] + right = single_label[1] + for i in range(left, right): + aggregated_label_list[tmp_index].append(i) + + + def align_sequences(seq1, seq2, sep_len): + """ + seq1: decoded sequence from token, one token may contain multiple characters + seq2: original separate character sequence + """ + i, j = 0, 0 + matched_pairs = [] + while i < len(seq1) and j < len(seq2): + word = seq1[i] + if len(word) == 0: + matched_pairs.append((word, [])) + i += 1 + continue + + if '\ufffd' in word: + for _ in range(sep_len): + matched_pairs.append((word, [j])) + i += 1 + j += 1 + continue + + char_sequence = '' + while j < len(seq2) and (char_sequence != word): + char_sequence += seq2[j] + if char_sequence == word: + matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)])) + j += 1 + break + elif len(char_sequence) > len(word): + if word == char_sequence[-len(word):]: + matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)])) + j += 1 + break + else: + j += 1 + else: + j += 1 + i += 1 + + return matched_pairs + + + + if 'qwen' in self.path or 'Qwen' in self.path: + sep_len = 2 + elif 'Llama-3' in self.path: + sep_len = 2 + elif 'Yi' in self.path: + sep_len = 3 + elif 'Llama-2' in self.path: + sep_len = 3 + elif 'deepseek' in self.path: + sep_len = 2 + else: + sep_len = 3 + + + matched_pairs_list = [align_sequences(decode_messages[i], char_messages[i], sep_len) for i in range(len(decode_messages))] + for match_index, matched_pairs in enumerate(matched_pairs_list): + for i, (word, indices) in enumerate(matched_pairs): + for j in indices: + if j in aggregated_label_list[match_index]: + tag_list[match_index].append(i) + break + + inference_loss_list = [] + token_len_list = [] + for i in range(len(loss)): + inference_loss = 0 + token_len = 0 + for j in range(len(loss[i])): + if j in tag_list[i]: + + inference_loss += loss[i][j] + print(loss[i][j]) + token_len += 1 + inference_loss_list.append(inference_loss) + token_len_list.append(token_len) + + return inference_loss_list, token_len_list + def _get_potential_stop_words(self, path: Optional[str]): from transformers import GenerationConfig potential_stop_words = [] diff --git a/opencompass/openicl/icl_evaluator/__init__.py b/opencompass/openicl/icl_evaluator/__init__.py index 0c68ce0c..1fd1683b 100644 --- a/opencompass/openicl/icl_evaluator/__init__.py +++ b/opencompass/openicl/icl_evaluator/__init__.py @@ -6,6 +6,7 @@ from .icl_circular_evaluator import CircularEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa from .icl_hf_evaluator import * # noqa from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa +from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa from .icl_misc_evaluator import AverageMinKEvaluator # noqa from .icl_misc_evaluator import AveragePPLEvaluator # noqa from .icl_plugin_evaluator import TEvalEvaluator # noqa diff --git a/opencompass/openicl/icl_evaluator/icl_misc_evaluator.py b/opencompass/openicl/icl_evaluator/icl_misc_evaluator.py index ddeb377a..fbb12209 100644 --- a/opencompass/openicl/icl_evaluator/icl_misc_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_misc_evaluator.py @@ -17,3 +17,11 @@ class AverageMinKEvaluator(BaseEvaluator): def score(self, mink): average_mink = sum(mink) / len(mink) return {'average_mink': average_mink} + + +@ICL_EVALUATORS.register_module() +class AverageInferencePPLEvaluator(BaseEvaluator): + + def score(self, ppl, token_len): + average_ppl = sum(ppl) / sum(token_len) + return {'average_ppl': average_ppl} diff --git a/opencompass/openicl/icl_inferencer/__init__.py b/opencompass/openicl/icl_inferencer/__init__.py index 677af4e4..0f034276 100644 --- a/opencompass/openicl/icl_inferencer/__init__.py +++ b/opencompass/openicl/icl_inferencer/__init__.py @@ -4,6 +4,8 @@ from .icl_base_inferencer import BaseInferencer # noqa from .icl_chat_inferencer import ChatInferencer # noqa from .icl_clp_inferencer import CLPInferencer # noqa from .icl_gen_inferencer import GenInferencer # noqa +from .icl_inference_ppl_only_inferencer import \ + InferencePPLOnlyInferencer # noqa from .icl_ll_inferencer import LLInferencer # noqa from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa from .icl_ppl_inferencer import PPLInferencer # noqa diff --git a/opencompass/openicl/icl_inferencer/icl_inference_ppl_only_inferencer.py b/opencompass/openicl/icl_inferencer/icl_inference_ppl_only_inferencer.py new file mode 100644 index 00000000..3f6e0def --- /dev/null +++ b/opencompass/openicl/icl_inferencer/icl_inference_ppl_only_inferencer.py @@ -0,0 +1,239 @@ +"""PPL Inferencer.""" + +import os +from typing import List, Optional + +import mmengine +import torch +from tqdm import tqdm + +from opencompass.models.base import BaseModel +from opencompass.registry import ICL_INFERENCERS + +from ..icl_prompt_template import PromptTemplate +from ..icl_retriever import BaseRetriever +from ..utils import get_logger +from .icl_base_inferencer import BaseInferencer, dump_results_dict + +logger = get_logger(__name__) + + +@ICL_INFERENCERS.register_module() +class InferencePPLOnlyInferencer(BaseInferencer): + """InferencePPLOnlyInferencer class to calculate Inference-PPL only, no + choice is made. This Inferencer is usually used along with + AverageInferencePPLEvaluator. + + Attributes: + model (:obj:`BaseModel`, optional): The module to inference. + max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by + the LM. + batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader` + output_json_filepath (:obj:`str`, optional): File path for output + `JSON` file. + output_json_filename (:obj:`str`, optional): File name for output + `JSON` file. + save_every (:obj:`int`, optional): Save intermediate results every + """ + + def __init__( + self, + model: BaseModel, + max_seq_len: Optional[int] = None, + batch_size: Optional[int] = 1, + output_json_filepath: Optional[str] = './icl_inference_output', + output_json_filename: Optional[str] = 'predictions', + save_every: Optional[int] = 1, + **kwargs) -> None: + super().__init__( + model=model, + max_seq_len=max_seq_len, + batch_size=batch_size, + output_json_filename=output_json_filename, + output_json_filepath=output_json_filepath, + **kwargs, + ) + + self.save_every = save_every + + def inference(self, + retriever: BaseRetriever, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None, + output_json_filepath: Optional[str] = None, + output_json_filename: Optional[str] = None) -> List: + # 1. Preparation for output logs + output_handler = InferencePPLOnlyInferencerOutputHandler() + + if output_json_filepath is None: + output_json_filepath = self.output_json_filepath + if output_json_filename is None: + output_json_filename = self.output_json_filename + + # 2. Get results of retrieval process + ice_idx_list = retriever.retrieve() + + # 3. Generate prompts for testing input + prompt_list, label_list = self.get_generation_prompt_list_and_label( + ice_idx_list, + retriever, + max_seq_len=self.max_seq_len, + ice_template=ice_template, + prompt_template=prompt_template) + + prompt_list = [{ + 'prompt': prompt, + 'label': label + } for prompt, label in zip(prompt_list, label_list)] + + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + + assert ds_reader.output_column is None, ( + 'InferencePPLOnlyInferencer supports `output_column=None` only.') + + # Create tmp json file for saving intermediate results and future + # resuming + index = 0 + tmp_json_filepath = os.path.join(output_json_filepath, + 'tmp_' + output_json_filename) + if os.path.exists(tmp_json_filepath): + # TODO: move resume to output handler + try: + tmp_result_dict = mmengine.load(tmp_json_filepath) + except Exception: + pass + else: + output_handler.results_dict = tmp_result_dict + index = len(tmp_result_dict) + + # 4. Wrap prompts with Dataloader + dataloader = self.get_dataloader(prompt_list[index:], self.batch_size) + + # 5. Inference for prompts in each batch + logger.info('Starting inference process...') + for datum in tqdm(dataloader, disable=not self.is_main_process): + entry = [datum_single['prompt'] for datum_single in datum] + label = [datum_single['label'] for datum_single in datum] + + # 5-1. Inference with local model + with torch.no_grad(): + (inference_loss_list, + token_len_list) = self.model.get_ppl_tokenwise_from_template( + entry, label) + + parsed_entries = self.model.parse_template(entry, mode='gen') + # 5-3. Save current output + for prompt, inference_loss, token_len, in zip( + parsed_entries, inference_loss_list, token_len_list): + output_handler.save_results(prompt, inference_loss, token_len, + index) + index = index + 1 + + # 5-4. Save intermediate results + if (self.save_every is not None and index % self.save_every == 0 + and self.is_main_process): + output_handler.write_to_json(output_json_filepath, + 'tmp_' + output_json_filename) + + # 6. Output + if self.is_main_process: + os.makedirs(output_json_filepath, exist_ok=True) + output_handler.write_to_json(output_json_filepath, + output_json_filename) + if os.path.exists(tmp_json_filepath): + os.remove(tmp_json_filepath) + + return [ + sample['ppl'] for sample in output_handler.results_dict.values() + ] + + def get_generation_prompt_list_from_retriever_indices( + self, + ice_idx_list: List[List[int]], + retriever: BaseRetriever, + max_seq_len: Optional[int] = None, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None): + prompt_list = [] + for idx, ice_idx in enumerate(ice_idx_list): + ice = retriever.generate_ice(ice_idx, ice_template=ice_template) + + prompt = retriever.generate_prompt_for_generate_task( + idx, + ice, + ice_template=ice_template, + prompt_template=prompt_template) + + if max_seq_len is not None: + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + while len(ice_idx) > 0 and prompt_token_num > max_seq_len: + ice_idx = ice_idx[:-1] + ice = retriever.generate_ice(ice_idx, + ice_template=ice_template) + prompt = retriever.generate_prompt_for_generate_task( + idx, + ice, + ice_template=ice_template, + prompt_template=prompt_template) + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + prompt_list.append(prompt) + return prompt_list + + def get_generation_prompt_list_and_label( + self, + ice_idx_list: List[List[int]], + retriever: BaseRetriever, + max_seq_len: Optional[int] = None, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None): + prompt_list = [] + label_list = [] + for idx, ice_idx in enumerate(ice_idx_list): + ice = retriever.generate_ice(ice_idx, ice_template=ice_template) + + prompt, label = retriever.generate_prompt_and_label_for_generate_task( # noqa + idx, + ice, + ice_template=ice_template, + prompt_template=prompt_template) + + if max_seq_len is not None: + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + while len(ice_idx) > 0 and prompt_token_num > max_seq_len: + ice_idx = ice_idx[:-1] + ice = retriever.generate_ice(ice_idx, + ice_template=ice_template) + prompt, label = retriever.generate_prompt_for_generate_task( # noqa + idx, + ice, + ice_template=ice_template, + prompt_template=prompt_template) + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + prompt_list.append(prompt) + label_list.append(label) + return prompt_list, label_list + + +class InferencePPLOnlyInferencerOutputHandler: + origin_prompt_dict = {} + output_dict = {} + results_dict = {} + + def __init__(self) -> None: + self.results_dict = {} + + def write_to_json(self, save_dir: str, filename: str): + """Dump the result to a json file.""" + dump_results_dict(self.results_dict, os.path.join(save_dir, filename)) + + def save_results(self, origin_prompt, ppl, token_len, idx): + self.results_dict[str(idx)] = { + 'origin_prompt': origin_prompt, + 'ppl': ppl, + 'token_len': token_len, + } diff --git a/opencompass/openicl/icl_retriever/icl_base_retriever.py b/opencompass/openicl/icl_retriever/icl_base_retriever.py index 6445d9f9..30be06fe 100644 --- a/opencompass/openicl/icl_retriever/icl_base_retriever.py +++ b/opencompass/openicl/icl_retriever/icl_base_retriever.py @@ -207,6 +207,59 @@ class BaseRetriever: raise NotImplementedError( 'Leaving prompt as empty is not supported') + def generate_prompt_and_label_for_generate_task( + self, + idx, + ice, + gen_field_replace_token='', + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None): + """Generate the prompt and the label info for one test example in + generative evaluation with `prompt_template`. If `prompt_template` is + not provided, the `ice_template` will be used to generate the prompt. + The token represented by `gen_field_replace_token` will not be replaced + by the generated text, or it will leaks the answer. + + Args: + idx (`int`): The index of the test example. + ice (`str`): The in-context example for the test example. + gen_field_replace_token (`str`): The token of the answer in the + prompt. Defaults to ''. + ice_template (`Optional[PromptTemplate]`): The template for + in-context example. Defaults to None. + prompt_template (`Optional[PromptTemplate]`): The template for + prompt. Defaults to None. + """ + if prompt_template is not None and ice_template is not None: + if prompt_template.ice_token is not None: + return prompt_template.generate_item( + self.test_ds[idx], + output_field=self.dataset_reader.output_column, + output_field_replace_token=gen_field_replace_token, + ice_field_replace_token=ice), self.test_ds[idx]['label'] + else: + raise NotImplementedError( + 'ice_token of prompt_template is not provided') + elif ice_template is not None and prompt_template is None: + if ice_template.ice_token is not None: + return ice_template.generate_item( + self.test_ds[idx], + output_field=self.dataset_reader.output_column, + output_field_replace_token=gen_field_replace_token, + ice_field_replace_token=ice), self.test_ds[idx]['label'] + else: + raise NotImplementedError( + 'ice_token of ice_template is not provided') + elif ice_template is None and prompt_template is not None: + return prompt_template.generate_item( + self.test_ds[idx], + output_field=self.dataset_reader.output_column, + output_field_replace_token=gen_field_replace_token, + ice_field_replace_token=ice), self.test_ds[idx]['label'] + else: + raise NotImplementedError( + 'Leaving prompt as empty is not supported') + def generate_prompt_for_adv_generate_task( self, idx,