Merge pull request #1198 from chatchat-space/dev

pg similarity_search 转 similarity_search_with_score
This commit is contained in:
zqt996 2023-08-22 16:41:37 +08:00 committed by GitHub
commit cb29e1ed23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 682 additions and 148 deletions

View File

@ -1,4 +1,4 @@
from .model_config import *
from .server_config import *
VERSION = "v0.2.1"
VERSION = "v0.2.2-preview"

View File

@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
if __name__ == "__main__":
import argparse
@ -21,6 +23,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
dump_server_info()
create_tables()
print("database talbes created")

109
server/chat/github_chat.py Normal file
View File

@ -0,0 +1,109 @@
from langchain.document_loaders.github import GitHubIssuesLoader
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Literal
from server.chat.utils import History
from langchain.docstore.document import Document
import json
import os
from functools import lru_cache
from datetime import datetime
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
@lru_cache(1)
def load_issues(tick: str):
'''
set tick to a periodic value to refresh cache
'''
loader = GitHubIssuesLoader(
repo="chatchat-space/langchain-chatglm",
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
include_prs=True,
state="all",
)
docs = loader.load()
return docs
def
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
include_prs: bool = Body(True, description="是否包含PR"),
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
creator: str = Body(None, description="创建者"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "介绍一下本项目"},
{"role": "assistant",
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
async def chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
media_type="text/event-stream")

View File

@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
):
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
):
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs(
knowledge_base_name: str
):
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -41,13 +41,14 @@ async def list_docs(
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names)
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
):
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
kb.add_doc(kb_file)
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
更新知识库文档
'''
@ -113,14 +127,17 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(
@ -137,18 +154,20 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
@ -163,24 +182,35 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb):
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
kb.add_doc(kb_file)
except Exception as e:
print(e)
async def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
})
return StreamingResponse(output(kb), media_type="text/event-stream")
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -71,36 +71,37 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings)
self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file)
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile):
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
使用content中的文件更新向量库
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file)
return self.add_doc(kb_file)
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,

View File

@ -41,7 +41,23 @@ def load_vector_store(
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
@ -50,6 +66,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@ -74,8 +91,10 @@ class FaissKBService(KBService):
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name)
def do_drop_kb(self):
self.clear_vs()
shutil.rmtree(self.kb_path)
def do_search(self,
@ -93,38 +112,40 @@ class FaissKBService(KBService):
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(
docs, embeddings, normalize_L2=True) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: KnowledgeFile):
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store.add_documents(docs)
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
else:
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):

View File

@ -46,7 +46,7 @@ class PGKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)
return self.pg_vector.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile):
"""

View File

@ -43,7 +43,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -67,7 +71,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.update_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -81,7 +89,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:

View File

@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
class Config:
schema_extra = {

View File

@ -252,35 +252,36 @@ def run_webui(q: Queue, run_seq: int = 5):
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat controller/openai_api servers",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
dest="model_worker",
)
parser.add_argument(
@ -315,13 +316,39 @@ def parse_args() -> argparse.ArgumentParser:
return args
if __name__ == "__main__":
def dump_server_info(after_start=False):
import platform
import time
import langchain
import fastchat
from configs.server_config import api_address, webui_address
print("\n\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
if __name__ == "__main__":
import time
mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
@ -343,6 +370,7 @@ if __name__ == "__main__":
args.api = False
args.webui = False
dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
@ -403,27 +431,7 @@ if __name__ == "__main__":
no = queue.get()
if no == len(processes):
time.sleep(0.5)
print("\n\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print("请确认llm_model_dict中配置的api_base_url与上面地址一致。")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
dump_server_info(True)
break
else:
queue.put(no)

204
tests/api/test_kb_api.py Normal file
View File

@ -0,0 +1,204 @@
from doctest import testfile
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
}
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
return
url = api_base_url + api
print("\n测试知识库存在,需要删除")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
url = api_base_url + api
print(f"\n尝试用空名称创建知识库:")
r = requests.post(url, json={"knowledge_base_name": " "})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
url = api_base_url + api
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
def test_list_docs(api="/knowledge_base/list_docs"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list)
for name in test_files:
assert name in data["data"]
def test_search_docs(api="/knowledge_base/search_docs"):
url = api_base_url + api
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
def test_delete_doc(api="/knowledge_base/delete_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
url = api_base_url + api
print("\n重建知识库:")
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
url = api_base_url + "/knowledge_base/search_docs"
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
url = api_base_url + api
print("\n删除知识库")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]

View File

@ -0,0 +1,108 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.server_config import API_SERVER, api_address
from pprint import pprint
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {
"stream": True,
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
}
dump_input(data2, api)
response = requests.post(url, headers=headers, json=data2, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_chat_chat(api="/chat/chat"):
url = f"{api_base_url}{api}"
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if first:
for doc in data["docs"]:
print(doc)
first = False
print(data["answer"], end="", flush=True)
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
dump_input(data, api)
response = requests.post(url, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200

View File

@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest):
vector_store_type=vs_type,
embed_model=embed_model,
)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
st.experimental_rerun()
@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
for f in files:
ret = api.upload_kb_doc(f, kb)
if ret["code"] == 200:
st.toast(ret["msg"], icon="")
else:
st.toast(ret["msg"], icon="")
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
st.divider()
@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest):
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.experimental_rerun()
st.divider()
@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
type="primary",
):
with st.spinner("向量库重构中"):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
print(d)
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun()
if cols[2].button(
@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
):
ret = api.delete_knowledge_base(kb)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
time.sleep(1)
st.experimental_rerun()

View File

@ -229,18 +229,18 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "errorMsg": str(e)}
yield {"code": 500, "msg": str(e)}
# 对话相关操作
@ -394,7 +394,7 @@ class ApiRequest:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "errorMsg": errorMsg or str(e)}
return {"code": 500, "msg": errorMsg or str(e)}
def list_knowledge_bases(
self,
@ -496,6 +496,7 @@ class ApiRequest:
knowledge_base_name: str,
filename: str = None,
override: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -529,7 +530,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name, "override": override},
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
)
return self._check_httpx_json_response(response)
@ -539,6 +544,7 @@ class ApiRequest:
knowledge_base_name: str,
doc_name: str,
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -551,6 +557,7 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
@ -568,6 +575,7 @@ class ApiRequest:
self,
knowledge_base_name: str,
file_name: str,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -583,7 +591,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
)
return self._check_httpx_json_response(response)
@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict) and key in data:
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
'''
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""