mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
- 重构 api.py:
- 按模块划分为不同的 router
- 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
- 添加 /tools 接口,可以获取/调用编写的 agent tools
- 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
- 待办:
- /chat/chat 接口改为 openai 兼容
- 添加 /chat/kb_chat 接口,openai 兼容
- 改变 ntlk/knowledge_base/logs 等数据目录位置
92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import argparse
|
||
from typing import Literal
|
||
|
||
from fastapi import FastAPI, Body
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
from starlette.responses import RedirectResponse
|
||
import uvicorn
|
||
|
||
from configs import VERSION, MEDIA_PATH
|
||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||
from server.api_server.chat_routes import chat_router
|
||
from server.api_server.kb_routes import kb_router
|
||
from server.api_server.openai_routes import openai_router
|
||
from server.api_server.server_routes import server_router
|
||
from server.api_server.tool_routes import tool_router
|
||
from server.chat.completion import completion
|
||
from server.utils import MakeFastAPIOffline
|
||
|
||
|
||
def create_app(run_mode: str=None):
|
||
app = FastAPI(
|
||
title="Langchain-Chatchat API Server",
|
||
version=VERSION
|
||
)
|
||
MakeFastAPIOffline(app)
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
@app.get("/", summary="swagger 文档", include_in_schema=False)
|
||
async def document():
|
||
return RedirectResponse(url="/docs")
|
||
|
||
app.include_router(chat_router)
|
||
app.include_router(kb_router)
|
||
app.include_router(tool_router)
|
||
app.include_router(openai_router)
|
||
app.include_router(server_router)
|
||
|
||
# 其它接口
|
||
app.post("/other/completion",
|
||
tags=["Other"],
|
||
summary="要求llm模型补全(通过LLMChain)",
|
||
)(completion)
|
||
|
||
# 媒体文件
|
||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||
|
||
return app
|
||
|
||
|
||
def run_api(host, port, **kwargs):
|
||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||
uvicorn.run(app,
|
||
host=host,
|
||
port=port,
|
||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||
)
|
||
else:
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
app = create_app()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||
' | 基于本地知识库的 ChatGLM 问答')
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=7861)
|
||
parser.add_argument("--ssl_keyfile", type=str)
|
||
parser.add_argument("--ssl_certfile", type=str)
|
||
# 初始化消息
|
||
args = parser.parse_args()
|
||
args_dict = vars(args)
|
||
|
||
run_api(host=args.host,
|
||
port=args.port,
|
||
ssl_keyfile=args.ssl_keyfile,
|
||
ssl_certfile=args.ssl_certfile,
|
||
)
|