mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-25 08:13:30 +08:00
1. make huggingfaceembeddings hashable 2. unify embeddings' loading method for all kbservie 3. make ApiRequest skip empty content when streaming json to avoid dict KeyError
141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
import nltk
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||
|
||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||
import argparse
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.responses import RedirectResponse
|
||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||
search_engine_chat)
|
||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||
update_doc, recreate_vector_store)
|
||
from server.utils import BaseResponse, ListResponse
|
||
|
||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||
|
||
|
||
async def document():
|
||
return RedirectResponse(url="/docs")
|
||
|
||
|
||
def create_app():
|
||
app = FastAPI()
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.get("/",
|
||
response_model=BaseResponse,
|
||
summary="swagger 文档")(document)
|
||
|
||
# Tag: Chat
|
||
app.post("/chat/fastchat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
||
|
||
app.post("/chat/chat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(通过LLMChain)")(chat)
|
||
|
||
app.post("/chat/knowledge_base_chat",
|
||
tags=["Chat"],
|
||
summary="与知识库对话")(knowledge_base_chat)
|
||
|
||
app.post("/chat/search_engine_chat",
|
||
tags=["Chat"],
|
||
summary="与搜索引擎对话")(search_engine_chat)
|
||
|
||
# Tag: Knowledge Base Management
|
||
app.get("/knowledge_base/list_knowledge_bases",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库列表")(list_kbs)
|
||
|
||
app.post("/knowledge_base/create_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="创建知识库"
|
||
)(create_kb)
|
||
|
||
app.delete("/knowledge_base/delete_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库"
|
||
)(delete_kb)
|
||
|
||
app.get("/knowledge_base/list_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库内的文件列表"
|
||
)(list_docs)
|
||
|
||
app.post("/knowledge_base/upload_doc",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="上传文件到知识库"
|
||
)(upload_doc)
|
||
|
||
app.delete("/knowledge_base/delete_doc",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库内的文件"
|
||
)(delete_doc)
|
||
|
||
# app.post("/knowledge_base/update_doc",
|
||
# tags=["Knowledge Base Management"],
|
||
# response_model=BaseResponse,
|
||
# summary="上传文件到知识库,并删除另一个文件"
|
||
# )(update_doc)
|
||
|
||
app.post("/knowledge_base/recreate_vector_store",
|
||
tags=["Knowledge Base Management"],
|
||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||
)(recreate_vector_store)
|
||
return app
|
||
|
||
|
||
app = create_app()
|
||
|
||
|
||
def run_api(host, port, **kwargs):
|
||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||
uvicorn.run(app,
|
||
host=host,
|
||
port=port,
|
||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||
)
|
||
else:
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||
' | 基于本地知识库的 ChatGLM 问答')
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=7861)
|
||
parser.add_argument("--ssl_keyfile", type=str)
|
||
parser.add_argument("--ssl_certfile", type=str)
|
||
# 初始化消息
|
||
args = parser.parse_args()
|
||
args_dict = vars(args)
|
||
run_api(host=args.host,
|
||
port=args.port,
|
||
ssl_keyfile=args.ssl_keyfile,
|
||
ssl_certfile=args.ssl_certfile,
|
||
)
|