This commit is contained in:
Yi-Kai Zhang 2025-05-28 11:38:58 +02:00 committed by GitHub
commit b631c1a5ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 421 additions and 0 deletions

View File

@ -0,0 +1,66 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import KnowledgeRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import CHIDDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess
chid_knowledge_reader_cfg = dict(
input_columns=["content", "A", "B", "C", "D", "E", "F", "G"],
output_column="answer",
)
chid_knowledge_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template='以下是参考内容:{knowledge},结合上述参考内容,考虑接下来的问题:'
),
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role="HUMAN",
prompt=
"</E>{content}\n请选择______处所填的词\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nF. {F}\nG. {G}\n请从“A”“B”“C”“D”“E”“F”“G”中进行选择。答",
),
]
),
ice_token='</E>'
),
retriever=dict(
type=KnowledgeRetriever,
knowledge_docs=[
'./data/knowledge/chengyu-01-of-02.txt',
'./data/knowledge/chengyu-02-of-02.txt',
],
retrieve_keys=['A', 'B', 'C', 'D', 'E', 'F', 'G'],
ice_eos_token='\n'
),
inferencer=dict(type=GenInferencer),
)
chid_knowledge_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess),
)
chid_knowledge_datasets = [
dict(
abbr="chid-dev",
type=CHIDDataset_V2,
path="./data/FewCLUE/chid/dev_few_all.json",
reader_cfg=chid_knowledge_reader_cfg,
infer_cfg=chid_knowledge_infer_cfg,
eval_cfg=chid_knowledge_eval_cfg,
),
dict(
abbr="chid-test",
type=CHIDDataset_V2,
path="./data/FewCLUE/chid/test_public.json",
reader_cfg=chid_knowledge_reader_cfg,
infer_cfg=chid_knowledge_infer_cfg,
eval_cfg=chid_knowledge_eval_cfg,
),
]

View File

@ -0,0 +1,9 @@
from mmengine.config import read_base
with read_base():
from .datasets.FewCLUE_chid.FewCLUE_chid_knowledge_gen_0a29a2 import chid_knowledge_datasets
from .models.hf_opt_125m import opt125m
from .models.hf_opt_350m import opt350m
datasets = [*chid_knowledge_datasets]
models = [opt125m, opt350m]

View File

@ -8,3 +8,4 @@ from .icl_sliding_k_retriever import SlidingWindowRetriever # noqa
from .icl_topk_retriever import TopkRetriever # noqa
from .icl_votek_retriever import VotekRetriever # noqa
from .icl_zero_retriever import ZeroRetriever # noqa
from .icl_knowledge_retriever import KnowledgeRetriever

View File

