diff --git a/api.py b/api.py deleted file mode 100644 index 7f95be8d..00000000 --- a/api.py +++ /dev/null @@ -1,103 +0,0 @@ -from configs.model_config import * -from chains.local_doc_qa import LocalDocQA -import os -import nltk - -import uvicorn -from fastapi import FastAPI, File, UploadFile -from pydantic import BaseModel -from starlette.responses import RedirectResponse - -app = FastAPI() - -global local_doc_qa, vs_path - -nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path - -# return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 10 - -# LLM input history length -LLM_HISTORY_LEN = 3 - -# Show reply with source text from input document -REPLY_WITH_SOURCE = False - -class Query(BaseModel): - query: str - -@app.get('/') -async def document(): - return RedirectResponse(url="/docs") - -@app.on_event("startup") -async def get_local_doc_qa(): - global local_doc_qa - local_doc_qa = LocalDocQA() - local_doc_qa.init_cfg(llm_model=LLM_MODEL, - embedding_model=EMBEDDING_MODEL, - embedding_device=EMBEDDING_DEVICE, - llm_history_len=LLM_HISTORY_LEN, - top_k=VECTOR_SEARCH_TOP_K) - - -@app.post("/file") -async def upload_file(UserFile: UploadFile=File(...),): - global vs_path - response = { - "msg": None, - "status": 0 - } - try: - filepath = './content/' + UserFile.filename - content = await UserFile.read() - # print(UserFile.filename) - with open(filepath, 'wb') as f: - f.write(content) - vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath) - response = { - 'msg': 'seccess' if len(files)>0 else 'fail', - 'status': 1 if len(files)>0 else 0, - 'loaded_files': files - } - - except Exception as err: - response["message"] = err - - return response - -@app.post("/qa") -async def get_answer(query: str = ""): - response = { - "status": 0, - "message": "", - "answer": None - } - global vs_path - history = [] - try: - resp, history = local_doc_qa.get_knowledge_based_answer(query=query, - vs_path=vs_path, - chat_history=history) - if REPLY_WITH_SOURCE: - response["answer"] = resp - else: - response['answer'] = resp["result"] - - response["message"] = 'successful' - response["status"] = 1 - - except Exception as err: - response["message"] = err - - return response - - -if __name__ == "__main__": - uvicorn.run( - app=app, - host='0.0.0.0', - port=8100, - reload=True, - ) - diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 6467c0b5..de1a9f6e 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -141,7 +141,6 @@ class LocalDocQA: if streaming: for result, history in self.llm._call(prompt=prompt, history=chat_history): - history[-1] = list(history[-1]) history[-1][0] = query response = {"query": query, "result": result, @@ -150,7 +149,6 @@ class LocalDocQA: else: result, history = self.llm._call(prompt=prompt, history=chat_history) - history[-1] = list(history[-1]) history[-1][0] = query response = {"query": query, "result": result, diff --git a/chains/test.ipynb b/chains/test.ipynb deleted file mode 100644 index 5183fa29..00000000 --- a/chains/test.ipynb +++ /dev/null @@ -1,195 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.chains.question_answering import load_qa_chain\n", - "from langchain.prompts import PromptTemplate\n", - "from lib.embeds import MyEmbeddings\n", - "from lib.faiss import FAISSVS\n", - "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", - "from langchain.chains.llm import LLMChain\n", - "from lib.chatglm_llm import ChatGLM, AlpacaGLM\n", - "from lib.config import *\n", - "from lib.utils import get_docs\n", - "\n", - "\n", - "class LocalDocQA:\n", - " def __init__(self, \n", - " embedding_model=EMBEDDING_MODEL, \n", - " embedding_device=EMBEDDING_DEVICE, \n", - " llm_model=LLM_MODEL, \n", - " llm_device=LLM_DEVICE, \n", - " llm_history_len=LLM_HISTORY_LEN, \n", - " top_k=VECTOR_SEARCH_TOP_K,\n", - " vs_name = VS_NAME\n", - " ) -> None:\n", - " \n", - " torch.cuda.empty_cache()\n", - " torch.cuda.empty_cache()\n", - "\n", - " self.embedding_model = embedding_model\n", - " self.llm_model = llm_model\n", - " self.embedding_device = embedding_device\n", - " self.llm_device = llm_device\n", - " self.llm_history_len = llm_history_len\n", - " self.top_k = top_k\n", - " self.vs_name = vs_name\n", - "\n", - " self.llm = AlpacaGLM()\n", - " self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n", - "\n", - " self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n", - " self.load_vector_store(vs_name)\n", - "\n", - " self.prompt = PromptTemplate(\n", - " template=PROMPT_TEMPLATE,\n", - " input_variables=[\"context\", \"question\"]\n", - " )\n", - " self.search_params = {\n", - " \"engine\": \"bing\",\n", - " \"gl\": \"us\",\n", - " \"hl\": \"en\",\n", - " \"serpapi_api_key\": \"\"\n", - " }\n", - "\n", - " def init_knowledge_vector_store(self, vs_name: str):\n", - " \n", - " docs = get_docs(KNOWLEDGE_PATH)\n", - " vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", - " vs_path = VECTORSTORE_PATH + vs_name\n", - " vector_store.save_local(vs_path)\n", - "\n", - " def add_knowledge_to_vector_store(self, vs_name: str):\n", - " docs = get_docs(ADD_KNOWLEDGE_PATH)\n", - " new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", - " vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n", - " vector_store.merge_from(new_vector_store)\n", - " vector_store.save_local(VECTORSTORE_PATH + vs_name)\n", - "\n", - " def load_vector_store(self, vs_name: str):\n", - " self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n", - "\n", - " # def get_search_based_answer(self, query):\n", - " \n", - " # search = SerpAPIWrapper(params=self.search_params)\n", - " # docs = search.run(query)\n", - " # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n", - " # answer = search_chain.run(input_documents=docs, question=query)\n", - "\n", - " # return answer\n", - " \n", - " def get_knowledge_based_answer(self, query):\n", - " \n", - " docs = self.vector_store.max_marginal_relevance_search(query)\n", - " print(f'召回的文档和相似度分数:{docs}')\n", - " # 这里 doc[1] 就是对应的score \n", - " docs = [doc[0] for doc in docs]\n", - " \n", - " document_prompt = PromptTemplate(\n", - " input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n", - " )\n", - " llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n", - " combine_documents_chain = StuffDocumentsChain(\n", - " llm_chain=llm_chain,\n", - " document_variable_name=\"context\",\n", - " document_prompt=document_prompt,\n", - " )\n", - " answer = combine_documents_chain.run(\n", - " input_documents=docs, question=query\n", - " )\n", - "\n", - " self.llm.history[-1][0] = query\n", - " self.llm.history[-1][-1] = answer\n", - " return answer, docs, self.llm.history" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d4342213010c4ed2ad5b04694aa436d6", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/3 [00:00 str: if self.streaming: - history = history + [[None, ""]] - for stream_resp, history in self.model.stream_chat( + for inum, (stream_resp, _) in enumerate(self.model.stream_chat( self.tokenizer, prompt, - history=history[-self.history_len:] if self.history_len > 0 else [], + history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, - ): + )): + if inum == 0: + history += [[prompt, stream_resp]] + else: + history[-1] = [prompt, stream_resp] yield stream_resp, history else: diff --git a/webui.py b/webui.py index e75a8318..aaeb7342 100644 --- a/webui.py +++ b/webui.py @@ -33,23 +33,23 @@ def get_answer(query, vs_path, history, mode): if mode == "知识库问答": if vs_path: for resp, history in local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history): - # source = "".join([f"""
出处 {i + 1} - # {doc.page_content} - # - # 所属文件:{doc.metadata["source"]} - #
""" for i, doc in enumerate(resp["source_documents"])]) - # history[-1][-1] += source + query=query, vs_path=vs_path, chat_history=history): + source = "\n\n" + source += "".join( + [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) + history[-1][-1] += source yield history, "" else: - history = history + [[query, ""]] - for resp in local_doc_qa.llm._call(query): + for resp, history in local_doc_qa.llm._call(query, history): history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" else: - history = history + [[query, ""]] - for resp in local_doc_qa.llm._call(query): + for resp, history in local_doc_qa.llm._call(query, history): history[-1][-1] = resp yield history, "" @@ -269,9 +269,10 @@ with gr.Blocks(css=block_css) as demo: outputs=chatbot ) -demo.queue(concurrency_count=3 - ).launch(server_name='0.0.0.0', - server_port=7860, - show_api=False, - share=False, - inbrowser=False) +(demo + .queue(concurrency_count=3) + .launch(server_name='0.0.0.0', + server_port=7860, + show_api=False, + share=False, + inbrowser=False))