mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
![]() |
"""Random Retriever."""
|
||
|
|
||
|
from typing import Optional
|
||
|
|
||
|
import numpy as np
|
||
|
from tqdm import trange
|
||
|
|
||
|
from opencompass.openicl.icl_retriever import BaseRetriever
|
||
|
from opencompass.openicl.utils.logging import get_logger
|
||
|
|
||
|
logger = get_logger(__name__)
|
||
|
|
||
|
|
||
|
class RandomRetriever(BaseRetriever):
|
||
|
"""Random Retriever. Each in-context example of the test prompts is
|
||
|
retrieved in a random way.
|
||
|
|
||
|
**WARNING**: This class has not been tested thoroughly. Please use it with
|
||
|
caution.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
dataset,
|
||
|
ice_separator: Optional[str] = '\n',
|
||
|
ice_eos_token: Optional[str] = '\n',
|
||
|
ice_num: Optional[int] = 1,
|
||
|
seed: Optional[int] = 43) -> None:
|
||
|
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
|
||
|
self.seed = seed
|
||
|
|
||
|
def retrieve(self):
|
||
|
np.random.seed(self.seed)
|
||
|
num_idx = len(self.index_ds)
|
||
|
rtr_idx_list = []
|
||
|
logger.info('Retrieving data for test set...')
|
||
|
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
|
||
|
idx_list = np.random.choice(num_idx, self.ice_num,
|
||
|
replace=False).tolist()
|
||
|
rtr_idx_list.append(idx_list)
|
||
|
return rtr_idx_list
|