OpenCompass/opencompass/openicl/icl_retriever/icl_knowledge_retriever.py
Kaiaicy e754804c68 [Feat] Support Knowledge-based Retriever
* [Feature]: Add KnowledgeRetriever with config demo
2023-09-03 16:15:09 +08:00

346 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Local Knowledge Retriever."""
from typing import List, Optional, Callable, Dict, Any
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.registry import ICL_RETRIEVERS
from opencompass.utils import get_logger
logger = get_logger(__name__)
import os
import re
import numpy as np
import torch
from copy import deepcopy
from tqdm import tqdm
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader, CSVLoader, UnstructuredFileLoader
from langchain.schema import Document
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
VECTOR_SEARCH_SCORE_THRESHOLD = 500
CHUNK_SIZE = 50
class RetrievedFAISS(FAISS, VectorStore):
def __init__(self,
embedding_function: Callable,
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
normalize_L2: bool = False,
):
super().__init__(embedding_function=embedding_function,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2)
self.score_threshold = VECTOR_SEARCH_SCORE_THRESHOLD
self.chunk_size = CHUNK_SIZE
self.chunk_conent = False
def seperate_list(self, lines: List[int]) -> List[List[int]]:
results = []
cur_line = [lines[0]]
docs_source = self.index_to_docstore_source(lines[0])
for i in range(1, len(lines)):
if lines[i - 1] + 1 == lines[i] and self.index_to_docstore_source(lines[i]) == docs_source:
cur_line.append(lines[i])
else:
results.append(cur_line)
cur_line = [lines[i]]
docs_source = self.index_to_docstore_source(lines[i])
results.append(cur_line)
return results
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Document]:
faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
rearrange_id_list = False
for j, i in enumerate(indices[0]):
if i == -1 or 0 < self.score_threshold < scores[0][j]:
continue
if i in self.index_to_docstore_id:
_id = self.index_to_docstore_id[i]
else:
continue
doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
docs.append(doc)
continue
id_set.add(i)
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != \
doc.metadata["source"]:
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
rearrange_id_list = True
if break_flag:
break
if (not self.chunk_conent) or (not rearrange_id_list):
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
id_list = sorted(list(id_set))
id_lists = self.seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = deepcopy(self.docstore.search(_id))
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += " " + doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
doc.metadata["score"] = int(doc_score)
docs.append(doc)
return docs
def list_docs(self):
return list(v.metadata["source"] for v in self.docstore._dict.values())
def index_to_docstore_source(self,i:int):
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
return doc.metadata["source"]
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(
self,
max_length: int,
**kwargs
):
super().__init__(**kwargs)
self.max_length = max_length
def split_text(
self,
text: str,
is_pdf: bool = False,
) -> List[str]:
if is_pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;.!?。!?\?])([^”’])', r"\1\n\2", text)
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r'\1\n\2', text)
text = text.rstrip()
lines = [i for i in text.split("\n") if i]
for cur_line in lines:
if len(cur_line) > self.max_length:
sub_lines1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r'\1\n\2', cur_line).split("\n")
for cur_s_line1 in sub_lines1:
if len(cur_s_line1) > self.max_length:
sub_lines2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', cur_s_line1).split("\n")
for cur_s_line2 in sub_lines2:
if len(cur_s_line2) > self.max_length:
cur_s_line3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', cur_s_line2)
cur_s_idx2 = sub_lines2.index(cur_s_line2)
sub_lines2 = sub_lines2[:cur_s_idx2] + [i for i in cur_s_line3.split("\n") if i] + sub_lines2[cur_s_idx2 + 1:]
cur_s_idx1 = sub_lines1.index(cur_s_line1)
sub_lines1 = sub_lines1[:cur_s_idx1] + [i for i in sub_lines2 if i] + sub_lines1[cur_s_idx1 + 1:]
cur_idx = lines.index(cur_line)
lines = lines[:cur_idx] + [i for i in sub_lines1 if i] + lines[cur_idx + 1:]
return lines
def load_knowledge(
knowledge_doc: str,
sentence_max_length: int
) -> List[Document]:
"""
Load and split knowledge documents from .txt or .csv formats.
knowledge_doc (`str`): Path to the knowledge document file.
sentence_max_length (`str`): Maximum length of a sentence in terms of tokens.
"""
text_splitter = ChineseTextSplitter(max_length=sentence_max_length)
if knowledge_doc.lower().endswith(".txt"):
loader = TextLoader(knowledge_doc, autodetect_encoding=True)
docs = loader.load_and_split(text_splitter)
elif knowledge_doc.lower().endswith(".csv"):
loader = CSVLoader(knowledge_doc)
docs = loader.load()
else:
loader = UnstructuredFileLoader(knowledge_doc, mode="elements")
docs = loader.load_and_split(text_splitter=text_splitter)
return docs
class LocalKnowledgeBase:
"""Local Knowledge Base.
Args:
embedding_path (`Optional[str]`): The path or name of the
pre-trained embedding model used for encoding text.
topk (`int`): The number of most similar knowledge
documents to retrieve for a given query.
knowledge_docs (`List`): Files containing the knowledge base,
supporting txt, csv formats.
sentence_max_length (`int`): Maximum length of a sentence
in terms of tokens for processing.
vector_store_path (`str or os.PathLike`): Path to save or load
pre-computed document vectors.
device (`Optional[str]`): The device (CPU or GPU) to
run the embedding model on.
"""
def __init__(
self,
embedding_path: str,
topk: int,
knowledge_docs: List[str],
sentence_max_length: int,
vector_store_path: str or os.PathLike = None,
device: Optional[str] = None,
) -> None:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_path,
model_kwargs={'device': device}
)
self.topk = topk
docs = sum([load_knowledge(knowledge_doc=cur_doc, sentence_max_length=sentence_max_length) for cur_doc in knowledge_docs], [])
if vector_store_path is None:
vector_store_path = os.path.join(
os.path.commonprefix(knowledge_docs).rsplit('/', 1)[0],
"vector_store")
if os.path.isdir(vector_store_path) and "index.faiss" in os.listdir(vector_store_path):
logger.info(f'Loading from existing vector store ({vector_store_path})...')
self.vector_store = RetrievedFAISS.load_local(vector_store_path, self.embeddings)
self.vector_store.add_documents(docs)
else:
logger.info(f'Constructing vector store ({vector_store_path})...')
self.vector_store = RetrievedFAISS.from_documents(docs, self.embeddings)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
self.vector_store.save_local(vector_store_path)
logger.info(f'Vector store is ready.')
def retrieve_one(self, query: str, separator: str = ' ') -> str:
"""Retrieve the most relevant knowledge documents based on a query."""
related_docs_with_score = self.vector_store.similarity_search_with_score(
query,
k=self.topk)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return separator.join([cur_doc.page_content for cur_doc in related_docs_with_score])
@ICL_RETRIEVERS.register_module()
class KnowledgeRetriever(BaseRetriever):
"""Local Knowledge Retriever. The retriever returns related local knowledge for all queries.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
knowledge_docs (`List`): Files containing the knowledge base,
supporting txt, csv formats.
retrieve_keys (`List`): Keys of the test sample that require
indexing of relevant knowledge.
embedding_path (`Optional[str]`): The path or name of the
pre-trained embedding model used for encoding text.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to ''.
"""
def __init__(self,
dataset,
knowledge_docs: List,
retrieve_keys: List,
embedding_path: Optional[str] = 'GanymedeNil/text2vec-large-chinese',
ice_eos_token: Optional[str] = '') -> None:
super().__init__(dataset, '', ice_eos_token, 0)
self.knowledge_ds = None
self.retrieve_keys = retrieve_keys
self.local_knowledge_base = LocalKnowledgeBase(
embedding_path=embedding_path,
knowledge_docs=knowledge_docs,
topk=3,
sentence_max_length=100)
def retrieve(self) -> List[List]:
"""Construct the knowledge base associated with test each sample and retrieve the sequential indices."""
logger.info('Retrieving data for test set...')
rtr_idx_list = [[i] for i in range(len(self.test_ds))]
self.knowledge_ds = [
{'knowledge': '; '.join([
self.local_knowledge_base.retrieve_one(cur_d[option_key])
for option_key in self.retrieve_keys
])} for cur_d in tqdm(self.test_ds)]
return rtr_idx_list
def generate_ice(self,
idx_list: List[int],
ice_template: Optional[PromptTemplate] = None) -> str:
"""Generate the knowledge-related example for one test example.
Args:
idx_list (`List[int]`): The index of knowledge-related examples for the
test example.
ice_template (`Optional[PromptTemplate]`): The template for
knowledge-related example. Defaults to None.
"""
assert self.knowledge_ds is not None, (
'knowledge_ds must be set first in retrieve method')
if ice_template is None:
assert len(
idx_list
) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa
generated_ice_list = []
for idx in idx_list:
generated_ice_list.append(
ice_template.generate_ice_item(
self.knowledge_ds[idx],
''))
generated_ice = self.ice_separator.join(
generated_ice_list) + self.ice_eos_token
return generated_ice