2024-03-25 15:31:49 +08:00

92 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
from typing import Literal
from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse
import uvicorn
from configs import VERSION, MEDIA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
from chatchat_server.api_server.chat_routes import chat_router
from chatchat_server.api_server.kb_routes import kb_router
from chatchat_server.api_server.openai_routes import openai_router
from chatchat_server.api_server.server_routes import server_router
from chatchat_server.api_server.tool_routes import tool_router
from chatchat_server.chat.completion import completion
from chatchat_server.utils import MakeFastAPIOffline
def create_app(run_mode: str=None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
)
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", summary="swagger 文档", include_in_schema=False)
async def document():
return RedirectResponse(url="/docs")
app.include_router(chat_router)
app.include_router(kb_router)
app.include_router(tool_router)
app.include_router(openai_router)
app.include_router(server_router)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
# 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
return app
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
app = create_app()
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)