liunux4odoo c7b91bfaf1 update api/kb_doc_api and knowledge base management ui:
1. add update_doc to api which can udpate vector store from existed
   content file
2. add parameter `delete_content` to delete_doc api. user can decide
   whether delete local content file when delete doc.
3. fix bug in ApiReqeust.upload_doc
4. support listing docs existed in local folder bu not in db
2023-08-09 16:52:04 +08:00

141 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import nltk
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
async def document():
return RedirectResponse(url="/docs")
def create_app():
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
# Tag: Chat
app.post("/chat/fastchat",
tags=["Chat"],
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.delete("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
app.delete("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内的文件"
)(delete_doc)
app.post("/knowledge_base/update_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
return app
app = create_app()
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)