mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge e754804c68
into 408f5caff4
This commit is contained in:
commit
b631c1a5ea
@ -0,0 +1,66 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import KnowledgeRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import CHIDDataset_V2
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
chid_knowledge_reader_cfg = dict(
|
||||
input_columns=["content", "A", "B", "C", "D", "E", "F", "G"],
|
||||
output_column="answer",
|
||||
)
|
||||
|
||||
chid_knowledge_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='以下是参考内容:{knowledge},结合上述参考内容,考虑接下来的问题:'
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"</E>{content}\n请选择______处所填的词\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nF. {F}\nG. {G}\n请从“A”,“B”,“C”,“D”,“E”,“F”,“G”中进行选择。答:",
|
||||
),
|
||||
]
|
||||
),
|
||||
ice_token='</E>'
|
||||
),
|
||||
retriever=dict(
|
||||
type=KnowledgeRetriever,
|
||||
knowledge_docs=[
|
||||
'./data/knowledge/chengyu-01-of-02.txt',
|
||||
'./data/knowledge/chengyu-02-of-02.txt',
|
||||
],
|
||||
retrieve_keys=['A', 'B', 'C', 'D', 'E', 'F', 'G'],
|
||||
ice_eos_token='\n'
|
||||
),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
chid_knowledge_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type=first_capital_postprocess),
|
||||
)
|
||||
|
||||
chid_knowledge_datasets = [
|
||||
dict(
|
||||
abbr="chid-dev",
|
||||
type=CHIDDataset_V2,
|
||||
path="./data/FewCLUE/chid/dev_few_all.json",
|
||||
reader_cfg=chid_knowledge_reader_cfg,
|
||||
infer_cfg=chid_knowledge_infer_cfg,
|
||||
eval_cfg=chid_knowledge_eval_cfg,
|
||||
),
|
||||
dict(
|
||||
abbr="chid-test",
|
||||
type=CHIDDataset_V2,
|
||||
path="./data/FewCLUE/chid/test_public.json",
|
||||
reader_cfg=chid_knowledge_reader_cfg,
|
||||
infer_cfg=chid_knowledge_infer_cfg,
|
||||
eval_cfg=chid_knowledge_eval_cfg,
|
||||
),
|
||||
]
|
9
configs/eval_demo_knowledge.py
Normal file
9
configs/eval_demo_knowledge.py
Normal file
@ -0,0 +1,9 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.FewCLUE_chid.FewCLUE_chid_knowledge_gen_0a29a2 import chid_knowledge_datasets
|
||||
from .models.hf_opt_125m import opt125m
|
||||
from .models.hf_opt_350m import opt350m
|
||||
|
||||
datasets = [*chid_knowledge_datasets]
|
||||
models = [opt125m, opt350m]
|
@ -8,3 +8,4 @@ from .icl_sliding_k_retriever import SlidingWindowRetriever # noqa
|
||||
from .icl_topk_retriever import TopkRetriever # noqa
|
||||
from .icl_votek_retriever import VotekRetriever # noqa
|
||||
from .icl_zero_retriever import ZeroRetriever # noqa
|
||||
from .icl_knowledge_retriever import KnowledgeRetriever
|
||||
|
345
opencompass/openicl/icl_retriever/icl_knowledge_retriever.py
Normal file
345
opencompass/openicl/icl_retriever/icl_knowledge_retriever.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""Local Knowledge Retriever."""
|
||||
|
||||
from typing import List, Optional, Callable, Dict, Any
|
||||
|
||||
from opencompass.openicl.icl_retriever import BaseRetriever
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.registry import ICL_RETRIEVERS
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.document_loaders import TextLoader, CSVLoader, UnstructuredFileLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.faiss import dependable_faiss_import
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
|
||||
VECTOR_SEARCH_SCORE_THRESHOLD = 500
|
||||
CHUNK_SIZE = 50
|
||||
|
||||
class RetrievedFAISS(FAISS, VectorStore):
|
||||
def __init__(self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
normalize_L2: bool = False,
|
||||
):
|
||||
super().__init__(embedding_function=embedding_function,
|
||||
index=index,
|
||||
docstore=docstore,
|
||||
index_to_docstore_id=index_to_docstore_id,
|
||||
normalize_L2=normalize_L2)
|
||||
self.score_threshold = VECTOR_SEARCH_SCORE_THRESHOLD
|
||||
self.chunk_size = CHUNK_SIZE
|
||||
self.chunk_conent = False
|
||||
|
||||
def seperate_list(self, lines: List[int]) -> List[List[int]]:
|
||||
results = []
|
||||
cur_line = [lines[0]]
|
||||
docs_source = self.index_to_docstore_source(lines[0])
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i - 1] + 1 == lines[i] and self.index_to_docstore_source(lines[i]) == docs_source:
|
||||
cur_line.append(lines[i])
|
||||
else:
|
||||
results.append(cur_line)
|
||||
cur_line = [lines[i]]
|
||||
docs_source = self.index_to_docstore_source(lines[i])
|
||||
results.append(cur_line)
|
||||
return results
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self, embedding: List[float], k: int = 4
|
||||
) -> List[Document]:
|
||||
faiss = dependable_faiss_import()
|
||||
vector = np.array([embedding], dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k)
|
||||
docs = []
|
||||
id_set = set()
|
||||
store_len = len(self.index_to_docstore_id)
|
||||
rearrange_id_list = False
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1 or 0 < self.score_threshold < scores[0][j]:
|
||||
continue
|
||||
if i in self.index_to_docstore_id:
|
||||
_id = self.index_to_docstore_id[i]
|
||||
else:
|
||||
continue
|
||||
doc = self.docstore.search(_id)
|
||||
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
doc.metadata["score"] = int(scores[0][j])
|
||||
docs.append(doc)
|
||||
continue
|
||||
|
||||
id_set.add(i)
|
||||
docs_len = len(doc.page_content)
|
||||
for k in range(1, max(i, store_len - i)):
|
||||
break_flag = False
|
||||
if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
|
||||
expand_range = [i + k]
|
||||
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
|
||||
expand_range = [i - k]
|
||||
else:
|
||||
expand_range = [i + k, i - k]
|
||||
for l in expand_range:
|
||||
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
|
||||
_id0 = self.index_to_docstore_id[l]
|
||||
doc0 = self.docstore.search(_id0)
|
||||
if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != \
|
||||
doc.metadata["source"]:
|
||||
break_flag = True
|
||||
break
|
||||
elif doc0.metadata["source"] == doc.metadata["source"]:
|
||||
docs_len += len(doc0.page_content)
|
||||
id_set.add(l)
|
||||
rearrange_id_list = True
|
||||
if break_flag:
|
||||
break
|
||||
if (not self.chunk_conent) or (not rearrange_id_list):
|
||||
return docs
|
||||
if len(id_set) == 0 and self.score_threshold > 0:
|
||||
return []
|
||||
id_list = sorted(list(id_set))
|
||||
id_lists = self.seperate_list(id_list)
|
||||
for id_seq in id_lists:
|
||||
for id in id_seq:
|
||||
if id == id_seq[0]:
|
||||
_id = self.index_to_docstore_id[id]
|
||||
doc = deepcopy(self.docstore.search(_id))
|
||||
else:
|
||||
_id0 = self.index_to_docstore_id[id]
|
||||
doc0 = self.docstore.search(_id0)
|
||||
doc.page_content += " " + doc0.page_content
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
|
||||
doc.metadata["score"] = int(doc_score)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def list_docs(self):
|
||||
return list(v.metadata["source"] for v in self.docstore._dict.values())
|
||||
|
||||
def index_to_docstore_source(self,i:int):
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
return doc.metadata["source"]
|
||||
|
||||
class ChineseTextSplitter(CharacterTextSplitter):
|
||||
def __init__(
|
||||
self,
|
||||
max_length: int,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_length = max_length
|
||||
|
||||
def split_text(
|
||||
self,
|
||||
text: str,
|
||||
is_pdf: bool = False,
|
||||
) -> List[str]:
|
||||
if is_pdf:
|
||||
text = re.sub(r"\n{3,}", r"\n", text)
|
||||
text = re.sub('\s', " ", text)
|
||||
text = re.sub("\n\n", "", text)
|
||||
|
||||
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text)
|
||||
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
|
||||
text = text.rstrip()
|
||||
lines = [i for i in text.split("\n") if i]
|
||||
for cur_line in lines:
|
||||
if len(cur_line) > self.max_length:
|
||||
sub_lines1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', cur_line).split("\n")
|
||||
for cur_s_line1 in sub_lines1:
|
||||
if len(cur_s_line1) > self.max_length:
|
||||
sub_lines2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', cur_s_line1).split("\n")
|
||||
for cur_s_line2 in sub_lines2:
|
||||
if len(cur_s_line2) > self.max_length:
|
||||
cur_s_line3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', cur_s_line2)
|
||||
cur_s_idx2 = sub_lines2.index(cur_s_line2)
|
||||
sub_lines2 = sub_lines2[:cur_s_idx2] + [i for i in cur_s_line3.split("\n") if i] + sub_lines2[cur_s_idx2 + 1:]
|
||||
cur_s_idx1 = sub_lines1.index(cur_s_line1)
|
||||
sub_lines1 = sub_lines1[:cur_s_idx1] + [i for i in sub_lines2 if i] + sub_lines1[cur_s_idx1 + 1:]
|
||||
|
||||
cur_idx = lines.index(cur_line)
|
||||
lines = lines[:cur_idx] + [i for i in sub_lines1 if i] + lines[cur_idx + 1:]
|
||||
return lines
|
||||
|
||||
def load_knowledge(
|
||||
knowledge_doc: str,
|
||||
sentence_max_length: int
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load and split knowledge documents from .txt or .csv formats.
|
||||
|
||||
knowledge_doc (`str`): Path to the knowledge document file.
|
||||
sentence_max_length (`str`): Maximum length of a sentence in terms of tokens.
|
||||
"""
|
||||
text_splitter = ChineseTextSplitter(max_length=sentence_max_length)
|
||||
if knowledge_doc.lower().endswith(".txt"):
|
||||
loader = TextLoader(knowledge_doc, autodetect_encoding=True)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
elif knowledge_doc.lower().endswith(".csv"):
|
||||
loader = CSVLoader(knowledge_doc)
|
||||
docs = loader.load()
|
||||
else:
|
||||
loader = UnstructuredFileLoader(knowledge_doc, mode="elements")
|
||||
docs = loader.load_and_split(text_splitter=text_splitter)
|
||||
return docs
|
||||
|
||||
class LocalKnowledgeBase:
|
||||
"""Local Knowledge Base.
|
||||
|
||||
Args:
|
||||
embedding_path (`Optional[str]`): The path or name of the
|
||||
pre-trained embedding model used for encoding text.
|
||||
topk (`int`): The number of most similar knowledge
|
||||
documents to retrieve for a given query.
|
||||
knowledge_docs (`List`): Files containing the knowledge base,
|
||||
supporting txt, csv formats.
|
||||
sentence_max_length (`int`): Maximum length of a sentence
|
||||
in terms of tokens for processing.
|
||||
vector_store_path (`str or os.PathLike`): Path to save or load
|
||||
pre-computed document vectors.
|
||||
device (`Optional[str]`): The device (CPU or GPU) to
|
||||
run the embedding model on.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embedding_path: str,
|
||||
topk: int,
|
||||
knowledge_docs: List[str],
|
||||
sentence_max_length: int,
|
||||
vector_store_path: str or os.PathLike = None,
|
||||
device: Optional[str] = None,
|
||||
) -> None:
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
self.embeddings = HuggingFaceEmbeddings(
|
||||
model_name=embedding_path,
|
||||
model_kwargs={'device': device}
|
||||
)
|
||||
self.topk = topk
|
||||
|
||||
docs = sum([load_knowledge(knowledge_doc=cur_doc, sentence_max_length=sentence_max_length) for cur_doc in knowledge_docs], [])
|
||||
|
||||
if vector_store_path is None:
|
||||
vector_store_path = os.path.join(
|
||||
os.path.commonprefix(knowledge_docs).rsplit('/', 1)[0],
|
||||
"vector_store")
|
||||
|
||||
if os.path.isdir(vector_store_path) and "index.faiss" in os.listdir(vector_store_path):
|
||||
logger.info(f'Loading from existing vector store ({vector_store_path})...')
|
||||
self.vector_store = RetrievedFAISS.load_local(vector_store_path, self.embeddings)
|
||||
self.vector_store.add_documents(docs)
|
||||
else:
|
||||
logger.info(f'Constructing vector store ({vector_store_path})...')
|
||||
self.vector_store = RetrievedFAISS.from_documents(docs, self.embeddings)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
self.vector_store.save_local(vector_store_path)
|
||||
|
||||
logger.info(f'Vector store is ready.')
|
||||
|
||||
def retrieve_one(self, query: str, separator: str = ' ') -> str:
|
||||
"""Retrieve the most relevant knowledge documents based on a query."""
|
||||
related_docs_with_score = self.vector_store.similarity_search_with_score(
|
||||
query,
|
||||
k=self.topk)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
return separator.join([cur_doc.page_content for cur_doc in related_docs_with_score])
|
||||
|
||||
@ICL_RETRIEVERS.register_module()
|
||||
class KnowledgeRetriever(BaseRetriever):
|
||||
"""Local Knowledge Retriever. The retriever returns related local knowledge for all queries.
|
||||
|
||||
Args:
|
||||
dataset (`BaseDataset`): Any BaseDataset instances.
|
||||
Attributes of ``reader``, ``train`` and ``test`` will be used.
|
||||
knowledge_docs (`List`): Files containing the knowledge base,
|
||||
supporting txt, csv formats.
|
||||
retrieve_keys (`List`): Keys of the test sample that require
|
||||
indexing of relevant knowledge.
|
||||
embedding_path (`Optional[str]`): The path or name of the
|
||||
pre-trained embedding model used for encoding text.
|
||||
ice_eos_token (`Optional[str]`): The end of sentence token for
|
||||
in-context example template when origin `PromptTemplate` is
|
||||
provided. Defaults to ''.
|
||||
"""
|
||||
def __init__(self,
|
||||
dataset,
|
||||
knowledge_docs: List,
|
||||
retrieve_keys: List,
|
||||
embedding_path: Optional[str] = 'GanymedeNil/text2vec-large-chinese',
|
||||
ice_eos_token: Optional[str] = '') -> None:
|
||||
super().__init__(dataset, '', ice_eos_token, 0)
|
||||
self.knowledge_ds = None
|
||||
self.retrieve_keys = retrieve_keys
|
||||
|
||||
self.local_knowledge_base = LocalKnowledgeBase(
|
||||
embedding_path=embedding_path,
|
||||
knowledge_docs=knowledge_docs,
|
||||
topk=3,
|
||||
sentence_max_length=100)
|
||||
|
||||
def retrieve(self) -> List[List]:
|
||||
"""Construct the knowledge base associated with test each sample and retrieve the sequential indices."""
|
||||
|
||||
logger.info('Retrieving data for test set...')
|
||||
rtr_idx_list = [[i] for i in range(len(self.test_ds))]
|
||||
self.knowledge_ds = [
|
||||
{'knowledge': '; '.join([
|
||||
self.local_knowledge_base.retrieve_one(cur_d[option_key])
|
||||
for option_key in self.retrieve_keys
|
||||
])} for cur_d in tqdm(self.test_ds)]
|
||||
return rtr_idx_list
|
||||
|
||||
def generate_ice(self,
|
||||
idx_list: List[int],
|
||||
ice_template: Optional[PromptTemplate] = None) -> str:
|
||||
"""Generate the knowledge-related example for one test example.
|
||||
|
||||
Args:
|
||||
idx_list (`List[int]`): The index of knowledge-related examples for the
|
||||
test example.
|
||||
ice_template (`Optional[PromptTemplate]`): The template for
|
||||
knowledge-related example. Defaults to None.
|
||||
"""
|
||||
assert self.knowledge_ds is not None, (
|
||||
'knowledge_ds must be set first in retrieve method')
|
||||
|
||||
if ice_template is None:
|
||||
assert len(
|
||||
idx_list
|
||||
) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa
|
||||
|
||||
generated_ice_list = []
|
||||
for idx in idx_list:
|
||||
generated_ice_list.append(
|
||||
ice_template.generate_ice_item(
|
||||
self.knowledge_ds[idx],
|
||||
''))
|
||||
|
||||
generated_ice = self.ice_separator.join(
|
||||
generated_ice_list) + self.ice_eos_token
|
||||
return generated_ice
|
Loading…
Reference in New Issue
Block a user