From a6184b01beb8bf2b00dd6fdbcf8e5bc678d22af0 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Thu, 13 Apr 2023 23:01:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=A1=B9=E7=9B=AE=E6=9E=B6?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 +- README_en.md | 2 +- chains/local_doc_qa.py | 104 ++++++++++++++++++++ cli_demo.py | 33 +++++++ configs/model_config.py | 31 ++++++ knowledge_based_chatglm.py | 124 ------------------------ chatglm_llm.py => models/chatglm_llm.py | 18 ++-- webui.py | 4 +- 8 files changed, 181 insertions(+), 139 deletions(-) create mode 100644 chains/local_doc_qa.py create mode 100644 cli_demo.py create mode 100644 configs/model_config.py delete mode 100644 knowledge_based_chatglm.py rename chatglm_llm.py => models/chatglm_llm.py (80%) diff --git a/README.md b/README.md index b9819684..e9285bbc 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 +[TOC] + ## 更新信息 **[2023/04/07]** @@ -76,7 +78,7 @@ Web UI 可以实现如下功能: 3. 添加上传文件功能,通过下拉框选择已上传的文件,点击`loading`加载文件,过程中可随时更换加载的文件 4. 底部添加`use via API`可对接到自己系统 -或执行 [knowledge_based_chatglm.py](knowledge_based_chatglm.py) 脚本体验**命令行交互** +或执行 [knowledge_based_chatglm.py](cli_demo.py) 脚本体验**命令行交互** ```commandline python knowledge_based_chatglm.py ``` diff --git a/README_en.md b/README_en.md index b9c350a6..e8b78a4f 100644 --- a/README_en.md +++ b/README_en.md @@ -68,7 +68,7 @@ pip install -r requirements.txt ``` Attention: With langchain.document_loaders.UnstructuredFileLoader used to connect with local knowledge file, you may need some other dependencies as mentioned in [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html) -### 2. Run [knowledge_based_chatglm.py](knowledge_based_chatglm.py) script +### 2. Run [knowledge_based_chatglm.py](cli_demo.py) script ```commandline python knowledge_based_chatglm.py ``` diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py new file mode 100644 index 00000000..94be74b3 --- /dev/null +++ b/chains/local_doc_qa.py @@ -0,0 +1,104 @@ +from langchain.chains import RetrievalQA +from langchain.prompts import PromptTemplate +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.vectorstores import FAISS +from langchain.document_loaders import UnstructuredFileLoader +from models.chatglm_llm import ChatGLM +import sentence_transformers +import os +from configs.model_config import * +import datetime + +# return top-k text chunk from vector store +VECTOR_SEARCH_TOP_K = 10 + +# LLM input history length +LLM_HISTORY_LEN = 3 + +# Show reply with source text from input document +REPLY_WITH_SOURCE = True + + +class LocalDocQA: + llm: object = None + embeddings: object = None + + def init_cfg(self, + embedding_model: str = EMBEDDING_MODEL, + embedding_device=EMBEDDING_DEVICE, + llm_history_len: int = LLM_HISTORY_LEN, + llm_model: str = LLM_MODEL, + llm_device=LLM_DEVICE + ): + self.llm = ChatGLM() + self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], + llm_device=llm_device) + self.llm.history_len = llm_history_len + + self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) + self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, + device=embedding_device) + + def init_knowledge_vector_store(self, + filepath: str): + if not os.path.exists(filepath): + print("路径不存在") + return None + elif os.path.isfile(filepath): + file = os.path.split(filepath)[-1] + try: + loader = UnstructuredFileLoader(filepath, mode="elements") + docs = loader.load() + print(f"{file} 已成功加载") + except: + print(f"{file} 未能成功加载") + return None + elif os.path.isdir(filepath): + docs = [] + for file in os.listdir(filepath): + fullfilepath = os.path.join(filepath, file) + try: + loader = UnstructuredFileLoader(fullfilepath, mode="elements") + docs += loader.load() + print(f"{file} 已成功加载") + except: + print(f"{file} 未能成功加载") + + vector_store = FAISS.from_documents(docs, self.embeddings) + vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" + vector_store.save_local(vs_path) + return vs_path + + def get_knowledge_based_answer(self, + query, + vs_path, + chat_history=[], + top_k=VECTOR_SEARCH_TOP_K): + prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 + 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 + + 已知内容: + {context} + + 问题: + {question}""" + prompt = PromptTemplate( + template=prompt_template, + input_variables=["context", "question"] + ) + self.llm.history = chat_history + vector_store = FAISS.load_local(vs_path, self.embeddings) + knowledge_chain = RetrievalQA.from_llm( + llm=self.llm, + retriever=vector_store.as_retriever(search_kwargs={"k": top_k}), + prompt=prompt + ) + knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( + input_variables=["page_content"], template="{page_content}" + ) + + knowledge_chain.return_source_documents = True + + result = knowledge_chain({"query": query}) + self.llm.history[-1][0] = query + return result, self.llm.history diff --git a/cli_demo.py b/cli_demo.py new file mode 100644 index 00000000..0678e4e8 --- /dev/null +++ b/cli_demo.py @@ -0,0 +1,33 @@ +from configs.model_config import * +import datetime +from chains.local_doc_qa import LocalDocQA + +# return top-k text chunk from vector store +VECTOR_SEARCH_TOP_K = 10 + +# LLM input history length +LLM_HISTORY_LEN = 3 + +# Show reply with source text from input document +REPLY_WITH_SOURCE = True + +if __name__ == "__main__": + local_doc_qa = LocalDocQA() + local_doc_qa.init_cfg(llm_model=LLM_MODEL, + embedding_model=EMBEDDING_MODEL, + embedding_device=EMBEDDING_DEVICE, + llm_history_len=LLM_HISTORY_LEN) + vs_path = None + while not vs_path: + filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") + vs_path = local_doc_qa.init_knowledge_vector_store(filepath) + history = [] + while True: + query = input("Input your question 请输入问题:") + resp, history = local_doc_qa.get_knowledge_based_answer(query=query, + vs_path=vs_path, + chat_history=history) + if REPLY_WITH_SOURCE: + print(resp) + else: + print(resp["result"]) diff --git a/configs/model_config.py b/configs/model_config.py new file mode 100644 index 00000000..640c2f5c --- /dev/null +++ b/configs/model_config.py @@ -0,0 +1,31 @@ +import torch.cuda +import torch.backends + + +embedding_model_dict = { + "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", + "ernie-base": "nghuyong/ernie-3.0-base-zh", + "text2vec": "GanymedeNil/text2vec-large-chinese", + "local": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese" +} + +# Embedding model name +EMBEDDING_MODEL = "local"#"text2vec" + +# Embedding running device +EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + +# supported LLM models +llm_model_dict = { + "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", + "chatglm-6b-int4": "THUDM/chatglm-6b-int4", + "chatglm-6b": "THUDM/chatglm-6b", + "local": "/Users/liuqian/Downloads/ChatGLM-6B/chatglm-6b" +} + +# LLM model name +LLM_MODEL = "local"#"chatglm-6b" + +# LLM running device +LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + diff --git a/knowledge_based_chatglm.py b/knowledge_based_chatglm.py deleted file mode 100644 index 6a7601bf..00000000 --- a/knowledge_based_chatglm.py +++ /dev/null @@ -1,124 +0,0 @@ -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores import FAISS -from langchain.document_loaders import UnstructuredFileLoader -from chatglm_llm import ChatGLM -import sentence_transformers -import torch -import os -import readline - - -# Global Parameters -EMBEDDING_MODEL = "text2vec" -VECTOR_SEARCH_TOP_K = 6 -LLM_MODEL = "chatglm-6b" -LLM_HISTORY_LEN = 3 -DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - -# Show reply with source text from input document -REPLY_WITH_SOURCE = True - -embedding_model_dict = { - "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", - "ernie-base": "nghuyong/ernie-3.0-base-zh", - "text2vec": "GanymedeNil/text2vec-large-chinese", -} - -llm_model_dict = { - "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", - "chatglm-6b-int4": "THUDM/chatglm-6b-int4", - "chatglm-6b": "THUDM/chatglm-6b", -} - - -def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6): - global chatglm, embeddings, VECTOR_SEARCH_TOP_K - VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K - - chatglm = ChatGLM() - chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL]) - chatglm.history_len = LLM_HISTORY_LEN - - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],) - embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name, - device=DEVICE) - - -def init_knowledge_vector_store(filepath:str): - if not os.path.exists(filepath): - print("路径不存在") - return None - elif os.path.isfile(filepath): - file = os.path.split(filepath)[-1] - try: - loader = UnstructuredFileLoader(filepath, mode="elements") - docs = loader.load() - print(f"{file} 已成功加载") - except: - print(f"{file} 未能成功加载") - return None - elif os.path.isdir(filepath): - docs = [] - for file in os.listdir(filepath): - fullfilepath = os.path.join(filepath, file) - try: - loader = UnstructuredFileLoader(fullfilepath, mode="elements") - docs += loader.load() - print(f"{file} 已成功加载") - except: - print(f"{file} 未能成功加载") - - vector_store = FAISS.from_documents(docs, embeddings) - return vector_store - - -def get_knowledge_based_answer(query, vector_store, chat_history=[]): - global chatglm, embeddings - - prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 -如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 - -已知内容: -{context} - -问题: -{question}""" - prompt = PromptTemplate( - template=prompt_template, - input_variables=["context", "question"] - ) - chatglm.history = chat_history - knowledge_chain = RetrievalQA.from_llm( - llm=chatglm, - retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}), - prompt=prompt - ) - knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( - input_variables=["page_content"], template="{page_content}" - ) - - knowledge_chain.return_source_documents = True - - result = knowledge_chain({"query": query}) - chatglm.history[-1][0] = query - return result, chatglm.history - - -if __name__ == "__main__": - init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN) - vector_store = None - while not vector_store: - filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") - vector_store = init_knowledge_vector_store(filepath) - history = [] - while True: - query = input("Input your question 请输入问题:") - resp, history = get_knowledge_based_answer(query=query, - vector_store=vector_store, - chat_history=history) - if REPLY_WITH_SOURCE: - print(resp) - else: - print(resp["result"]) diff --git a/chatglm_llm.py b/models/chatglm_llm.py similarity index 80% rename from chatglm_llm.py rename to models/chatglm_llm.py index aceb984c..7cf3b24a 100644 --- a/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -3,8 +3,9 @@ from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel import torch +from configs.model_config import LLM_DEVICE -DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +DEVICE = LLM_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE @@ -48,12 +49,14 @@ class ChatGLM(LLM): self.history = self.history+[[None, response]] return response - def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b"): + def load_model(self, + model_name_or_path: str = "THUDM/chatglm-6b", + llm_device=LLM_DEVICE): self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True ) - if torch.cuda.is_available(): + if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): self.model = ( AutoModel.from_pretrained( model_name_or_path, @@ -61,19 +64,12 @@ class ChatGLM(LLM): .half() .cuda() ) - elif torch.backends.mps.is_available(): - self.model = ( - AutoModel.from_pretrained( - model_name_or_path, - trust_remote_code=True) - .float() - .to('mps') - ) else: self.model = ( AutoModel.from_pretrained( model_name_or_path, trust_remote_code=True) .float() + .to(llm_device) ) self.model = self.model.eval() diff --git a/webui.py b/webui.py index b28d2bd6..0e19e1a0 100644 --- a/webui.py +++ b/webui.py @@ -1,7 +1,7 @@ import gradio as gr import os import shutil -import knowledge_based_chatglm as kb +import cli_demo as kb def get_file_list(): @@ -108,7 +108,7 @@ with gr.Blocks(css=""" value=file_list[0] if len(file_list) > 0 else None) with gr.Tab("upload"): file = gr.File(label="content file", - file_types=['.txt', '.md', '.docx'] + file_types=['.txt', '.md', '.docx', '.pdf'] ).style(height=100) # 将上传的文件保存到content文件夹下,并更新下拉框 file.upload(upload_file,