import argparse from typing import Literal from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import RedirectResponse import uvicorn from configs import VERSION, MEDIA_PATH from configs.server_config import OPEN_CROSS_DOMAIN from server.api_server.chat_routes import chat_router from server.api_server.kb_routes import kb_router from server.api_server.openai_routes import openai_router from server.api_server.server_routes import server_router from server.api_server.tool_routes import tool_router from server.chat.completion import completion from server.utils import MakeFastAPIOffline def create_app(run_mode: str=None): app = FastAPI( title="Langchain-Chatchat API Server", version=VERSION ) MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 # set OPEN_DOMAIN=True in config.py to allow cross-domain if OPEN_CROSS_DOMAIN: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", summary="swagger 文档", include_in_schema=False) async def document(): return RedirectResponse(url="/docs") app.include_router(chat_router) app.include_router(kb_router) app.include_router(tool_router) app.include_router(openai_router) app.include_router(server_router) # 其它接口 app.post("/other/completion", tags=["Other"], summary="要求llm模型补全(通过LLMChain)", )(completion) # 媒体文件 app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") return app def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"), ssl_certfile=kwargs.get("ssl_certfile"), ) else: uvicorn.run(app, host=host, port=port) app = create_app() if __name__ == "__main__": parser = argparse.ArgumentParser(prog='langchain-ChatGLM', description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain' ' | 基于本地知识库的 ChatGLM 问答') parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) # 初始化消息 args = parser.parse_args() args_dict = vars(args) run_api(host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, )