mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
![]() |
"""Votek Retriever."""
|
||
|
|
||
|
import json
|
||
|
import os
|
||
|
import random
|
||
|
from collections import defaultdict
|
||
|
from typing import Optional
|
||
|
|
||
|
import numpy as np
|
||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
|
||
|
from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever
|
||
|
|
||
|
|
||
|
class VotekRetriever(TopkRetriever):
|
||
|
"""Vote-k In-context Learning Retriever, subclass of `TopkRetriever`.
|
||
|
|
||
|
**WARNING**: This class has not been tested thoroughly. Please use it with
|
||
|
caution.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
dataset,
|
||
|
ice_separator: Optional[str] = '\n',
|
||
|
ice_eos_token: Optional[str] = '\n',
|
||
|
ice_num: Optional[int] = 1,
|
||
|
sentence_transformers_model_name: Optional[
|
||
|
str] = 'all-mpnet-base-v2',
|
||
|
tokenizer_name: Optional[str] = 'gpt2-xl',
|
||
|
batch_size: Optional[int] = 1,
|
||
|
votek_k: Optional[int] = 3) -> None:
|
||
|
super().__init__(dataset, ice_separator, ice_eos_token, ice_num,
|
||
|
sentence_transformers_model_name, tokenizer_name,
|
||
|
batch_size)
|
||
|
self.votek_k = votek_k
|
||
|
|
||
|
def votek_select(self,
|
||
|
embeddings=None,
|
||
|
select_num=None,
|
||
|
k=None,
|
||
|
overlap_threshold=None,
|
||
|
vote_file=None):
|
||
|
n = len(embeddings)
|
||
|
if vote_file is not None and os.path.isfile(vote_file):
|
||
|
with open(vote_file) as f:
|
||
|
vote_stat = json.load(f)
|
||
|
else:
|
||
|
vote_stat = defaultdict(list)
|
||
|
|
||
|
for i in range(n):
|
||
|
cur_emb = embeddings[i].reshape(1, -1)
|
||
|
cur_scores = np.sum(cosine_similarity(embeddings, cur_emb),
|
||
|
axis=1)
|
||
|
sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1]
|
||
|
for idx in sorted_indices:
|
||
|
if idx != i:
|
||
|
vote_stat[idx].append(i)
|
||
|
|
||
|
if vote_file is not None:
|
||
|
with open(vote_file, 'w', encoding='utf-8') as f:
|
||
|
json.dump(vote_stat, f)
|
||
|
votes = sorted(vote_stat.items(),
|
||
|
key=lambda x: len(x[1]),
|
||
|
reverse=True)
|
||
|
j = 0
|
||
|
selected_indices = []
|
||
|
while len(selected_indices) < select_num and j < len(votes):
|
||
|
candidate_set = set(votes[j][1])
|
||
|
flag = True
|
||
|
for pre in range(j):
|
||
|
cur_set = set(votes[pre][1])
|
||
|
if len(candidate_set.intersection(
|
||
|
cur_set)) >= overlap_threshold * len(candidate_set):
|
||
|
flag = False
|
||
|
break
|
||
|
if not flag:
|
||
|
j += 1
|
||
|
continue
|
||
|
selected_indices.append(int(votes[j][0]))
|
||
|
j += 1
|
||
|
if len(selected_indices) < select_num:
|
||
|
unselected_indices = []
|
||
|
cur_num = len(selected_indices)
|
||
|
for i in range(n):
|
||
|
if i not in selected_indices:
|
||
|
unselected_indices.append(i)
|
||
|
selected_indices += random.sample(unselected_indices,
|
||
|
select_num - cur_num)
|
||
|
return selected_indices
|
||
|
|
||
|
def vote_k_search(self):
|
||
|
vote_k_idxs = self.votek_select(embeddings=self.embed_list,
|
||
|
select_num=self.ice_num,
|
||
|
k=self.votek_k,
|
||
|
overlap_threshold=1)
|
||
|
return [vote_k_idxs[:] for _ in range(len(self.test_ds))]
|
||
|
|
||
|
def retrieve(self):
|
||
|
return self.vote_k_search()
|