diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index de1a9f6e..84f820c9 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -1,16 +1,13 @@ -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.document_loaders import UnstructuredFileLoader from models.chatglm_llm import ChatGLM -import sentence_transformers -import os from configs.model_config import * import datetime -from typing import List from textsplitter import ChineseTextSplitter +from typing import List, Tuple from langchain.docstore.document import Document +import numpy as np # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 6 @@ -48,10 +45,70 @@ def get_docs_with_score(docs_with_score): docs.append(doc) return docs + +def seperate_list(ls: List[int]) -> List[List[int]]: + lists = [] + ls1 = [ls[0]] + for i in range(1, len(ls)): + if ls[i-1] + 1 == ls[i]: + ls1.append(ls[i]) + else: + lists.append(ls1) + ls1 = [ls[i]] + lists.append(ls1) + return lists + + + +def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + ) -> List[Tuple[Document, float]]: + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + id_set = set() + for j, i in enumerate(indices[0]): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + id_set.add(i) + docs_len = len(doc.page_content) + for k in range(1, max(i, len(docs)-i)): + for l in [i+k, i-k]: + if 0 <= l < len(self.index_to_docstore_id): + _id0 = self.index_to_docstore_id[l] + doc0 = self.docstore.search(_id0) + if docs_len + len(doc0.page_content) > self.chunk_size: + break + elif doc0.metadata["source"] == doc.metadata["source"]: + docs_len += len(doc0.page_content) + id_set.add(l) + id_list = sorted(list(id_set)) + id_lists = seperate_list(id_list) + for id_seq in id_lists: + for id in id_seq: + if id == id_seq[0]: + _id = self.index_to_docstore_id[id] + doc = self.docstore.search(_id) + else: + _id0 = self.index_to_docstore_id[id] + doc0 = self.docstore.search(_id0) + doc.page_content += doc0.page_content + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs.append((doc, scores[0][j])) + return docs + + + class LocalDocQA: llm: object = None embeddings: object = None top_k: int = VECTOR_SEARCH_TOP_K + chunk_size: int = CHUNK_SIZE def init_cfg(self, embedding_model: str = EMBEDDING_MODEL, @@ -133,6 +190,8 @@ class LocalDocQA: streaming=True): self.llm.streaming = streaming vector_store = FAISS.load_local(vs_path, self.embeddings) + FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector + vector_store.chunk_size=self.chunk_size related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) related_docs = get_docs_with_score(related_docs_with_score) diff --git a/configs/model_config.py b/configs/model_config.py index b0093bc7..1afc5379 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -39,4 +39,7 @@ UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "con # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" PROMPT_TEMPLATE = """基于以下已知信息,简洁和专业的来回答用户的问题,问题是"{question}"。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。已知内容如下: -{context} """ \ No newline at end of file +{context} """ + +# 匹配后单段上下文长度 +CHUNK_SIZE = 500 \ No newline at end of file