diff --git a/api.py b/api.py
deleted file mode 100644
index 7f95be8d..00000000
--- a/api.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from configs.model_config import *
-from chains.local_doc_qa import LocalDocQA
-import os
-import nltk
-
-import uvicorn
-from fastapi import FastAPI, File, UploadFile
-from pydantic import BaseModel
-from starlette.responses import RedirectResponse
-
-app = FastAPI()
-
-global local_doc_qa, vs_path
-
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 10
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
-
-# Show reply with source text from input document
-REPLY_WITH_SOURCE = False
-
-class Query(BaseModel):
- query: str
-
-@app.get('/')
-async def document():
- return RedirectResponse(url="/docs")
-
-@app.on_event("startup")
-async def get_local_doc_qa():
- global local_doc_qa
- local_doc_qa = LocalDocQA()
- local_doc_qa.init_cfg(llm_model=LLM_MODEL,
- embedding_model=EMBEDDING_MODEL,
- embedding_device=EMBEDDING_DEVICE,
- llm_history_len=LLM_HISTORY_LEN,
- top_k=VECTOR_SEARCH_TOP_K)
-
-
-@app.post("/file")
-async def upload_file(UserFile: UploadFile=File(...),):
- global vs_path
- response = {
- "msg": None,
- "status": 0
- }
- try:
- filepath = './content/' + UserFile.filename
- content = await UserFile.read()
- # print(UserFile.filename)
- with open(filepath, 'wb') as f:
- f.write(content)
- vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
- response = {
- 'msg': 'seccess' if len(files)>0 else 'fail',
- 'status': 1 if len(files)>0 else 0,
- 'loaded_files': files
- }
-
- except Exception as err:
- response["message"] = err
-
- return response
-
-@app.post("/qa")
-async def get_answer(query: str = ""):
- response = {
- "status": 0,
- "message": "",
- "answer": None
- }
- global vs_path
- history = []
- try:
- resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
- vs_path=vs_path,
- chat_history=history)
- if REPLY_WITH_SOURCE:
- response["answer"] = resp
- else:
- response['answer'] = resp["result"]
-
- response["message"] = 'successful'
- response["status"] = 1
-
- except Exception as err:
- response["message"] = err
-
- return response
-
-
-if __name__ == "__main__":
- uvicorn.run(
- app=app,
- host='0.0.0.0',
- port=8100,
- reload=True,
- )
-
diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py
index 6467c0b5..de1a9f6e 100644
--- a/chains/local_doc_qa.py
+++ b/chains/local_doc_qa.py
@@ -141,7 +141,6 @@ class LocalDocQA:
if streaming:
for result, history in self.llm._call(prompt=prompt,
history=chat_history):
- history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
@@ -150,7 +149,6 @@ class LocalDocQA:
else:
result, history = self.llm._call(prompt=prompt,
history=chat_history)
- history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
diff --git a/chains/test.ipynb b/chains/test.ipynb
deleted file mode 100644
index 5183fa29..00000000
--- a/chains/test.ipynb
+++ /dev/null
@@ -1,195 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from langchain.chains.question_answering import load_qa_chain\n",
- "from langchain.prompts import PromptTemplate\n",
- "from lib.embeds import MyEmbeddings\n",
- "from lib.faiss import FAISSVS\n",
- "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
- "from langchain.chains.llm import LLMChain\n",
- "from lib.chatglm_llm import ChatGLM, AlpacaGLM\n",
- "from lib.config import *\n",
- "from lib.utils import get_docs\n",
- "\n",
- "\n",
- "class LocalDocQA:\n",
- " def __init__(self, \n",
- " embedding_model=EMBEDDING_MODEL, \n",
- " embedding_device=EMBEDDING_DEVICE, \n",
- " llm_model=LLM_MODEL, \n",
- " llm_device=LLM_DEVICE, \n",
- " llm_history_len=LLM_HISTORY_LEN, \n",
- " top_k=VECTOR_SEARCH_TOP_K,\n",
- " vs_name = VS_NAME\n",
- " ) -> None:\n",
- " \n",
- " torch.cuda.empty_cache()\n",
- " torch.cuda.empty_cache()\n",
- "\n",
- " self.embedding_model = embedding_model\n",
- " self.llm_model = llm_model\n",
- " self.embedding_device = embedding_device\n",
- " self.llm_device = llm_device\n",
- " self.llm_history_len = llm_history_len\n",
- " self.top_k = top_k\n",
- " self.vs_name = vs_name\n",
- "\n",
- " self.llm = AlpacaGLM()\n",
- " self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n",
- "\n",
- " self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n",
- " self.load_vector_store(vs_name)\n",
- "\n",
- " self.prompt = PromptTemplate(\n",
- " template=PROMPT_TEMPLATE,\n",
- " input_variables=[\"context\", \"question\"]\n",
- " )\n",
- " self.search_params = {\n",
- " \"engine\": \"bing\",\n",
- " \"gl\": \"us\",\n",
- " \"hl\": \"en\",\n",
- " \"serpapi_api_key\": \"\"\n",
- " }\n",
- "\n",
- " def init_knowledge_vector_store(self, vs_name: str):\n",
- " \n",
- " docs = get_docs(KNOWLEDGE_PATH)\n",
- " vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
- " vs_path = VECTORSTORE_PATH + vs_name\n",
- " vector_store.save_local(vs_path)\n",
- "\n",
- " def add_knowledge_to_vector_store(self, vs_name: str):\n",
- " docs = get_docs(ADD_KNOWLEDGE_PATH)\n",
- " new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
- " vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n",
- " vector_store.merge_from(new_vector_store)\n",
- " vector_store.save_local(VECTORSTORE_PATH + vs_name)\n",
- "\n",
- " def load_vector_store(self, vs_name: str):\n",
- " self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n",
- "\n",
- " # def get_search_based_answer(self, query):\n",
- " \n",
- " # search = SerpAPIWrapper(params=self.search_params)\n",
- " # docs = search.run(query)\n",
- " # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n",
- " # answer = search_chain.run(input_documents=docs, question=query)\n",
- "\n",
- " # return answer\n",
- " \n",
- " def get_knowledge_based_answer(self, query):\n",
- " \n",
- " docs = self.vector_store.max_marginal_relevance_search(query)\n",
- " print(f'召回的文档和相似度分数:{docs}')\n",
- " # 这里 doc[1] 就是对应的score \n",
- " docs = [doc[0] for doc in docs]\n",
- " \n",
- " document_prompt = PromptTemplate(\n",
- " input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n",
- " )\n",
- " llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n",
- " combine_documents_chain = StuffDocumentsChain(\n",
- " llm_chain=llm_chain,\n",
- " document_variable_name=\"context\",\n",
- " document_prompt=document_prompt,\n",
- " )\n",
- " answer = combine_documents_chain.run(\n",
- " input_documents=docs, question=query\n",
- " )\n",
- "\n",
- " self.llm.history[-1][0] = query\n",
- " self.llm.history[-1][-1] = answer\n",
- " return answer, docs, self.llm.history"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "d4342213010c4ed2ad5b04694aa436d6",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "qa = LocalDocQA()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "召回的文档和相似度分数:[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n ’ t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h e m y r i a d o f r e f e r e n c e s i n XmlPackageOptions , O u t l i n i n g S t a t e D i r , e t c . , t h a t t h e HtaDotnet a n d ShellcodeLoader s o l u t i o n s w e r e o r i g i n a l l y u n d e r t h e f o l d e r p a t h G:\\\\WebBuilder\\\\Gift_HtaDotnet\\\\ . T h i s i s a l s o s u p p o r t e d b y t h e P D B p a t h s o f o l d e r b u i l t b i n a r i e s w i t h i n t h e b r o a d e r S t r i k e S u i t G i f t p a c k a g e . F r o m l o o k i n g a t D e b u g g e r W a t c h e s v a l u e s i n o t h e r p r o j e c t s , w e c a n s e e t h a t t h e m a l w a r e d e v e l o p e r w a s a c t i v e l y d e b u g g i n g t h e h i s t o r i c a l p r o g r a m s . S U O f i l e D e b u g g e r W a t c h e s WebBuilder/HtaDotNet/HtaDotnet.v11.suo result WebBuilder/ShellcodeLoader/.vs/L/v14/.suo (char)77 WebBuilder/ShellcodeLoader/L.suo (char)77 3 4 04/2022', metadata={'source': './KnowledgeStore/Stairwell-threat-report-The-origin-of-APT32-macros.pdf', 'page': 33}), 0.38091612), (Document(page_content='2 APTs and COVID-19: How advanced persistent threats use the coronavirus as a lureTable of contents Introduction: APT groups using COVID-19 .........................................................', metadata={'source': './KnowledgeStore/200407-MWB-COVID-White-Paper_Final.pdf', 'page': 1}), 0.44476452)]\n"
- ]
- }
- ],
- "source": [
- "query = r\"\"\"make a brief introduction of APT?\"\"\"\n",
- "ans, docs, _ = qa.get_knowledge_based_answer(query)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'\\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ans"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "chatgpt",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- },
- "orig_nbformat": 4
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py
index eb15c334..a0e95d96 100644
--- a/models/chatglm_llm.py
+++ b/models/chatglm_llm.py
@@ -74,14 +74,17 @@ class ChatGLM(LLM):
history: List[List[str]] = [],
stop: Optional[List[str]] = None) -> str:
if self.streaming:
- history = history + [[None, ""]]
- for stream_resp, history in self.model.stream_chat(
+ for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
self.tokenizer,
prompt,
- history=history[-self.history_len:] if self.history_len > 0 else [],
+ history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature,
- ):
+ )):
+ if inum == 0:
+ history += [[prompt, stream_resp]]
+ else:
+ history[-1] = [prompt, stream_resp]
yield stream_resp, history
else:
diff --git a/webui.py b/webui.py
index e75a8318..aaeb7342 100644
--- a/webui.py
+++ b/webui.py
@@ -33,23 +33,23 @@ def get_answer(query, vs_path, history, mode):
if mode == "知识库问答":
if vs_path:
for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=query, vs_path=vs_path, chat_history=history):
- # source = "".join([f""" 出处 {i + 1}
- # {doc.page_content}
- #
- # 所属文件:{doc.metadata["source"]}
- # """ for i, doc in enumerate(resp["source_documents"])])
- # history[-1][-1] += source
+ query=query, vs_path=vs_path, chat_history=history):
+ source = "\n\n"
+ source += "".join(
+ [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
+ f"""{doc.page_content}\n"""
+ f""" """
+ for i, doc in
+ enumerate(resp["source_documents"])])
+ history[-1][-1] += source
yield history, ""
else:
- history = history + [[query, ""]]
- for resp in local_doc_qa.llm._call(query):
+ for resp, history in local_doc_qa.llm._call(query, history):
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
else:
- history = history + [[query, ""]]
- for resp in local_doc_qa.llm._call(query):
+ for resp, history in local_doc_qa.llm._call(query, history):
history[-1][-1] = resp
yield history, ""
@@ -269,9 +269,10 @@ with gr.Blocks(css=block_css) as demo:
outputs=chatbot
)
-demo.queue(concurrency_count=3
- ).launch(server_name='0.0.0.0',
- server_port=7860,
- show_api=False,
- share=False,
- inbrowser=False)
+(demo
+ .queue(concurrency_count=3)
+ .launch(server_name='0.0.0.0',
+ server_port=7860,
+ show_api=False,
+ share=False,
+ inbrowser=False))