mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
127 lines
4.6 KiB
Python
127 lines
4.6 KiB
Python
"""DPP Retriever."""
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
|
|
from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever
|
|
from opencompass.openicl.utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DPPRetriever(TopkRetriever):
|
|
"""DPP In-context Learning Retriever, subclass of `TopkRetriever`. Two-
|
|
stage DPP is used, where first stage is to get results of TopK to reduce
|
|
candidate sets. Chechout https://arxiv.org/abs/2302.05698 for details.
|
|
|
|
**WARNING**: This class has not been tested thoroughly. Please use it with
|
|
caution.
|
|
"""
|
|
model = None
|
|
|
|
def __init__(self,
|
|
dataset,
|
|
ice_separator: Optional[str] = '\n',
|
|
ice_eos_token: Optional[str] = '\n',
|
|
ice_num: Optional[int] = 1,
|
|
sentence_transformers_model_name: Optional[
|
|
str] = 'all-mpnet-base-v2',
|
|
tokenizer_name: Optional[str] = 'gpt2-xl',
|
|
batch_size: Optional[int] = 1,
|
|
candidate_num: Optional[int] = 1,
|
|
seed: Optional[int] = 1,
|
|
scale_factor: Optional[float] = 0.1) -> None:
|
|
super().__init__(dataset, ice_separator, ice_eos_token, ice_num,
|
|
sentence_transformers_model_name, tokenizer_name,
|
|
batch_size)
|
|
self.candidate_num = candidate_num
|
|
self.seed = seed
|
|
self.scale_factor = scale_factor
|
|
|
|
def dpp_search(self):
|
|
res_list = self.forward(self.dataloader,
|
|
process_bar=True,
|
|
information='Embedding test set...')
|
|
rtr_idx_list = [[] for _ in range(len(res_list))]
|
|
logger.info('Retrieving data for test set...')
|
|
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
|
|
idx = entry['metadata']['id']
|
|
|
|
# get TopK results
|
|
embed = np.expand_dims(entry['embed'], axis=0)
|
|
near_ids = np.array(
|
|
self.index.search(embed, self.candidate_num)[1][0].tolist())
|
|
|
|
# DPP stage
|
|
near_reps, rel_scores, kernel_matrix = self.get_kernel(
|
|
embed, near_ids.tolist())
|
|
|
|
# MAP inference
|
|
samples_ids = fast_map_dpp(kernel_matrix, self.ice_num)
|
|
|
|
# ordered by relevance score
|
|
samples_scores = np.array([rel_scores[i] for i in samples_ids])
|
|
samples_ids = samples_ids[(-samples_scores).argsort()].tolist()
|
|
rtr_sub_list = [int(near_ids[i]) for i in samples_ids]
|
|
|
|
rtr_idx_list[idx] = rtr_sub_list
|
|
|
|
return rtr_idx_list
|
|
|
|
def retrieve(self):
|
|
return self.dpp_search()
|
|
|
|
def get_kernel(self, embed, candidates):
|
|
near_reps = np.stack(
|
|
[self.index.index.reconstruct(i) for i in candidates], axis=0)
|
|
# normalize first
|
|
embed = embed / np.linalg.norm(embed)
|
|
near_reps = near_reps / np.linalg.norm(
|
|
near_reps, keepdims=True, axis=1)
|
|
|
|
# to make kernel-matrix non-negative
|
|
rel_scores = np.matmul(embed, near_reps.T)[0]
|
|
rel_scores = (rel_scores + 1) / 2
|
|
|
|
# to prevent overflow error
|
|
rel_scores -= rel_scores.max()
|
|
|
|
# to balance relevance and diversity
|
|
rel_scores = np.exp(rel_scores / (2 * self.scale_factor))
|
|
|
|
# to make kernel-matrix non-negative
|
|
sim_matrix = np.matmul(near_reps, near_reps.T)
|
|
sim_matrix = (sim_matrix + 1) / 2
|
|
|
|
kernel_matrix = rel_scores[None] * sim_matrix * rel_scores[:, None]
|
|
return near_reps, rel_scores, kernel_matrix
|
|
|
|
|
|
def fast_map_dpp(kernel_matrix, max_length):
|
|
"""fast implementation of the greedy algorithm reference:
|
|
|
|
https://github.com/laming-chen/fast-map-dpp/blob/master/dpp_test.py
|
|
paper: Fast Greedy MAP Inference for Determinantal Point Process to Improve
|
|
Recommendation Diversity
|
|
"""
|
|
item_size = kernel_matrix.shape[0]
|
|
cis = np.zeros((max_length, item_size))
|
|
di2s = np.copy(np.diag(kernel_matrix))
|
|
selected_items = list()
|
|
selected_item = np.argmax(di2s)
|
|
selected_items.append(int(selected_item))
|
|
while len(selected_items) < max_length:
|
|
k = len(selected_items) - 1
|
|
ci_optimal = cis[:k, selected_item]
|
|
di_optimal = math.sqrt(di2s[selected_item])
|
|
elements = kernel_matrix[selected_item, :]
|
|
eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
|
|
cis[k, :] = eis
|
|
di2s -= np.square(eis)
|
|
selected_item = np.argmax(di2s)
|
|
selected_items.append(int(selected_item))
|
|
return selected_items
|