@ -0,0 +1,345 @@
"""Local Knowledge Retriever."""
from typing import List, Optional, Callable, Dict, Any
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.registry import ICL_RETRIEVERS
from opencompass.utils import get_logger
logger = get_logger(__name__)
import os
import re
import numpy as np
import torch
from copy import deepcopy
from tqdm import tqdm
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader, CSVLoader, UnstructuredFileLoader
from langchain.schema import Document
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
VECTOR_SEARCH_SCORE_THRESHOLD = 500
CHUNK_SIZE = 50
class RetrievedFAISS(FAISS, VectorStore):
def __init__(self,
embedding_function: Callable,
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
normalize_L2: bool = False,
):
super().__init__(embedding_function=embedding_function,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2)
self.score_threshold = VECTOR_SEARCH_SCORE_THRESHOLD
self.chunk_size = CHUNK_SIZE
self.chunk_conent = False
def seperate_list(self, lines: List[int]) -> List[List[int]]:
results = []
cur_line = [lines[0]]
docs_source = self.index_to_docstore_source(lines[0])
for i in range(1, len(lines)):
if lines[i - 1] + 1 == lines[i] and self.index_to_docstore_source(lines[i]) == docs_source:
cur_line.append(lines[i])
else:
results.append(cur_line)
cur_line = [lines[i]]
docs_source = self.index_to_docstore_source(lines[i])
results.append(cur_line)
return results
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Document]:
faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
rearrange_id_list = False
for j, i in enumerate(indices[0]):
if i == -1 or 0 < self.score_threshold < scores[0][j]:
continue
if i in self.index_to_docstore_id:
_id = self.index_to_docstore_id[i]
else:
continue
doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
docs.append(doc)
continue
id_set.add(i)
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != \
doc.metadata["source"]:
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
rearrange_id_list = True
if break_flag:
break
if (not self.chunk_conent) or (not rearrange_id_list):
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
id_list = sorted(list(id_set))
id_lists = self.seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = deepcopy(self.docstore.search(_id))
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += " " + doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
doc.metadata["score"] = int(doc_score)
docs.append(doc)
return docs
def list_docs(self):
return list(v.metadata["source"] for v in self.docstore._dict.values())
def index_to_docstore_source(self,i:int):
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
return doc.metadata["source"]
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(
self,
max_length: int,
**kwargs
):
super().__init__(**kwargs)
self.max_length = max_length
def split_text(
self,
text: str,
is_pdf: bool = False,
) -> List[str]:
if is_pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;.!?。!?\?])([^”’])', r"\1\n\2", text)
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r'\1\n\2', text)
text = text.rstrip()
lines = [i for i in text.split("\n") if i]
for cur_line in lines:
if len(cur_line) > self.max_length:
sub_lines1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r'\1\n\2', cur_line).split("\n")
for cur_s_line1 in sub_lines1:
if len(cur_s_line1) > self.max_length:
sub_lines2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', cur_s_line1).split("\n")
for cur_s_line2 in sub_lines2:
if len(cur_s_line2) > self.max_length:
cur_s_line3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', cur_s_line2)
cur_s_idx2 = sub_lines2.index(cur_s_line2)
sub_lines2 = sub_lines2[:cur_s_idx2] + [i for i in cur_s_line3.split("\n") if i] + sub_lines2[cur_s_idx2 + 1:]
cur_s_idx1 = sub_lines1.index(cur_s_line1)
sub_lines1 = sub_lines1[:cur_s_idx1] + [i for i in sub_lines2 if i] + sub_lines1[cur_s_idx1 + 1:]
cur_idx = lines.index(cur_line)
lines = lines[:cur_idx] + [i for i in sub_lines1 if i] + lines[cur_idx + 1:]
return lines
def load_knowledge(
knowledge_doc: str,
sentence_max_length: int
) -> List[Document]:
"""
Load and split knowledge documents from .txt or .csv formats.
knowledge_doc (`str`): Path to the knowledge document file.
sentence_max_length (`str`): Maximum length of a sentence in terms of tokens.
"""
text_splitter = ChineseTextSplitter(max_length=sentence_max_length)
if knowledge_doc.lower().endswith(".txt"):
loader = TextLoader(knowledge_doc, autodetect_encoding=True)
docs = loader.load_and_split(text_splitter)
elif knowledge_doc.lower().endswith(".csv"):
loader = CSVLoader(knowledge_doc)
docs = loader.load()
else:
loader = UnstructuredFileLoader(knowledge_doc, mode="elements")
docs = loader.load_and_split(text_splitter=text_splitter)
return docs
class LocalKnowledgeBase:
"""Local Knowledge Base.
Args:
embedding_path (`Optional[str]`): The path or name of the
pre-trained embedding model used for encoding text.
topk (`int`): The number of most similar knowledge
documents to retrieve for a given query.
knowledge_docs (`List`): Files containing the knowledge base,
supporting txt, csv formats.
sentence_max_length (`int`): Maximum length of a sentence
in terms of tokens for processing.
vector_store_path (`str or os.PathLike`): Path to save or load
pre-computed document vectors.
device (`Optional[str]`): The device (CPU or GPU) to
run the embedding model on.
"""
def __init__(
self,
embedding_path: str,
topk: int,
knowledge_docs: List[str],
sentence_max_length: int,
vector_store_path: str or os.PathLike = None,
device: Optional[str] = None,
) -> None:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_path,
model_kwargs={'device': device}
)
self.topk = topk
docs = sum([load_knowledge(knowledge_doc=cur_doc, sentence_max_length=sentence_max_length) for cur_doc in knowledge_docs], [])
if vector_store_path is None:
vector_store_path = os.path.join(
os.path.commonprefix(knowledge_docs).rsplit('/', 1)[0],
"vector_store")
if os.path.isdir(vector_store_path) and "index.faiss" in os.listdir(vector_store_path):
logger.info(f'Loading from existing vector store ({vector_store_path})...')
self.vector_store = RetrievedFAISS.load_local(vector_store_path, self.embeddings)
self.vector_store.add_documents(docs)
else:
logger.info(f'Constructing vector store ({vector_store_path})...')
self.vector_store = RetrievedFAISS.from_documents(docs, self.embeddings)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
self.vector_store.save_local(vector_store_path)
logger.info(f'Vector store is ready.')
def retrieve_one(self, query: str, separator: str = ' ') -> str:
"""Retrieve the most relevant knowledge documents based on a query."""
related_docs_with_score = self.vector_store.similarity_search_with_score(
query,
k=self.topk)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return separator.join([cur_doc.page_content for cur_doc in related_docs_with_score])
@ICL_RETRIEVERS.register_module()
class KnowledgeRetriever(BaseRetriever):
"""Local Knowledge Retriever. The retriever returns related local knowledge for all queries.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
knowledge_docs (`List`): Files containing the knowledge base,
supporting txt, csv formats.
retrieve_keys (`List`): Keys of the test sample that require
indexing of relevant knowledge.
embedding_path (`Optional[str]`): The path or name of the
pre-trained embedding model used for encoding text.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to ''.
"""
def __init__(self,
dataset,
knowledge_docs: List,
retrieve_keys: List,
embedding_path: Optional[str] = 'GanymedeNil/text2vec-large-chinese',
ice_eos_token: Optional[str] = '') -> None:
super().__init__(dataset, '', ice_eos_token, 0)
self.knowledge_ds = None
self.retrieve_keys = retrieve_keys
self.local_knowledge_base = LocalKnowledgeBase(
embedding_path=embedding_path,
knowledge_docs=knowledge_docs,
topk=3,
sentence_max_length=100)
def retrieve(self) -> List[List]:
"""Construct the knowledge base associated with test each sample and retrieve the sequential indices."""
logger.info('Retrieving data for test set...')
rtr_idx_list = [[i] for i in range(len(self.test_ds))]
self.knowledge_ds = [
{'knowledge': '; '.join([
self.local_knowledge_base.retrieve_one(cur_d[option_key])
for option_key in self.retrieve_keys
])} for cur_d in tqdm(self.test_ds)]
return rtr_idx_list
def generate_ice(self,
idx_list: List[int],
ice_template: Optional[PromptTemplate] = None) -> str:
"""Generate the knowledge-related example for one test example.
Args:
idx_list (`List[int]`): The index of knowledge-related examples for the
test example.
ice_template (`Optional[PromptTemplate]`): The template for
knowledge-related example. Defaults to None.
"""
assert self.knowledge_ds is not None, (
'knowledge_ds must be set first in retrieve method')
if ice_template is None:
assert len(
idx_list
) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa
generated_ice_list = []
for idx in idx_list:
generated_ice_list.append(
ice_template.generate_ice_item(
self.knowledge_ds[idx],
''))
generated_ice = self.ice_separator.join(
generated_ice_list) + self.ice_eos_token
return generated_ice