diff --git a/configs/datasets/FewCLUE_chid/FewCLUE_chid_knowledge_gen_0a29a2.py b/configs/datasets/FewCLUE_chid/FewCLUE_chid_knowledge_gen_0a29a2.py new file mode 100644 index 00000000..4c7641e2 --- /dev/null +++ b/configs/datasets/FewCLUE_chid/FewCLUE_chid_knowledge_gen_0a29a2.py @@ -0,0 +1,66 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import KnowledgeRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import CHIDDataset_V2 +from opencompass.utils.text_postprocessors import first_capital_postprocess + +chid_knowledge_reader_cfg = dict( + input_columns=["content", "A", "B", "C", "D", "E", "F", "G"], + output_column="answer", +) + +chid_knowledge_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template='以下是参考内容:{knowledge},结合上述参考内容,考虑接下来的问题:' + ), + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt= + "{content}\n请选择______处所填的词\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nF. {F}\nG. {G}\n请从“A”,“B”,“C”,“D”,“E”,“F”,“G”中进行选择。答:", + ), + ] + ), + ice_token='' + ), + retriever=dict( + type=KnowledgeRetriever, + knowledge_docs=[ + './data/knowledge/chengyu-01-of-02.txt', + './data/knowledge/chengyu-02-of-02.txt', + ], + retrieve_keys=['A', 'B', 'C', 'D', 'E', 'F', 'G'], + ice_eos_token='\n' + ), + inferencer=dict(type=GenInferencer), +) + +chid_knowledge_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_role="BOT", + pred_postprocessor=dict(type=first_capital_postprocess), +) + +chid_knowledge_datasets = [ + dict( + abbr="chid-dev", + type=CHIDDataset_V2, + path="./data/FewCLUE/chid/dev_few_all.json", + reader_cfg=chid_knowledge_reader_cfg, + infer_cfg=chid_knowledge_infer_cfg, + eval_cfg=chid_knowledge_eval_cfg, + ), + dict( + abbr="chid-test", + type=CHIDDataset_V2, + path="./data/FewCLUE/chid/test_public.json", + reader_cfg=chid_knowledge_reader_cfg, + infer_cfg=chid_knowledge_infer_cfg, + eval_cfg=chid_knowledge_eval_cfg, + ), +] diff --git a/configs/eval_demo_knowledge.py b/configs/eval_demo_knowledge.py new file mode 100644 index 00000000..fa2ac50c --- /dev/null +++ b/configs/eval_demo_knowledge.py @@ -0,0 +1,9 @@ +from mmengine.config import read_base + +with read_base(): + from .datasets.FewCLUE_chid.FewCLUE_chid_knowledge_gen_0a29a2 import chid_knowledge_datasets + from .models.hf_opt_125m import opt125m + from .models.hf_opt_350m import opt350m + +datasets = [*chid_knowledge_datasets] +models = [opt125m, opt350m] diff --git a/opencompass/openicl/icl_retriever/__init__.py b/opencompass/openicl/icl_retriever/__init__.py index b48cdd8f..c45bdbac 100644 --- a/opencompass/openicl/icl_retriever/__init__.py +++ b/opencompass/openicl/icl_retriever/__init__.py @@ -8,3 +8,4 @@ from .icl_sliding_k_retriever import SlidingWindowRetriever # noqa from .icl_topk_retriever import TopkRetriever # noqa from .icl_votek_retriever import VotekRetriever # noqa from .icl_zero_retriever import ZeroRetriever # noqa +from .icl_knowledge_retriever import KnowledgeRetriever diff --git a/opencompass/openicl/icl_retriever/icl_knowledge_retriever.py b/opencompass/openicl/icl_retriever/icl_knowledge_retriever.py new file mode 100644 index 00000000..5567c2ef --- /dev/null +++ b/opencompass/openicl/icl_retriever/icl_knowledge_retriever.py @@ -0,0 +1,345 @@ +"""Local Knowledge Retriever.""" + +from typing import List, Optional, Callable, Dict, Any + +from opencompass.openicl.icl_retriever import BaseRetriever +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.registry import ICL_RETRIEVERS +from opencompass.utils import get_logger + +logger = get_logger(__name__) + +import os +import re +import numpy as np +import torch +from copy import deepcopy +from tqdm import tqdm +from langchain.text_splitter import CharacterTextSplitter +from langchain.document_loaders import TextLoader, CSVLoader, UnstructuredFileLoader +from langchain.schema import Document +from langchain.vectorstores import FAISS +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.faiss import dependable_faiss_import +from langchain.docstore.base import Docstore +from langchain.docstore.document import Document + + + +VECTOR_SEARCH_SCORE_THRESHOLD = 500 +CHUNK_SIZE = 50 + +class RetrievedFAISS(FAISS, VectorStore): + def __init__(self, + embedding_function: Callable, + index: Any, + docstore: Docstore, + index_to_docstore_id: Dict[int, str], + normalize_L2: bool = False, + ): + super().__init__(embedding_function=embedding_function, + index=index, + docstore=docstore, + index_to_docstore_id=index_to_docstore_id, + normalize_L2=normalize_L2) + self.score_threshold = VECTOR_SEARCH_SCORE_THRESHOLD + self.chunk_size = CHUNK_SIZE + self.chunk_conent = False + + def seperate_list(self, lines: List[int]) -> List[List[int]]: + results = [] + cur_line = [lines[0]] + docs_source = self.index_to_docstore_source(lines[0]) + for i in range(1, len(lines)): + if lines[i - 1] + 1 == lines[i] and self.index_to_docstore_source(lines[i]) == docs_source: + cur_line.append(lines[i]) + else: + results.append(cur_line) + cur_line = [lines[i]] + docs_source = self.index_to_docstore_source(lines[i]) + results.append(cur_line) + return results + + def similarity_search_with_score_by_vector( + self, embedding: List[float], k: int = 4 + ) -> List[Document]: + faiss = dependable_faiss_import() + vector = np.array([embedding], dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + scores, indices = self.index.search(vector, k) + docs = [] + id_set = set() + store_len = len(self.index_to_docstore_id) + rearrange_id_list = False + for j, i in enumerate(indices[0]): + if i == -1 or 0 < self.score_threshold < scores[0][j]: + continue + if i in self.index_to_docstore_id: + _id = self.index_to_docstore_id[i] + else: + continue + doc = self.docstore.search(_id) + if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]): + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + doc.metadata["score"] = int(scores[0][j]) + docs.append(doc) + continue + + id_set.add(i) + docs_len = len(doc.page_content) + for k in range(1, max(i, store_len - i)): + break_flag = False + if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward": + expand_range = [i + k] + elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward": + expand_range = [i - k] + else: + expand_range = [i + k, i - k] + for l in expand_range: + if l not in id_set and 0 <= l < len(self.index_to_docstore_id): + _id0 = self.index_to_docstore_id[l] + doc0 = self.docstore.search(_id0) + if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != \ + doc.metadata["source"]: + break_flag = True + break + elif doc0.metadata["source"] == doc.metadata["source"]: + docs_len += len(doc0.page_content) + id_set.add(l) + rearrange_id_list = True + if break_flag: + break + if (not self.chunk_conent) or (not rearrange_id_list): + return docs + if len(id_set) == 0 and self.score_threshold > 0: + return [] + id_list = sorted(list(id_set)) + id_lists = self.seperate_list(id_list) + for id_seq in id_lists: + for id in id_seq: + if id == id_seq[0]: + _id = self.index_to_docstore_id[id] + doc = deepcopy(self.docstore.search(_id)) + else: + _id0 = self.index_to_docstore_id[id] + doc0 = self.docstore.search(_id0) + doc.page_content += " " + doc0.page_content + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) + doc.metadata["score"] = int(doc_score) + docs.append(doc) + return docs + + def list_docs(self): + return list(v.metadata["source"] for v in self.docstore._dict.values()) + + def index_to_docstore_source(self,i:int): + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + return doc.metadata["source"] + +class ChineseTextSplitter(CharacterTextSplitter): + def __init__( + self, + max_length: int, + **kwargs + ): + super().__init__(**kwargs) + self.max_length = max_length + + def split_text( + self, + text: str, + is_pdf: bool = False, + ) -> List[str]: + if is_pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub('\s', " ", text) + text = re.sub("\n\n", "", text) + + text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) + text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) + text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) + text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) + text = text.rstrip() + lines = [i for i in text.split("\n") if i] + for cur_line in lines: + if len(cur_line) > self.max_length: + sub_lines1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', cur_line).split("\n") + for cur_s_line1 in sub_lines1: + if len(cur_s_line1) > self.max_length: + sub_lines2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', cur_s_line1).split("\n") + for cur_s_line2 in sub_lines2: + if len(cur_s_line2) > self.max_length: + cur_s_line3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', cur_s_line2) + cur_s_idx2 = sub_lines2.index(cur_s_line2) + sub_lines2 = sub_lines2[:cur_s_idx2] + [i for i in cur_s_line3.split("\n") if i] + sub_lines2[cur_s_idx2 + 1:] + cur_s_idx1 = sub_lines1.index(cur_s_line1) + sub_lines1 = sub_lines1[:cur_s_idx1] + [i for i in sub_lines2 if i] + sub_lines1[cur_s_idx1 + 1:] + + cur_idx = lines.index(cur_line) + lines = lines[:cur_idx] + [i for i in sub_lines1 if i] + lines[cur_idx + 1:] + return lines + +def load_knowledge( + knowledge_doc: str, + sentence_max_length: int + ) -> List[Document]: + """ + Load and split knowledge documents from .txt or .csv formats. + + knowledge_doc (`str`): Path to the knowledge document file. + sentence_max_length (`str`): Maximum length of a sentence in terms of tokens. + """ + text_splitter = ChineseTextSplitter(max_length=sentence_max_length) + if knowledge_doc.lower().endswith(".txt"): + loader = TextLoader(knowledge_doc, autodetect_encoding=True) + docs = loader.load_and_split(text_splitter) + elif knowledge_doc.lower().endswith(".csv"): + loader = CSVLoader(knowledge_doc) + docs = loader.load() + else: + loader = UnstructuredFileLoader(knowledge_doc, mode="elements") + docs = loader.load_and_split(text_splitter=text_splitter) + return docs + +class LocalKnowledgeBase: + """Local Knowledge Base. + + Args: + embedding_path (`Optional[str]`): The path or name of the + pre-trained embedding model used for encoding text. + topk (`int`): The number of most similar knowledge + documents to retrieve for a given query. + knowledge_docs (`List`): Files containing the knowledge base, + supporting txt, csv formats. + sentence_max_length (`int`): Maximum length of a sentence + in terms of tokens for processing. + vector_store_path (`str or os.PathLike`): Path to save or load + pre-computed document vectors. + device (`Optional[str]`): The device (CPU or GPU) to + run the embedding model on. + """ + def __init__( + self, + embedding_path: str, + topk: int, + knowledge_docs: List[str], + sentence_max_length: int, + vector_store_path: str or os.PathLike = None, + device: Optional[str] = None, + ) -> None: + from langchain.embeddings.huggingface import HuggingFaceEmbeddings + self.embeddings = HuggingFaceEmbeddings( + model_name=embedding_path, + model_kwargs={'device': device} + ) + self.topk = topk + + docs = sum([load_knowledge(knowledge_doc=cur_doc, sentence_max_length=sentence_max_length) for cur_doc in knowledge_docs], []) + + if vector_store_path is None: + vector_store_path = os.path.join( + os.path.commonprefix(knowledge_docs).rsplit('/', 1)[0], + "vector_store") + + if os.path.isdir(vector_store_path) and "index.faiss" in os.listdir(vector_store_path): + logger.info(f'Loading from existing vector store ({vector_store_path})...') + self.vector_store = RetrievedFAISS.load_local(vector_store_path, self.embeddings) + self.vector_store.add_documents(docs) + else: + logger.info(f'Constructing vector store ({vector_store_path})...') + self.vector_store = RetrievedFAISS.from_documents(docs, self.embeddings) + + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + self.vector_store.save_local(vector_store_path) + + logger.info(f'Vector store is ready.') + + def retrieve_one(self, query: str, separator: str = ' ') -> str: + """Retrieve the most relevant knowledge documents based on a query.""" + related_docs_with_score = self.vector_store.similarity_search_with_score( + query, + k=self.topk) + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + return separator.join([cur_doc.page_content for cur_doc in related_docs_with_score]) + +@ICL_RETRIEVERS.register_module() +class KnowledgeRetriever(BaseRetriever): + """Local Knowledge Retriever. The retriever returns related local knowledge for all queries. + + Args: + dataset (`BaseDataset`): Any BaseDataset instances. + Attributes of ``reader``, ``train`` and ``test`` will be used. + knowledge_docs (`List`): Files containing the knowledge base, + supporting txt, csv formats. + retrieve_keys (`List`): Keys of the test sample that require + indexing of relevant knowledge. + embedding_path (`Optional[str]`): The path or name of the + pre-trained embedding model used for encoding text. + ice_eos_token (`Optional[str]`): The end of sentence token for + in-context example template when origin `PromptTemplate` is + provided. Defaults to ''. + """ + def __init__(self, + dataset, + knowledge_docs: List, + retrieve_keys: List, + embedding_path: Optional[str] = 'GanymedeNil/text2vec-large-chinese', + ice_eos_token: Optional[str] = '') -> None: + super().__init__(dataset, '', ice_eos_token, 0) + self.knowledge_ds = None + self.retrieve_keys = retrieve_keys + + self.local_knowledge_base = LocalKnowledgeBase( + embedding_path=embedding_path, + knowledge_docs=knowledge_docs, + topk=3, + sentence_max_length=100) + + def retrieve(self) -> List[List]: + """Construct the knowledge base associated with test each sample and retrieve the sequential indices.""" + + logger.info('Retrieving data for test set...') + rtr_idx_list = [[i] for i in range(len(self.test_ds))] + self.knowledge_ds = [ + {'knowledge': '; '.join([ + self.local_knowledge_base.retrieve_one(cur_d[option_key]) + for option_key in self.retrieve_keys + ])} for cur_d in tqdm(self.test_ds)] + return rtr_idx_list + + def generate_ice(self, + idx_list: List[int], + ice_template: Optional[PromptTemplate] = None) -> str: + """Generate the knowledge-related example for one test example. + + Args: + idx_list (`List[int]`): The index of knowledge-related examples for the + test example. + ice_template (`Optional[PromptTemplate]`): The template for + knowledge-related example. Defaults to None. + """ + assert self.knowledge_ds is not None, ( + 'knowledge_ds must be set first in retrieve method') + + if ice_template is None: + assert len( + idx_list + ) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa + + generated_ice_list = [] + for idx in idx_list: + generated_ice_list.append( + ice_template.generate_ice_item( + self.knowledge_ds[idx], + '')) + + generated_ice = self.ice_separator.join( + generated_ice_list) + self.ice_eos_token + return generated_ice