diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 433ec60d..bc523e33 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -76,25 +76,34 @@ LLM_MODEL_CONFIG = { }, } -# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。 +# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。 +# - platform_name 可以任意填写,不要重复即可 +# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分 +# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。 + MODEL_PLATFORMS = [ - { - "platform_name": "openai-api", - "platform_type": "openai", - "llm_models": [ - "gpt-3.5-turbo", - ], - "embed_models": [], - "image_models": [], - "multimodal_models": [], - "api_base_url": "https://api.openai.com/v1", - "api_key": "sk-", - "api_proxy": "", - }, + # { + # "platform_name": "openai-api", + # "platform_type": "openai", + # "api_base_url": "https://api.openai.com/v1", + # "api_key": "sk-", + # "api_proxy": "", + # "api_concurrencies": 5, + # "llm_models": [ + # "gpt-3.5-turbo", + # ], + # "embed_models": [], + # "image_models": [], + # "multimodal_models": [], + # }, { "platform_name": "xinference", "platform_type": "xinference", + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + "api_concurrencies": 5, + # 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称 "llm_models": [ "chatglm3-6b", ], @@ -107,40 +116,38 @@ MODEL_PLATFORMS = [ "multimodal_models": [ "qwen-vl", ], - "api_base_url": "http://127.0.0.1:9997/v1", - "api_key": "EMPTY", }, - { - "platform_name": "oneapi", - "platform_type": "oneapi", - "api_key": "", - "llm_models": [ - "qwen-turbo", - "qwen-plus", - "chatglm_turbo", - "chatglm_std", - ], - "embed_models": [], - "image_models": [], - "multimodal_models": [], - "api_base_url": "http://127.0.0.1:3000/v1", - "api_key": "sk-xxx", - }, + # { + # "platform_name": "oneapi", + # "platform_type": "oneapi", + # "api_base_url": "http://127.0.0.1:3000/v1", + # "api_key": "", + # "api_concurrencies": 5, + # "llm_models": [ + # "qwen-turbo", + # "qwen-plus", + # "chatglm_turbo", + # "chatglm_std", + # ], + # "embed_models": [], + # "image_models": [], + # "multimodal_models": [], + # }, - { - "platform_name": "loom", - "platform_type": "loom", - "api_key": "", - "llm_models": [ - "chatglm3-6b", - ], - "embed_models": [], - "image_models": [], - "multimodal_models": [], - "api_base_url": "http://127.0.0.1:7860/v1", - "api_key": "EMPTY", - }, + # { + # "platform_name": "loom", + # "platform_type": "loom", + # "api_base_url": "http://127.0.0.1:7860/v1", + # "api_key": "", + # "api_concurrencies": 5, + # "llm_models": [ + # "chatglm3-6b", + # ], + # "embed_models": [], + # "image_models": [], + # "multimodal_models": [], + # }, ] LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml") diff --git a/init_database.py b/init_database.py index 1ca0fa60..f2e2dc59 100644 --- a/init_database.py +++ b/init_database.py @@ -2,7 +2,7 @@ import sys sys.path.append(".") from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL +from configs.model_config import DEFAULT_EMBEDDING_MODEL from datetime import datetime diff --git a/requirements.txt b/requirements.txt index 97fa093b..d758b96a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,6 +41,38 @@ youtube-search==2.1.2 duckduckgo-search==3.9.9 metaphor-python==0.1.23 +httpx==0.26.0 +httpx_sse==0.4.0 +watchdog==3.0.0 +pyjwt==2.8.0 +elasticsearch +numexpr>=2.8.8 +strsimpy>=0.2.1 +markdownify>=0.11.6 +tqdm>=4.66.1 +websockets>=12.0 +numpy>=1.26.3 +pandas~=2.1.4 +pydantic<2 +httpx[brotli,http2,socks]>=0.25.2 + +# optional document loaders + +# rapidocr_paddle[gpu]>=1.3.0.post5 +# jq>=1.6.0 +# html2text +# beautifulsoup4>=4.12.2 +# pysrt>=1.1.2 + +# Agent and Search Tools + +# arxiv>=2.1.0 +# youtube-search>=2.1.2 +# duckduckgo-search>=4.1.0 +# metaphor-python>=0.1.23 + +# WebUI requirements + streamlit==1.30.0 streamlit-option-menu==0.3.12 streamlit-antd-components==0.3.1 @@ -48,7 +80,3 @@ streamlit-chatbox==1.1.11 streamlit-modal==0.1.0 streamlit-aggrid==0.3.4.post3 -httpx==0.26.0 -httpx_sse==0.4.0 -watchdog==3.0.0 -pyjwt==2.8.0 diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index 35ab19c7..d4173e5f 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -16,10 +16,8 @@ from langchain.output_parsers import OutputFixingParser from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool -from langchain.pydantic_v1 import Field +from server.pydantic_types import Field, typing, model_schema -from pydantic import typing -from pydantic.schema import model_schema logger = logging.getLogger(__name__) diff --git a/server/agent/tools_factory/arxiv.py b/server/agent/tools_factory/arxiv.py index 43129463..97eae161 100644 --- a/server/agent/tools_factory/arxiv.py +++ b/server/agent/tools_factory/arxiv.py @@ -1,6 +1,8 @@ # LangChain 的 ArxivQueryRun 工具 -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field from langchain.tools.arxiv.tool import ArxivQueryRun + + def arxiv(query: str): tool = ArxivQueryRun() return tool.run(tool_input=query) diff --git a/server/agent/tools_factory/audio_factory/aqa.py b/server/agent/tools_factory/audio_factory/aqa.py index 4d053ecc..337090d9 100644 --- a/server/agent/tools_factory/audio_factory/aqa.py +++ b/server/agent/tools_factory/audio_factory/aqa.py @@ -1,6 +1,6 @@ import base64 import os -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field def save_base64_audio(base64_audio, file_path): audio_data = base64.b64decode(base64_audio) diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py index c893e548..e66292f6 100644 --- a/server/agent/tools_factory/calculate.py +++ b/server/agent/tools_factory/calculate.py @@ -1,4 +1,4 @@ -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field def calculate(a: float, b: float, operator: str) -> float: if operator == "+": diff --git a/server/agent/tools_factory/search_internet.py b/server/agent/tools_factory/search_internet.py index 870156db..1e53644b 100644 --- a/server/agent/tools_factory/search_internet.py +++ b/server/agent/tools_factory/search_internet.py @@ -1,4 +1,4 @@ -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from configs import TOOL_CONFIG diff --git a/server/agent/tools_factory/search_local_knowledgebase.py b/server/agent/tools_factory/search_local_knowledgebase.py index ff9614e1..c3205709 100644 --- a/server/agent/tools_factory/search_local_knowledgebase.py +++ b/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,5 +1,5 @@ from urllib.parse import urlencode -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field from server.knowledge_base.kb_doc_api import search_docs from configs import TOOL_CONFIG diff --git a/server/agent/tools_factory/search_youtube.py b/server/agent/tools_factory/search_youtube.py index e7737eb4..3b5b939e 100644 --- a/server/agent/tools_factory/search_youtube.py +++ b/server/agent/tools_factory/search_youtube.py @@ -1,5 +1,7 @@ from langchain_community.tools import YouTubeSearchTool -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field + + def search_youtube(query: str): tool = YouTubeSearchTool() return tool.run(tool_input=query) diff --git a/server/agent/tools_factory/shell.py b/server/agent/tools_factory/shell.py index c8f7ddfe..ea902f78 100644 --- a/server/agent/tools_factory/shell.py +++ b/server/agent/tools_factory/shell.py @@ -1,6 +1,8 @@ # LangChain 的 Shell 工具 -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field from langchain_community.tools import ShellTool + + def shell(query: str): tool = ShellTool() return tool.run(tool_input=query) diff --git a/server/agent/tools_factory/text2image.py b/server/agent/tools_factory/text2image.py index c6983666..5e5800b5 100644 --- a/server/agent/tools_factory/text2image.py +++ b/server/agent/tools_factory/text2image.py @@ -6,9 +6,8 @@ from typing import List import uuid from langchain.agents import tool -from langchain.pydantic_v1 import Field +from server.pydantic_types import Field, FieldInfo import openai -from pydantic.fields import FieldInfo from configs.basic_config import MEDIA_PATH from server.utils import MsgType diff --git a/server/agent/tools_factory/vision_factory/vqa.py b/server/agent/tools_factory/vision_factory/vqa.py index bbf57a8b..f39aa593 100644 --- a/server/agent/tools_factory/vision_factory/vqa.py +++ b/server/agent/tools_factory/vision_factory/vqa.py @@ -4,7 +4,7 @@ Method Use cogagent to generate response for a given image and query. import base64 from io import BytesIO from PIL import Image, ImageDraw -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field from configs import TOOL_CONFIG import re from server.agent.container import container diff --git a/server/agent/tools_factory/weather_check.py b/server/agent/tools_factory/weather_check.py index db52860f..b15f90e6 100644 --- a/server/agent/tools_factory/weather_check.py +++ b/server/agent/tools_factory/weather_check.py @@ -1,7 +1,7 @@ """ 简单的单参数输入工具实现,用于查询现在天气的情况 """ -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field import requests def weather(location: str, api_key: str): diff --git a/server/agent/tools_factory/wolfram.py b/server/agent/tools_factory/wolfram.py index 45ef0f0a..785d07e7 100644 --- a/server/agent/tools_factory/wolfram.py +++ b/server/agent/tools_factory/wolfram.py @@ -1,7 +1,9 @@ # Langchain 自带的 Wolfram Alpha API 封装 from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper -from langchain.pydantic_v1 import BaseModel, Field +from server.pydantic_types import BaseModel, Field wolfram_alpha_appid = "your key" + + def wolfram(query: str): wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) ans = wolfram.run(query) diff --git a/server/api.py b/server/api.py deleted file mode 100644 index 7c02b17a..00000000 --- a/server/api.py +++ /dev/null @@ -1,225 +0,0 @@ -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(__file__))) - -from configs import VERSION, MEDIA_PATH -from configs.server_config import OPEN_CROSS_DOMAIN -import argparse -import uvicorn -from fastapi import Body -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from starlette.responses import RedirectResponse -from server.chat.chat import chat -from server.chat.completion import completion -from server.chat.feedback import chat_feedback -from server.knowledge_base.model.kb_document_model import DocumentWithVSId -from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, - get_server_configs, get_prompt_template) -from typing import List, Literal - - -async def document(): - return RedirectResponse(url="/docs") - - -def create_app(run_mode: str = None): - app = FastAPI( - title="Langchain-Chatchat API Server", - version=VERSION - ) - MakeFastAPIOffline(app) - # Add CORS middleware to allow all origins - # 在config.py中设置OPEN_DOMAIN=True,允许跨域 - # set OPEN_DOMAIN=True in config.py to allow cross-domain - if OPEN_CROSS_DOMAIN: - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - mount_app_routes(app, run_mode=run_mode) - return app - - -def mount_app_routes(app: FastAPI, run_mode: str = None): - app.get("/", - response_model=BaseResponse, - summary="swagger 文档")(document) - - # Tag: Chat - app.post("/chat/chat", - tags=["Chat"], - summary="与llm模型对话(通过LLMChain)", - )(chat) - - app.post("/chat/feedback", - tags=["Chat"], - summary="返回llm模型对话评分", - )(chat_feedback) - - # 知识库相关接口 - mount_knowledge_routes(app) - # 摘要相关接口 - mount_filename_summary_routes(app) - - # 服务器相关接口 - app.post("/server/configs", - tags=["Server State"], - summary="获取服务器原始配置信息", - )(get_server_configs) - - - @app.post("/server/get_prompt_template", - tags=["Server State"], - summary="获取服务区配置的 prompt 模板") - def get_server_prompt_template( - type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"), - name: str = Body("default", description="模板名称"), - ) -> str: - return get_prompt_template(type=type, name=name) - - # 其它接口 - app.post("/other/completion", - tags=["Other"], - summary="要求llm模型补全(通过LLMChain)", - )(completion) - - # 媒体文件 - app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") - - -def mount_knowledge_routes(app: FastAPI): - from server.chat.file_chat import upload_temp_docs, file_chat - from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb - from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, - update_docs, download_doc, recreate_vector_store, - search_docs, update_info) - - app.post("/chat/file_chat", - tags=["Knowledge Base Management"], - summary="文件对话" - )(file_chat) - - app.get("/knowledge_base/list_knowledge_bases", - tags=["Knowledge Base Management"], - response_model=ListResponse, - summary="获取知识库列表")(list_kbs) - - app.post("/knowledge_base/create_knowledge_base", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="创建知识库" - )(create_kb) - - app.post("/knowledge_base/delete_knowledge_base", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库" - )(delete_kb) - - app.get("/knowledge_base/list_files", - tags=["Knowledge Base Management"], - response_model=ListResponse, - summary="获取知识库内的文件列表" - )(list_files) - - app.post("/knowledge_base/search_docs", - tags=["Knowledge Base Management"], - response_model=List[DocumentWithVSId], - summary="搜索知识库" - )(search_docs) - - app.post("/knowledge_base/upload_docs", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="上传文件到知识库,并/或进行向量化" - )(upload_docs) - - app.post("/knowledge_base/delete_docs", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库内指定文件" - )(delete_docs) - - app.post("/knowledge_base/update_info", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="更新知识库介绍" - )(update_info) - - app.post("/knowledge_base/update_docs", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="更新现有文件到知识库" - )(update_docs) - - app.get("/knowledge_base/download_doc", - tags=["Knowledge Base Management"], - summary="下载对应的知识文件")(download_doc) - - app.post("/knowledge_base/recreate_vector_store", - tags=["Knowledge Base Management"], - summary="根据content中文档重建向量库,流式输出处理进度。" - )(recreate_vector_store) - - app.post("/knowledge_base/upload_temp_docs", - tags=["Knowledge Base Management"], - summary="上传文件到临时目录,用于文件对话。" - )(upload_temp_docs) - - -def mount_filename_summary_routes(app: FastAPI): - from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, - summary_doc_ids_to_vector_store) - - app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store", - tags=["Knowledge kb_summary_api Management"], - summary="单个知识库根据文件名称摘要" - )(summary_file_to_vector_store) - app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store", - tags=["Knowledge kb_summary_api Management"], - summary="单个知识库根据doc_ids摘要", - response_model=BaseResponse, - )(summary_doc_ids_to_vector_store) - app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store", - tags=["Knowledge kb_summary_api Management"], - summary="重建单个知识库文件摘要" - )(recreate_summary_vector_store) - - - -def run_api(host, port, **kwargs): - if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): - uvicorn.run(app, - host=host, - port=port, - ssl_keyfile=kwargs.get("ssl_keyfile"), - ssl_certfile=kwargs.get("ssl_certfile"), - ) - else: - uvicorn.run(app, host=host, port=port) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(prog='langchain-ChatGLM', - description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain' - ' | 基于本地知识库的 ChatGLM 问答') - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=7861) - parser.add_argument("--ssl_keyfile", type=str) - parser.add_argument("--ssl_certfile", type=str) - # 初始化消息 - args = parser.parse_args() - args_dict = vars(args) - - app = create_app() - - run_api(host=args.host, - port=args.port, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ) diff --git a/server/api_server/api_schemas.py b/server/api_server/api_schemas.py new file mode 100644 index 00000000..280f69af --- /dev/null +++ b/server/api_server/api_schemas.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import re +from typing import Dict, List, Literal, Optional, Union + +from fastapi import UploadFile +from server.pydantic_types import BaseModel, Field, AnyUrl, root_validator +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam, + completion_create_params, +) + +from configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG + + +class OpenAIBaseInput(BaseModel): + user: Optional[str] = None + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict] = None + extra_query: Optional[Dict] = None + extra_body: Optional[Dict] = None + timeout: Optional[float] = None + + +class OpenAIChatInput(OpenAIBaseInput): + messages: List[ChatCompletionMessageParam] + model: str = DEFAULT_LLM_MODEL + frequency_penalty: Optional[float] = None + function_call: Optional[completion_create_params.FunctionCall] = None + functions: List[completion_create_params.Function] = None + logit_bias: Optional[Dict[str, int]] = None + logprobs: Optional[bool] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[float] = None + response_format: completion_create_params.ResponseFormat = None + seed: Optional[int] = None + stop: Union[Optional[str], List[str]] = None + stream: Optional[bool] = None + temperature: Optional[float] = TEMPERATURE + tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None + tools: List[ChatCompletionToolParam] = None + top_logprobs: Optional[int] = None + top_p: Optional[float] = None + + +class OpenAIEmbeddingsInput(OpenAIBaseInput): + input: Union[str, List[str]] + model: str + dimensions: Optional[int] = None + encoding_format: Optional[Literal["float", "base64"]] = None + + +class OpenAIImageBaseInput(OpenAIBaseInput): + model: str + n: int = 1 + response_format: Optional[Literal["url", "b64_json"]] = None + size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = "256x256" + + +class OpenAIImageGenerationsInput(OpenAIImageBaseInput): + prompt: str + quality: Literal["standard", "hd"] = None + style: Optional[Literal["vivid", "natural"]] = None + + +class OpenAIImageVariationsInput(OpenAIImageBaseInput): + image: Union[UploadFile, AnyUrl] + + +class OpenAIImageEditsInput(OpenAIImageVariationsInput): + prompt: str + mask: Union[UploadFile, AnyUrl] + + +class OpenAIAudioTranslationsInput(OpenAIBaseInput): + file: Union[UploadFile, AnyUrl] + model: str + prompt: Optional[str] = None + response_format: Optional[str] = None + temperature: float = TEMPERATURE + + +class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput): + language: Optional[str] = None + timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None + + +class OpenAIAudioSpeechInput(OpenAIBaseInput): + input: str + model: str + voice: str + response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None + speed: Optional[float] = None diff --git a/server/api_server/chat_routes.py b/server/api_server/chat_routes.py new file mode 100644 index 00000000..50a1a84c --- /dev/null +++ b/server/api_server/chat_routes.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Request + +from server.chat.chat import chat +from server.chat.feedback import chat_feedback +from server.chat.file_chat import file_chat + + +chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"]) + +chat_router.post("/chat", + summary="与llm模型对话(通过LLMChain)", + )(chat) + +chat_router.post("/feedback", + summary="返回llm模型对话评分", + )(chat_feedback) + +chat_router.post("/file_chat", + summary="文件对话" + )(file_chat) \ No newline at end of file diff --git a/server/api_server/kb_routes.py b/server/api_server/kb_routes.py new file mode 100644 index 00000000..c482bba9 --- /dev/null +++ b/server/api_server/kb_routes.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Request + +from server.chat.file_chat import upload_temp_docs +from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb +from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, + update_docs, download_doc, recreate_vector_store, + search_docs, update_info) +from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, + summary_doc_ids_to_vector_store) +from server.knowledge_base.model.kb_document_model import DocumentWithVSId +from server.utils import BaseResponse, ListResponse + + +kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"]) + + +kb_router.get("/list_knowledge_bases", + response_model=ListResponse, + summary="获取知识库列表")(list_kbs) + +kb_router.post("/create_knowledge_base", + response_model=BaseResponse, + summary="创建知识库" + )(create_kb) + +kb_router.post("/delete_knowledge_base", + response_model=BaseResponse, + summary="删除知识库" + )(delete_kb) + +kb_router.get("/list_files", + response_model=ListResponse, + summary="获取知识库内的文件列表" + )(list_files) + +kb_router.post("/search_docs", + response_model=List[DocumentWithVSId], + summary="搜索知识库" + )(search_docs) + +kb_router.post("/upload_docs", + response_model=BaseResponse, + summary="上传文件到知识库,并/或进行向量化" + )(upload_docs) + +kb_router.post("/delete_docs", + response_model=BaseResponse, + summary="删除知识库内指定文件" + )(delete_docs) + +kb_router.post("/update_info", + response_model=BaseResponse, + summary="更新知识库介绍" + )(update_info) + +kb_router.post("/update_docs", + response_model=BaseResponse, + summary="更新现有文件到知识库" + )(update_docs) + +kb_router.get("/download_doc", + summary="下载对应的知识文件")(download_doc) + +kb_router.post("/recreate_vector_store", + summary="根据content中文档重建向量库,流式输出处理进度。" + )(recreate_vector_store) + +kb_router.post("/upload_temp_docs", + summary="上传文件到临时目录,用于文件对话。" + )(upload_temp_docs) + + +summary_router = APIRouter(prefix="/kb_summary_api") +summary_router.post("/summary_file_to_vector_store", + summary="单个知识库根据文件名称摘要" + )(summary_file_to_vector_store) +summary_router.post("/summary_doc_ids_to_vector_store", + summary="单个知识库根据doc_ids摘要", + response_model=BaseResponse, + )(summary_doc_ids_to_vector_store) +summary_router.post("/recreate_summary_vector_store", + summary="重建单个知识库文件摘要" + )(recreate_summary_vector_store) + +kb_router.include_router(summary_router) diff --git a/server/api_server/openai_routes.py b/server/api_server/openai_routes.py new file mode 100644 index 00000000..ecc9f03b --- /dev/null +++ b/server/api_server/openai_routes.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Dict, Tuple, AsyncGenerator + +from fastapi import APIRouter, Request +from openai import AsyncClient +from sse_starlette.sse import EventSourceResponse + +from .api_schemas import * +from configs import logger +from server.utils import get_model_info, get_config_platforms, get_OpenAIClient + + +DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 +model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform) +openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) + + +@asynccontextmanager +async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: + ''' + 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 + ''' + max_semaphore = 0 + selected_platform = "" + model_infos = get_model_info(model_name=model_name, multiple=True) + for m, c in model_infos.items(): + key = (m, c["platform_name"]) + api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) + if key not in model_semaphores: + model_semaphores[key] = asyncio.Semaphore(api_concurrencies) + semaphore = model_semaphores[key] + if semaphore._value >= api_concurrencies: + selected_platform = c["platform_name"] + break + elif semaphore._value > max_semaphore: + selected_platform = c["platform_name"] + + key = (m, selected_platform) + semaphore = model_semaphores[key] + try: + await semaphore.acquire() + yield get_OpenAIClient(platform_name=selected_platform, is_async=True) + except Exception: + logger.error(f"failed when request to {key}", exc_info=True) + finally: + semaphore.release() + + +async def openai_request(method, body): + ''' + helper function to make openai request + ''' + async def generator(): + async for chunk in await method(**params): + yield {"data": chunk.json()} + + params = body.dict(exclude_unset=True) + if hasattr(body, "stream") and body.stream: + return EventSourceResponse(generator()) + else: + return (await method(**params)).dict() + + +@openai_router.get("/models") +async def list_models() -> Dict: + ''' + 整合所有平台的模型列表。 + 由于 openai sdk 不支持重名模型,对于重名模型,只返回其中响应速度最快的一个。在请求其它接口时会自动按照模型忙闲状态进行调度。 + ''' + async def task(name: str): + try: + client = get_OpenAIClient(name, is_async=True) + models = await client.models.list() + models = models.dict(exclude=["data", "object"]) + for x in models: + models[x]["platform_name"] = name + return models + except Exception: + logger.error(f"failed request to platform: {name}", exc_info=True) + return {} + + result = {} + tasks = [asyncio.create_task(task(name)) for name in get_config_platforms()] + for t in asyncio.as_completed(tasks): + for n, v in (await t).items(): + if n not in result: + result[n] = v + + return result + + +@openai_router.post("/chat/completions") +async def create_chat_completions( + request: Request, + body: OpenAIChatInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.chat.completions.create, body) + + +@openai_router.post("/completions") +async def create_completions( + request: Request, + body: OpenAIChatInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.completions.create, body) + + +@openai_router.post("/embeddings") +async def create_embeddings( + request: Request, + body: OpenAIEmbeddingsInput, +): + params = body.dict(exclude_unset=True) + client = get_OpenAIClient(model_name=body.model) + return (await client.embeddings.create(**params)).dict() + + +@openai_router.post("/images/generations") +async def create_image_generations( + request: Request, + body: OpenAIImageGenerationsInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.images.generate, body) + + +@openai_router.post("/images/variations") +async def create_image_variations( + request: Request, + body: OpenAIImageVariationsInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.images.create_variation, body) + + +@openai_router.post("/images/edit") +async def create_image_edit( + request: Request, + body: OpenAIImageEditsInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.images.edit, body) + + +@openai_router.post("/audio/translations", deprecated="暂不支持") +async def create_audio_translations( + request: Request, + body: OpenAIAudioTranslationsInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.audio.translations.create, body) + + +@openai_router.post("/audio/transcriptions", deprecated="暂不支持") +async def create_audio_transcriptions( + request: Request, + body: OpenAIAudioTranscriptionsInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.audio.transcriptions.create, body) + + +@openai_router.post("/audio/speech", deprecated="暂不支持") +async def create_audio_speech( + request: Request, + body: OpenAIAudioSpeechInput, +): + async with acquire_model_client(body.model) as client: + return await openai_request(client.audio.speech.create, body) + + +@openai_router.post("/files", deprecated="暂不支持") +async def files(): + ... diff --git a/server/api_server/server_app.py b/server/api_server/server_app.py new file mode 100644 index 00000000..5caf4e46 --- /dev/null +++ b/server/api_server/server_app.py @@ -0,0 +1,91 @@ +import argparse +from typing import Literal + +from fastapi import FastAPI, Body +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from starlette.responses import RedirectResponse +import uvicorn + +from configs import VERSION, MEDIA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN +from server.api_server.chat_routes import chat_router +from server.api_server.kb_routes import kb_router +from server.api_server.openai_routes import openai_router +from server.api_server.server_routes import server_router +from server.api_server.tool_routes import tool_router +from server.chat.completion import completion +from server.utils import MakeFastAPIOffline + + +def create_app(run_mode: str=None): + app = FastAPI( + title="Langchain-Chatchat API Server", + version=VERSION + ) + MakeFastAPIOffline(app) + # Add CORS middleware to allow all origins + # 在config.py中设置OPEN_DOMAIN=True,允许跨域 + # set OPEN_DOMAIN=True in config.py to allow cross-domain + if OPEN_CROSS_DOMAIN: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/", summary="swagger 文档", include_in_schema=False) + async def document(): + return RedirectResponse(url="/docs") + + app.include_router(chat_router) + app.include_router(kb_router) + app.include_router(tool_router) + app.include_router(openai_router) + app.include_router(server_router) + + # 其它接口 + app.post("/other/completion", + tags=["Other"], + summary="要求llm模型补全(通过LLMChain)", + )(completion) + + # 媒体文件 + app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") + + return app + + +def run_api(host, port, **kwargs): + if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): + uvicorn.run(app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) + else: + uvicorn.run(app, host=host, port=port) + +app = create_app() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog='langchain-ChatGLM', + description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain' + ' | 基于本地知识库的 ChatGLM 问答') + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=7861) + parser.add_argument("--ssl_keyfile", type=str) + parser.add_argument("--ssl_certfile", type=str) + # 初始化消息 + args = parser.parse_args() + args_dict = vars(args) + + run_api(host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/server/api_server/server_routes.py b/server/api_server/server_routes.py new file mode 100644 index 00000000..be215ef7 --- /dev/null +++ b/server/api_server/server_routes.py @@ -0,0 +1,23 @@ +from typing import Literal + +from fastapi import APIRouter, Body + +from server.utils import get_server_configs, get_prompt_template + + +server_router = APIRouter(prefix="/server", tags=["Server State"]) + + +# 服务器相关接口 +server_router.post("/configs", + summary="获取服务器原始配置信息", + )(get_server_configs) + + +@server_router.post("/get_prompt_template", + summary="获取服务区配置的 prompt 模板") +def get_server_prompt_template( + type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"), + name: str = Body("default", description="模板名称"), +) -> str: + return get_prompt_template(type=type, name=name) diff --git a/server/static/favicon.png b/server/api_server/static/favicon.png similarity index 100% rename from server/static/favicon.png rename to server/api_server/static/favicon.png diff --git a/server/static/redoc.standalone.js b/server/api_server/static/redoc.standalone.js similarity index 100% rename from server/static/redoc.standalone.js rename to server/api_server/static/redoc.standalone.js diff --git a/server/static/swagger-ui-bundle.js b/server/api_server/static/swagger-ui-bundle.js similarity index 100% rename from server/static/swagger-ui-bundle.js rename to server/api_server/static/swagger-ui-bundle.js diff --git a/server/static/swagger-ui.css b/server/api_server/static/swagger-ui.css similarity index 100% rename from server/static/swagger-ui.css rename to server/api_server/static/swagger-ui.css diff --git a/server/api_server/tool_routes.py b/server/api_server/tool_routes.py new file mode 100644 index 00000000..49441612 --- /dev/null +++ b/server/api_server/tool_routes.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Request, Body + +from configs import logger +from server.utils import BaseResponse + + +tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) + + +@tool_router.get("/", response_model=BaseResponse) +async def list_tools(): + import importlib + from server.agent.tools_factory import tools_registry + importlib.reload(tools_registry) + + data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools} + return {"data": data} + + +@tool_router.post("/call", response_model=BaseResponse) +async def call_tool( + name: str = Body(examples=["calculate"]), + kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]), +): + import importlib + from server.agent.tools_factory import tools_registry + importlib.reload(tools_registry) + + tool_names = {t.name: t for t in tools_registry.all_tools} + if tool := tool_names.get(name): + try: + result = await tool.ainvoke(kwargs) + return {"data": result} + except Exception: + msg = f"failed to call tool '{name}'" + logger.error(msg, exc_info=True) + return {"code": 500, "msg": msg} + else: + return {"code": 500, "msg": f"no tool named '{name}'"} diff --git a/server/chat/utils.py b/server/chat/utils.py index 2fd82674..87d61559 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,5 +1,5 @@ from functools import lru_cache -from pydantic import BaseModel, Field +from server.pydantic_types import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose from typing import List, Tuple, Dict, Union diff --git a/server/document_loaders/FilteredCSVloader.py b/server/document_loaders/FilteredCSVloader.py index d9ca508b..07d71503 100644 --- a/server/document_loaders/FilteredCSVloader.py +++ b/server/document_loaders/FilteredCSVloader.py @@ -1,11 +1,11 @@ ## 指定制定列的csv文件加载器 -from langchain.document_loaders import CSVLoader +from langchain_community.document_loaders import CSVLoader import csv from io import TextIOWrapper from typing import Dict, List, Optional from langchain.docstore.document import Document -from langchain.document_loaders.helpers import detect_file_encodings +from langchain_community.document_loaders.helpers import detect_file_encodings class FilteredCSVLoader(CSVLoader): diff --git a/server/document_loaders/mydocloader.py b/server/document_loaders/mydocloader.py index 7f5462a2..d10dd49b 100644 --- a/server/document_loaders/mydocloader.py +++ b/server/document_loaders/mydocloader.py @@ -1,4 +1,4 @@ -from langchain.document_loaders.unstructured import UnstructuredFileLoader +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from typing import List import tqdm diff --git a/server/document_loaders/myimgloader.py b/server/document_loaders/myimgloader.py index ffedf8ff..9e0b0c72 100644 --- a/server/document_loaders/myimgloader.py +++ b/server/document_loaders/myimgloader.py @@ -1,5 +1,5 @@ from typing import List -from langchain.document_loaders.unstructured import UnstructuredFileLoader +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from server.document_loaders.ocr import get_ocr diff --git a/server/document_loaders/mypdfloader.py b/server/document_loaders/mypdfloader.py index b15364be..e0f11c5a 100644 --- a/server/document_loaders/mypdfloader.py +++ b/server/document_loaders/mypdfloader.py @@ -1,5 +1,5 @@ from typing import List -from langchain.document_loaders.unstructured import UnstructuredFileLoader +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader import cv2 from PIL import Image import numpy as np diff --git a/server/document_loaders/mypptloader.py b/server/document_loaders/mypptloader.py index f14d0728..309ffdcc 100644 --- a/server/document_loaders/mypptloader.py +++ b/server/document_loaders/mypptloader.py @@ -1,4 +1,4 @@ -from langchain.document_loaders.unstructured import UnstructuredFileLoader +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from typing import List import tqdm diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index e9a22b21..b5246da0 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -10,7 +10,6 @@ from server.knowledge_base.utils import (validate_kb_name, list_files_from_folde files2docs_in_thread, KnowledgeFile) from fastapi.responses import FileResponse from sse_starlette import EventSourceResponse -from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import get_file_detail @@ -120,7 +119,7 @@ def upload_docs( chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - docs: Json = Form({}, description="自定义的docs,需要转为json字符串"), + docs: str = Form("", description="自定义的docs,需要转为json字符串"), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ @@ -133,6 +132,7 @@ def upload_docs( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + docs = json.loads(docs) if docs else {} failed_files = {} file_names = list(docs.keys()) @@ -221,7 +221,7 @@ def update_docs( chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), - docs: Json = Body({}, description="自定义的docs,需要转为json字符串"), + docs: str = Body("", description="自定义的docs,需要转为json字符串"), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ @@ -236,6 +236,7 @@ def update_docs( failed_files = {} kb_files = [] + docs = json.loads(docs) if docs else {} # 生成需要加载docs的文件列表 for file_name in file_names: diff --git a/server/knowledge_base/kb_service/chromadb_kb_service.py b/server/knowledge_base/kb_service/chromadb_kb_service.py index 5e1d746c..aa83f5e3 100644 --- a/server/knowledge_base/kb_service/chromadb_kb_service.py +++ b/server/knowledge_base/kb_service/chromadb_kb_service.py @@ -6,9 +6,9 @@ from chromadb.api.types import (GetResult, QueryResult) from langchain.docstore.document import Document from configs import SCORE_THRESHOLD -from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter, - KBService, SupportedVSType) +from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path +from server.utils import get_Embeddings def _get_result_to_documents(get_result: GetResult) -> List[Document]: @@ -75,7 +75,7 @@ class ChromaKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ Tuple[Document, float]]: - embed_func = EmbeddingsFunAdapter(self.embed_model) + embed_func = get_Embeddings(self.embed_model) embeddings = embed_func.embed_query(query) query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) return _results_to_docs_and_scores(query_result) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 0964fdfe..268f81a0 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -4,7 +4,8 @@ import shutil from configs import SCORE_THRESHOLD from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss -from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path, EmbeddingsFunAdapter +from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path +from server.utils import get_Embeddings from langchain.docstore.document import Document from typing import List, Dict, Optional, Tuple @@ -60,8 +61,9 @@ class FaissKBService(KBService): query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, - ) -> List[Document]: - + ) -> List[Tuple[Document, float]]: + embed_func = get_Embeddings(self.embed_model) + embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: embeddings = vs.embeddings.embed_query(query) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) @@ -72,11 +74,10 @@ class FaissKBService(KBService): **kwargs, ) -> List[Dict]: + texts = [x.page_content for x in docs] + metadatas = [x.metadata for x in docs] with self.load_vector_store().acquire() as vs: - texts = [x.page_content for x in docs] - metadatas = [x.metadata for x in docs] embeddings = vs.embeddings.embed_documents(texts) - ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings), metadatas=metadatas) if not kwargs.get("not_refresh_vs_cache"): diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index c7f3b762..f3dbda31 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -7,9 +7,10 @@ import os from configs import kbs_config from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ score_threshold_process from server.knowledge_base.utils import KnowledgeFile +from server.utils import get_Embeddings class MilvusKBService(KBService): @@ -49,7 +50,7 @@ class MilvusKBService(KBService): return SupportedVSType.MILVUS def _load_milvus(self): - self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model), + self.milvus = Milvus(embedding_function=(self.embed_model), collection_name=self.kb_name, connection_args=kbs_config.get("milvus"), index_params=kbs_config.get("milvus_kwargs")["index_params"], @@ -66,7 +67,7 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() - embed_func = EmbeddingsFunAdapter(self.embed_model) + embed_func = get_Embeddings(self.embed_model) embeddings = embed_func.embed_query(query) docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) return score_threshold_process(score_threshold, top_k, docs) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 46efe7d8..a9d578d4 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -7,9 +7,10 @@ from sqlalchemy import text from configs import kbs_config -from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ +from server.knowledge_base.kb_service.base import SupportedVSType, KBService, \ score_threshold_process from server.knowledge_base.utils import KnowledgeFile +from server.utils import get_Embeddings import shutil import sqlalchemy from sqlalchemy.engine.base import Engine @@ -20,7 +21,7 @@ class PGKBService(KBService): engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10) def _load_pg_vector(self): - self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), + self.pg_vector = PGVector(embedding_function=get_Embeddings(self.embed_model), collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, connection=PGKBService.engine, @@ -59,7 +60,7 @@ class PGKBService(KBService): shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - embed_func = EmbeddingsFunAdapter(self.embed_model) + embed_func = get_Embeddings(self.embed_model) embeddings = embed_func.embed_query(query) docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) return score_threshold_process(score_threshold, top_k, docs) diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index 753225a0..f0fd0fc7 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -3,9 +3,10 @@ from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import Zilliz from configs import kbs_config -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ score_threshold_process from server.knowledge_base.utils import KnowledgeFile +from server.utils import get_Embeddings class ZillizKBService(KBService): @@ -46,7 +47,7 @@ class ZillizKBService(KBService): def _load_zilliz(self): zilliz_args = kbs_config.get("zilliz") - self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model), + self.zilliz = Zilliz(embedding_function=get_Embeddings(self.embed_model), collection_name=self.kb_name, connection_args=zilliz_args) def do_init(self): @@ -59,7 +60,7 @@ class ZillizKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_zilliz() - embed_func = EmbeddingsFunAdapter(self.embed_model) + embed_func = get_Embeddings(self.embed_model) embeddings = embed_func.embed_query(query) docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) return score_threshold_process(score_threshold, top_k, docs) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 9e52162a..fc7edac3 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -12,7 +12,7 @@ from configs import ( ) import importlib from server.text_splitter import zh_title_enhance as func_zh_title_enhance -import langchain.document_loaders +import langchain_community.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path @@ -136,7 +136,7 @@ class JSONLinesLoader(JSONLoader): self._json_lines = True -langchain.document_loaders.JSONLinesLoader = JSONLinesLoader +langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader def get_LoaderClass(file_extension): @@ -155,13 +155,13 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): "RapidOCRDocLoader", "RapidOCRPPTLoader"]: document_loaders_module = importlib.import_module("server.document_loaders") else: - document_loaders_module = importlib.import_module("langchain.document_loaders") + document_loaders_module = importlib.import_module("langchain_community.document_loaders") DocumentLoader = getattr(document_loaders_module, loader_name) except Exception as e: msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) - document_loaders_module = importlib.import_module("langchain.document_loaders") + document_loaders_module = importlib.import_module("langchain_community.document_loaders") DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if loader_name == "UnstructuredFileLoader": diff --git a/server/pydantic_types.py b/server/pydantic_types.py new file mode 100644 index 00000000..78e754e6 --- /dev/null +++ b/server/pydantic_types.py @@ -0,0 +1,7 @@ +from langchain_core.pydantic_v1 import * +from pydantic.fields import FieldInfo +from pydantic.schema import model_schema + +# from pydantic.v1 import * +# from pydantic.v1.fields import FieldInfo +# from pydantic.v1.schema import model_schema diff --git a/server/utils.py b/server/utils.py index 61723f08..5365d5f0 100644 --- a/server/utils.py +++ b/server/utils.py @@ -3,11 +3,11 @@ from pathlib import Path import asyncio import os from concurrent.futures import ThreadPoolExecutor, as_completed -from langchain.pydantic_v1 import BaseModel, Field from langchain.embeddings.base import Embeddings from langchain_openai.chat_models import ChatOpenAI from langchain_openai.llms import OpenAI import httpx +import openai from typing import ( Optional, Callable, @@ -22,8 +22,10 @@ from typing import ( ) import logging -from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL, TEMPERATURE -from server.minx_chat_openai import MinxChatOpenAI +from configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, + DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE) +from server.pydantic_types import BaseModel, Field +from server.minx_chat_openai import MinxChatOpenAI # TODO: still used? async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -40,6 +42,14 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): event.set() +def get_config_platforms() -> Dict[str, Dict]: + import importlib + from configs import model_config + importlib.reload(model_config) + + return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS} + + def get_config_models( model_name: str = None, model_type: Literal["llm", "embed", "image", "multimodal"] = None, @@ -88,40 +98,54 @@ def get_config_models( return result -def get_model_info(model_name: str, platform_name: str = None) -> Dict: +def get_model_info(model_name: str = None, platform_name: str = None, multiple: bool = False) -> Dict: ''' 获取配置的模型信息,主要是 api_base_url, api_key + 如果指定 multiple=True,则返回所有重名模型;否则仅返回第一个 ''' result = get_config_models(model_name=model_name, platform_name=platform_name) if len(result) > 0: - return list(result.values())[0] + if multiple: + return result + else: + return list(result.values())[0] else: return {} def get_ChatOpenAI( - model_name: str, + model_name: str = DEFAULT_LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = None, streaming: bool = True, callbacks: List[Callable] = [], verbose: bool = True, + local_wrap: bool = True, # use local wrapped api **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) - try: - model = ChatOpenAI( + params = dict( streaming=streaming, verbose=verbose, callbacks=callbacks, model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_api_key=model_info.get("api_key"), - openai_api_base=model_info.get("api_base_url"), - openai_proxy=model_info.get("api_proxy"), **kwargs - ) + ) + try: + if local_wrap: + params.update( + openai_api_base = f"{api_address()}/v1", + openai_api_key = "EMPTY", + ) + else: + params.update( + openai_api_base=model_info.get("api_base_url"), + openai_api_key=model_info.get("api_key"), + openai_proxy=model_info.get("api_proxy"), + ) + model = ChatOpenAI(**params) except Exception as e: logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True) model = None @@ -136,41 +160,100 @@ def get_OpenAI( echo: bool = True, callbacks: List[Callable] = [], verbose: bool = True, + local_wrap: bool = True, # use local wrapped api **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 model_info = get_model_info(model_name) - model = OpenAI( + params = dict( streaming=streaming, verbose=verbose, callbacks=callbacks, model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_api_key=model_info.get("api_key"), - openai_api_base=model_info.get("api_base_url"), - openai_proxy=model_info.get("api_proxy"), echo=echo, **kwargs ) + try: + if local_wrap: + params.update( + openai_api_base = f"{api_address()}/v1", + openai_api_key = "EMPTY", + ) + else: + params.update( + openai_api_base=model_info.get("api_base_url"), + openai_api_key=model_info.get("api_key"), + openai_proxy=model_info.get("api_proxy"), + ) + model = OpenAI(**params) + except Exception as e: + logger.error(f"failed to create OpenAI for model: {model_name}.", exc_info=True) + model = None return model -def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings: +def get_Embeddings( + embed_model: str = DEFAULT_EMBEDDING_MODEL, + local_wrap: bool = True, # use local wrapped api +) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 model_info = get_model_info(model_name=embed_model) + params = dict(model=embed_model) + try: + if local_wrap: + params.update( + openai_api_base = f"{api_address()}/v1", + openai_api_key = "EMPTY", + ) + else: + params.update( + openai_api_base=model_info.get("api_base_url"), + openai_api_key=model_info.get("api_key"), + openai_proxy=model_info.get("api_proxy"), + ) + if model_info.get("platform_type") == "openai": + return OpenAIEmbeddings(**params) + else: + return LocalAIEmbeddings(**params) + except Exception as e: + logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) + + +def get_OpenAIClient( + platform_name: str=None, + model_name: str=None, + is_async: bool=True, +) -> Union[openai.Client, openai.AsyncClient]: + ''' + construct an openai Client for specified platform or model + ''' + if platform_name is None: + platform_name = get_model_info(model_name=model_name, platform_name=platform_name)["platform_name"] + platform_info = get_config_platforms().get(platform_name) + assert platform_info, f"cannot find configured platform: {platform_name}" params = { - "model": embed_model, - "base_url": model_info.get("api_base_url"), - "api_key": model_info.get("api_key"), - "openai_proxy": model_info.get("api_proxy"), + "base_url": platform_info.get("api_base_url"), + "api_key": platform_info.get("api_key"), } - if model_info.get("platform_type") == "openai": - return OpenAIEmbeddings(**params) + httpx_params = {} + if api_proxy := platform_info.get("api_proxy"): + httpx_params = { + "proxies": api_proxy, + "transport": httpx.HTTPTransport(local_address="0.0.0.0"), + } + + if is_async: + if httpx_params: + params["http_client"] = httpx.AsyncClient(**httpx_params) + return openai.AsyncClient(**params) else: - return LocalAIEmbeddings(**params) + if httpx_params: + params["http_client"] = httpx.Client(**httpx_params) + return openai.Client(**params) class MsgType: @@ -281,7 +364,7 @@ def iter_over_async(ait, loop=None): def MakeFastAPIOffline( app: FastAPI, - static_dir=Path(__file__).parent / "static", + static_dir=Path(__file__).parent / "api_server" / "static", static_url="/static-offline-docs", docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", diff --git a/startup.py b/startup.py index b1db023d..85947d24 100644 --- a/startup.py +++ b/startup.py @@ -42,7 +42,7 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None): def run_api_server(started_event: mp.Event = None, run_mode: str = None): - from server.api import create_app + from server.api_server.server_app import create_app import uvicorn from server.utils import set_httpx_config set_httpx_config() @@ -199,13 +199,6 @@ async def start_main_server(): args.api_worker = True args.webui = False - elif args.llm_api: - args.openai_api = True - args.model_worker = True - args.api_worker = True - args.api = False - args.webui = False - if args.lite: args.model_worker = False run_mode = "lite" diff --git a/tests/api/test_openai_wrap.py b/tests/api/test_openai_wrap.py new file mode 100644 index 00000000..342f8ced --- /dev/null +++ b/tests/api/test_openai_wrap.py @@ -0,0 +1,31 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent.parent)) + +import requests + +import openai + +from configs import DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL +from server.utils import api_address + + +api_base_url = f"{api_address()}/v1" +client = openai.Client( + api_key="EMPTY", + base_url=api_base_url, +) + +def test_chat(): + resp = client.chat.completions.create( + messages=[{"role": "user", "content": "你是谁"}], + model=DEFAULT_LLM_MODEL, + ) + print(resp) + assert hasattr(resp, "choices") and len(resp.choices) > 0 + + +def test_embeddings(): + resp = client.embeddings.create(input="你是谁", model=DEFAULT_EMBEDDING_MODEL) + print(resp) + assert hasattr(resp, "data") and len(resp.data) > 0 diff --git a/tests/api/test_tools.py b/tests/api/test_tools.py new file mode 100644 index 00000000..810f4c68 --- /dev/null +++ b/tests/api/test_tools.py @@ -0,0 +1,30 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from pprint import pprint +import requests + +from server.utils import api_address + + +api_base_url = f"{api_address()}/tools" + +def test_tool_list(): + resp = requests.get(api_base_url) + assert resp.status_code == 200 + data = resp.json()["data"] + pprint(data) + assert "calculate" in data + + +def test_tool_call(): + data = { + "name": "calculate", + "kwargs": {"a":1,"b":2,"operator":"+"}, + } + resp = requests.post(f"{api_base_url}/call", json=data) + assert resp.status_code == 200 + data = resp.json()["data"] + pprint(data) + assert data == 3