"""Topk Retriever.""" import copy from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import numpy as np import torch import tqdm from torch.utils.data import DataLoader from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase from transformers.file_utils import PaddingStrategy from opencompass.openicl.icl_dataset_reader import DatasetEncoder from opencompass.openicl.icl_retriever import BaseRetriever from opencompass.openicl.utils.logging import get_logger from opencompass.registry import ICL_RETRIEVERS logger = get_logger(__name__) @ICL_RETRIEVERS.register_module() class TopkRetriever(BaseRetriever): """Base class for Topk In-context Learning Retriever, implemented with basic knn. SentenceTransformer is used to calculate embeddings. Faiss is used to do the nearest neighbor search. Args: dataset (`BaseDataset`): Any BaseDataset instances. Attributes of ``reader``, ``train`` and ``test`` will be used. ice_separator (`Optional[str]`): The separator between each in-context example template when origin `PromptTemplate` is provided. Defaults to '\n'. ice_eos_token (`Optional[str]`): The end of sentence token for in-context example template when origin `PromptTemplate` is provided. Defaults to '\n'. ice_num (`Optional[int]`): The number of in-context example template when origin `PromptTemplate` is provided. Defaults to 1. sentence_transformers_model_name (`Optional[str]`): The name of the sentence transformers model. Defaults to 'all-mpnet-base-v2'. tokenizer_name (`Optional[str]`): The name of the tokenizer. Defaults to 'gpt2-xl'. batch_size (`Optional[int]`): The batch size for the dataloader. Defaults to 1. """ model = None def __init__(self, dataset, ice_separator: Optional[str] = '\n', ice_eos_token: Optional[str] = '\n', ice_num: Optional[int] = 1, sentence_transformers_model_name: Optional[ str] = 'all-mpnet-base-v2', tokenizer_name: Optional[str] = 'gpt2-xl', batch_size: Optional[int] = 1) -> None: super().__init__(dataset, ice_separator, ice_eos_token, ice_num) from sentence_transformers import SentenceTransformer self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.batch_size = batch_size self.tokenizer_name = tokenizer_name gen_datalist = self.dataset_reader.generate_input_field_corpus( self.test_ds) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = 'right' self.encode_dataset = DatasetEncoder(gen_datalist, tokenizer=self.tokenizer) co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device) self.dataloader = DataLoader(self.encode_dataset, batch_size=self.batch_size, collate_fn=co) self.model = SentenceTransformer(sentence_transformers_model_name) self.model = self.model.to(self.device) self.model.eval() self.index = self.create_index() def create_index(self): import faiss self.select_datalist = self.dataset_reader.generate_input_field_corpus( self.index_ds) encode_datalist = DatasetEncoder(self.select_datalist, tokenizer=self.tokenizer) co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device) dataloader = DataLoader(encode_datalist, batch_size=self.batch_size, collate_fn=co) index = faiss.IndexIDMap( faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension())) res_list = self.forward(dataloader, process_bar=True, information='Creating index for index set...') id_list = np.array([res['metadata']['id'] for res in res_list]) self.embed_list = np.stack([res['embed'] for res in res_list]) index.add_with_ids(self.embed_list, id_list) return index def knn_search(self, ice_num): res_list = self.forward(self.dataloader, process_bar=True, information='Embedding test set...') rtr_idx_list = [[] for _ in range(len(res_list))] logger.info('Retrieving data for test set...') for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): idx = entry['metadata']['id'] embed = np.expand_dims(entry['embed'], axis=0) near_ids = self.index.search(embed, ice_num)[1][0].tolist() rtr_idx_list[idx] = near_ids return rtr_idx_list def forward(self, dataloader, process_bar=False, information=''): res_list = [] _dataloader = copy.deepcopy(dataloader) if process_bar: logger.info(information) _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process) for _, entry in enumerate(_dataloader): with torch.no_grad(): metadata = entry.pop('metadata') raw_text = self.tokenizer.batch_decode( entry['input_ids'], skip_special_tokens=True, verbose=False) res = self.model.encode(raw_text, show_progress_bar=False) res_list.extend([{ 'embed': r, 'metadata': m } for r, m in zip(res, metadata)]) return res_list def retrieve(self): """Retrieve the in-context example index for each test example.""" return self.knn_search(self.ice_num) class ListWrapper: def __init__(self, data: List[Any]): self.data = data def to(self, device): return self.data def ignore_pad_dict(features): res_dict = {} if 'metadata' in features[0]: res_dict['metadata'] = ListWrapper( [x.pop('metadata') for x in features]) return res_dict @dataclass class DataCollatorWithPaddingAndCuda: tokenizer: PreTrainedTokenizerBase device: object = None padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = 3000 pad_to_multiple_of: Optional[int] = None def __call__( self, features: List[Dict[str, Union[List[int], torch.Tensor]]] ) -> BatchEncoding: res_dict = ignore_pad_dict(features) has_labels = 'labels' in features[0] if has_labels: labels = [{'input_ids': x.pop('labels')} for x in features] labels = self.tokenizer.pad( labels, padding=True, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_attention_mask=True, return_tensors='pt', verbose=False) # print(features) batch = self.tokenizer.pad(features, padding=True, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_attention_mask=True, return_tensors='pt', verbose=False) if has_labels: batch['labels'] = labels.input_ids batch.update(res_dict) if self.device: batch = batch.to(self.device) return batch