mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
- pydantic 限定为 v1,并统一项目中所有 pydantic 导入路径,为以后升级 v2 做准备
- 重构 api.py:
- 按模块划分为不同的 router
- 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
- 添加 /tools 接口,可以获取/调用编写的 agent tools
- 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
- 待办:
- /chat/chat 接口改为 openai 兼容
- 添加 /chat/kb_chat 接口,openai 兼容
- 改变 ntlk/knowledge_base/logs 等数据目录位置
This commit is contained in:
parent
65466007ae
commit
d0846f88cc
@ -76,25 +76,34 @@ LLM_MODEL_CONFIG = {
|
||||
},
|
||||
}
|
||||
|
||||
# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
||||
# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
||||
# - platform_name 可以任意填写,不要重复即可
|
||||
# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
|
||||
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
|
||||
|
||||
MODEL_PLATFORMS = [
|
||||
{
|
||||
"platform_name": "openai-api",
|
||||
"platform_type": "openai",
|
||||
"llm_models": [
|
||||
"gpt-3.5-turbo",
|
||||
],
|
||||
"embed_models": [],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-",
|
||||
"api_proxy": "",
|
||||
},
|
||||
# {
|
||||
# "platform_name": "openai-api",
|
||||
# "platform_type": "openai",
|
||||
# "api_base_url": "https://api.openai.com/v1",
|
||||
# "api_key": "sk-",
|
||||
# "api_proxy": "",
|
||||
# "api_concurrencies": 5,
|
||||
# "llm_models": [
|
||||
# "gpt-3.5-turbo",
|
||||
# ],
|
||||
# "embed_models": [],
|
||||
# "image_models": [],
|
||||
# "multimodal_models": [],
|
||||
# },
|
||||
|
||||
{
|
||||
"platform_name": "xinference",
|
||||
"platform_type": "xinference",
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
"api_concurrencies": 5,
|
||||
# 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
@ -107,40 +116,38 @@ MODEL_PLATFORMS = [
|
||||
"multimodal_models": [
|
||||
"qwen-vl",
|
||||
],
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "oneapi",
|
||||
"platform_type": "oneapi",
|
||||
"api_key": "",
|
||||
"llm_models": [
|
||||
"qwen-turbo",
|
||||
"qwen-plus",
|
||||
"chatglm_turbo",
|
||||
"chatglm_std",
|
||||
],
|
||||
"embed_models": [],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
"api_base_url": "http://127.0.0.1:3000/v1",
|
||||
"api_key": "sk-xxx",
|
||||
},
|
||||
# {
|
||||
# "platform_name": "oneapi",
|
||||
# "platform_type": "oneapi",
|
||||
# "api_base_url": "http://127.0.0.1:3000/v1",
|
||||
# "api_key": "",
|
||||
# "api_concurrencies": 5,
|
||||
# "llm_models": [
|
||||
# "qwen-turbo",
|
||||
# "qwen-plus",
|
||||
# "chatglm_turbo",
|
||||
# "chatglm_std",
|
||||
# ],
|
||||
# "embed_models": [],
|
||||
# "image_models": [],
|
||||
# "multimodal_models": [],
|
||||
# },
|
||||
|
||||
{
|
||||
"platform_name": "loom",
|
||||
"platform_type": "loom",
|
||||
"api_key": "",
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
"embed_models": [],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
"api_base_url": "http://127.0.0.1:7860/v1",
|
||||
"api_key": "EMPTY",
|
||||
},
|
||||
# {
|
||||
# "platform_name": "loom",
|
||||
# "platform_type": "loom",
|
||||
# "api_base_url": "http://127.0.0.1:7860/v1",
|
||||
# "api_key": "",
|
||||
# "api_concurrencies": 5,
|
||||
# "llm_models": [
|
||||
# "chatglm3-6b",
|
||||
# ],
|
||||
# "embed_models": [],
|
||||
# "image_models": [],
|
||||
# "multimodal_models": [],
|
||||
# },
|
||||
]
|
||||
|
||||
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
|
||||
|
||||
@ -2,7 +2,7 @@ import sys
|
||||
sys.path.append(".")
|
||||
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL
|
||||
from configs.model_config import DEFAULT_EMBEDDING_MODEL
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
||||
@ -41,6 +41,38 @@ youtube-search==2.1.2
|
||||
duckduckgo-search==3.9.9
|
||||
metaphor-python==0.1.23
|
||||
|
||||
httpx==0.26.0
|
||||
httpx_sse==0.4.0
|
||||
watchdog==3.0.0
|
||||
pyjwt==2.8.0
|
||||
elasticsearch
|
||||
numexpr>=2.8.8
|
||||
strsimpy>=0.2.1
|
||||
markdownify>=0.11.6
|
||||
tqdm>=4.66.1
|
||||
websockets>=12.0
|
||||
numpy>=1.26.3
|
||||
pandas~=2.1.4
|
||||
pydantic<2
|
||||
httpx[brotli,http2,socks]>=0.25.2
|
||||
|
||||
# optional document loaders
|
||||
|
||||
# rapidocr_paddle[gpu]>=1.3.0.post5
|
||||
# jq>=1.6.0
|
||||
# html2text
|
||||
# beautifulsoup4>=4.12.2
|
||||
# pysrt>=1.1.2
|
||||
|
||||
# Agent and Search Tools
|
||||
|
||||
# arxiv>=2.1.0
|
||||
# youtube-search>=2.1.2
|
||||
# duckduckgo-search>=4.1.0
|
||||
# metaphor-python>=0.1.23
|
||||
|
||||
# WebUI requirements
|
||||
|
||||
streamlit==1.30.0
|
||||
streamlit-option-menu==0.3.12
|
||||
streamlit-antd-components==0.3.1
|
||||
@ -48,7 +80,3 @@ streamlit-chatbox==1.1.11
|
||||
streamlit-modal==0.1.0
|
||||
streamlit-aggrid==0.3.4.post3
|
||||
|
||||
httpx==0.26.0
|
||||
httpx_sse==0.4.0
|
||||
watchdog==3.0.0
|
||||
pyjwt==2.8.0
|
||||
|
||||
@ -16,10 +16,8 @@ from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.pydantic_v1 import Field
|
||||
from server.pydantic_types import Field, typing, model_schema
|
||||
|
||||
from pydantic import typing
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# LangChain 的 ArxivQueryRun 工具
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
|
||||
|
||||
def arxiv(query: str):
|
||||
tool = ArxivQueryRun()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import os
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
|
||||
def save_base64_audio(base64_audio, file_path):
|
||||
audio_data = base64.b64decode(base64_audio)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
|
||||
def calculate(a: float, b: float, operator: str) -> float:
|
||||
if operator == "+":
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||
from configs import TOOL_CONFIG
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from urllib.parse import urlencode
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from configs import TOOL_CONFIG
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from langchain_community.tools import YouTubeSearchTool
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
|
||||
|
||||
def search_youtube(query: str):
|
||||
tool = YouTubeSearchTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# LangChain 的 Shell 工具
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from langchain_community.tools import ShellTool
|
||||
|
||||
|
||||
def shell(query: str):
|
||||
tool = ShellTool()
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
@ -6,9 +6,8 @@ from typing import List
|
||||
import uuid
|
||||
|
||||
from langchain.agents import tool
|
||||
from langchain.pydantic_v1 import Field
|
||||
from server.pydantic_types import Field, FieldInfo
|
||||
import openai
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from configs.basic_config import MEDIA_PATH
|
||||
from server.utils import MsgType
|
||||
|
||||
@ -4,7 +4,7 @@ Method Use cogagent to generate response for a given image and query.
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from configs import TOOL_CONFIG
|
||||
import re
|
||||
from server.agent.container import container
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
"""
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
import requests
|
||||
|
||||
def weather(location: str, api_key: str):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# Langchain 自带的 Wolfram Alpha API 封装
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
wolfram_alpha_appid = "your key"
|
||||
|
||||
|
||||
def wolfram(query: str):
|
||||
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
|
||||
ans = wolfram.run(query)
|
||||
|
||||
225
server/api.py
225
server/api.py
@ -1,225 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs import VERSION, MEDIA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat.chat import chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
from typing import List, Literal
|
||||
|
||||
|
||||
async def document():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def create_app(run_mode: str = None):
|
||||
app = FastAPI(
|
||||
title="Langchain-Chatchat API Server",
|
||||
version=VERSION
|
||||
)
|
||||
MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
mount_app_routes(app, run_mode=run_mode)
|
||||
return app
|
||||
|
||||
|
||||
def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
app.get("/",
|
||||
response_model=BaseResponse,
|
||||
summary="swagger 文档")(document)
|
||||
|
||||
# Tag: Chat
|
||||
app.post("/chat/chat",
|
||||
tags=["Chat"],
|
||||
summary="与llm模型对话(通过LLMChain)",
|
||||
)(chat)
|
||||
|
||||
app.post("/chat/feedback",
|
||||
tags=["Chat"],
|
||||
summary="返回llm模型对话评分",
|
||||
)(chat_feedback)
|
||||
|
||||
# 知识库相关接口
|
||||
mount_knowledge_routes(app)
|
||||
# 摘要相关接口
|
||||
mount_filename_summary_routes(app)
|
||||
|
||||
# 服务器相关接口
|
||||
app.post("/server/configs",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
|
||||
@app.post("/server/get_prompt_template",
|
||||
tags=["Server State"],
|
||||
summary="获取服务区配置的 prompt 模板")
|
||||
def get_server_prompt_template(
|
||||
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"),
|
||||
name: str = Body("default", description="模板名称"),
|
||||
) -> str:
|
||||
return get_prompt_template(type=type, name=name)
|
||||
|
||||
# 其它接口
|
||||
app.post("/other/completion",
|
||||
tags=["Other"],
|
||||
summary="要求llm模型补全(通过LLMChain)",
|
||||
)(completion)
|
||||
|
||||
# 媒体文件
|
||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||
|
||||
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, update_info)
|
||||
|
||||
app.post("/chat/file_chat",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库列表")(list_kbs)
|
||||
|
||||
app.post("/knowledge_base/create_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="创建知识库"
|
||||
)(create_kb)
|
||||
|
||||
app.post("/knowledge_base/delete_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
app.get("/knowledge_base/list_files",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库内的文件列表"
|
||||
)(list_files)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=List[DocumentWithVSId],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到知识库,并/或进行向量化"
|
||||
)(upload_docs)
|
||||
|
||||
app.post("/knowledge_base/delete_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_docs)
|
||||
|
||||
app.post("/knowledge_base/update_info",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新知识库介绍"
|
||||
)(update_info)
|
||||
|
||||
app.post("/knowledge_base/update_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新现有文件到知识库"
|
||||
)(update_docs)
|
||||
|
||||
app.get("/knowledge_base/download_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="下载对应的知识文件")(download_doc)
|
||||
|
||||
app.post("/knowledge_base/recreate_vector_store",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
app.post("/knowledge_base/upload_temp_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="上传文件到临时目录,用于文件对话。"
|
||||
)(upload_temp_docs)
|
||||
|
||||
|
||||
def mount_filename_summary_routes(app: FastAPI):
|
||||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||||
summary_doc_ids_to_vector_store)
|
||||
|
||||
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="单个知识库根据文件名称摘要"
|
||||
)(summary_file_to_vector_store)
|
||||
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="单个知识库根据doc_ids摘要",
|
||||
response_model=BaseResponse,
|
||||
)(summary_doc_ids_to_vector_store)
|
||||
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="重建单个知识库文件摘要"
|
||||
)(recreate_summary_vector_store)
|
||||
|
||||
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||||
' | 基于本地知识库的 ChatGLM 问答')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
|
||||
app = create_app()
|
||||
|
||||
run_api(host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
97
server/api_server/api_schemas.py
Normal file
97
server/api_server/api_schemas.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import UploadFile
|
||||
from server.pydantic_types import BaseModel, Field, AnyUrl, root_validator
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
completion_create_params,
|
||||
)
|
||||
|
||||
from configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
|
||||
|
||||
|
||||
class OpenAIBaseInput(BaseModel):
|
||||
user: Optional[str] = None
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict] = None
|
||||
extra_query: Optional[Dict] = None
|
||||
extra_body: Optional[Dict] = None
|
||||
timeout: Optional[float] = None
|
||||
|
||||
|
||||
class OpenAIChatInput(OpenAIBaseInput):
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str = DEFAULT_LLM_MODEL
|
||||
frequency_penalty: Optional[float] = None
|
||||
function_call: Optional[completion_create_params.FunctionCall] = None
|
||||
functions: List[completion_create_params.Function] = None
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: completion_create_params.ResponseFormat = None
|
||||
seed: Optional[int] = None
|
||||
stop: Union[Optional[str], List[str]] = None
|
||||
stream: Optional[bool] = None
|
||||
temperature: Optional[float] = TEMPERATURE
|
||||
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
|
||||
tools: List[ChatCompletionToolParam] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
|
||||
|
||||
class OpenAIEmbeddingsInput(OpenAIBaseInput):
|
||||
input: Union[str, List[str]]
|
||||
model: str
|
||||
dimensions: Optional[int] = None
|
||||
encoding_format: Optional[Literal["float", "base64"]] = None
|
||||
|
||||
|
||||
class OpenAIImageBaseInput(OpenAIBaseInput):
|
||||
model: str
|
||||
n: int = 1
|
||||
response_format: Optional[Literal["url", "b64_json"]] = None
|
||||
size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = "256x256"
|
||||
|
||||
|
||||
class OpenAIImageGenerationsInput(OpenAIImageBaseInput):
|
||||
prompt: str
|
||||
quality: Literal["standard", "hd"] = None
|
||||
style: Optional[Literal["vivid", "natural"]] = None
|
||||
|
||||
|
||||
class OpenAIImageVariationsInput(OpenAIImageBaseInput):
|
||||
image: Union[UploadFile, AnyUrl]
|
||||
|
||||
|
||||
class OpenAIImageEditsInput(OpenAIImageVariationsInput):
|
||||
prompt: str
|
||||
mask: Union[UploadFile, AnyUrl]
|
||||
|
||||
|
||||
class OpenAIAudioTranslationsInput(OpenAIBaseInput):
|
||||
file: Union[UploadFile, AnyUrl]
|
||||
model: str
|
||||
prompt: Optional[str] = None
|
||||
response_format: Optional[str] = None
|
||||
temperature: float = TEMPERATURE
|
||||
|
||||
|
||||
class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput):
|
||||
language: Optional[str] = None
|
||||
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
|
||||
|
||||
|
||||
class OpenAIAudioSpeechInput(OpenAIBaseInput):
|
||||
input: str
|
||||
model: str
|
||||
voice: str
|
||||
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
|
||||
speed: Optional[float] = None
|
||||
24
server/api_server/chat_routes.py
Normal file
24
server/api_server/chat_routes.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from server.chat.chat import chat
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.chat.file_chat import file_chat
|
||||
|
||||
|
||||
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
|
||||
|
||||
chat_router.post("/chat",
|
||||
summary="与llm模型对话(通过LLMChain)",
|
||||
)(chat)
|
||||
|
||||
chat_router.post("/feedback",
|
||||
summary="返回llm模型对话评分",
|
||||
)(chat_feedback)
|
||||
|
||||
chat_router.post("/file_chat",
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
89
server/api_server/kb_routes.py
Normal file
89
server/api_server/kb_routes.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from server.chat.file_chat import upload_temp_docs
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, update_info)
|
||||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||||
summary_doc_ids_to_vector_store)
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
||||
|
||||
kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"])
|
||||
|
||||
|
||||
kb_router.get("/list_knowledge_bases",
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库列表")(list_kbs)
|
||||
|
||||
kb_router.post("/create_knowledge_base",
|
||||
response_model=BaseResponse,
|
||||
summary="创建知识库"
|
||||
)(create_kb)
|
||||
|
||||
kb_router.post("/delete_knowledge_base",
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
kb_router.get("/list_files",
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库内的文件列表"
|
||||
)(list_files)
|
||||
|
||||
kb_router.post("/search_docs",
|
||||
response_model=List[DocumentWithVSId],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
kb_router.post("/upload_docs",
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到知识库,并/或进行向量化"
|
||||
)(upload_docs)
|
||||
|
||||
kb_router.post("/delete_docs",
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_docs)
|
||||
|
||||
kb_router.post("/update_info",
|
||||
response_model=BaseResponse,
|
||||
summary="更新知识库介绍"
|
||||
)(update_info)
|
||||
|
||||
kb_router.post("/update_docs",
|
||||
response_model=BaseResponse,
|
||||
summary="更新现有文件到知识库"
|
||||
)(update_docs)
|
||||
|
||||
kb_router.get("/download_doc",
|
||||
summary="下载对应的知识文件")(download_doc)
|
||||
|
||||
kb_router.post("/recreate_vector_store",
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
kb_router.post("/upload_temp_docs",
|
||||
summary="上传文件到临时目录,用于文件对话。"
|
||||
)(upload_temp_docs)
|
||||
|
||||
|
||||
summary_router = APIRouter(prefix="/kb_summary_api")
|
||||
summary_router.post("/summary_file_to_vector_store",
|
||||
summary="单个知识库根据文件名称摘要"
|
||||
)(summary_file_to_vector_store)
|
||||
summary_router.post("/summary_doc_ids_to_vector_store",
|
||||
summary="单个知识库根据doc_ids摘要",
|
||||
response_model=BaseResponse,
|
||||
)(summary_doc_ids_to_vector_store)
|
||||
summary_router.post("/recreate_summary_vector_store",
|
||||
summary="重建单个知识库文件摘要"
|
||||
)(recreate_summary_vector_store)
|
||||
|
||||
kb_router.include_router(summary_router)
|
||||
179
server/api_server/openai_routes.py
Normal file
179
server/api_server/openai_routes.py
Normal file
@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Tuple, AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from openai import AsyncClient
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .api_schemas import *
|
||||
from configs import logger
|
||||
from server.utils import get_model_info, get_config_platforms, get_OpenAIClient
|
||||
|
||||
|
||||
DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
|
||||
model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform)
|
||||
openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
||||
'''
|
||||
对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
|
||||
'''
|
||||
max_semaphore = 0
|
||||
selected_platform = ""
|
||||
model_infos = get_model_info(model_name=model_name, multiple=True)
|
||||
for m, c in model_infos.items():
|
||||
key = (m, c["platform_name"])
|
||||
api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
|
||||
if key not in model_semaphores:
|
||||
model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
|
||||
semaphore = model_semaphores[key]
|
||||
if semaphore._value >= api_concurrencies:
|
||||
selected_platform = c["platform_name"]
|
||||
break
|
||||
elif semaphore._value > max_semaphore:
|
||||
selected_platform = c["platform_name"]
|
||||
|
||||
key = (m, selected_platform)
|
||||
semaphore = model_semaphores[key]
|
||||
try:
|
||||
await semaphore.acquire()
|
||||
yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
|
||||
except Exception:
|
||||
logger.error(f"failed when request to {key}", exc_info=True)
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def openai_request(method, body):
|
||||
'''
|
||||
helper function to make openai request
|
||||
'''
|
||||
async def generator():
|
||||
async for chunk in await method(**params):
|
||||
yield {"data": chunk.json()}
|
||||
|
||||
params = body.dict(exclude_unset=True)
|
||||
if hasattr(body, "stream") and body.stream:
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
return (await method(**params)).dict()
|
||||
|
||||
|
||||
@openai_router.get("/models")
|
||||
async def list_models() -> Dict:
|
||||
'''
|
||||
整合所有平台的模型列表。
|
||||
由于 openai sdk 不支持重名模型,对于重名模型,只返回其中响应速度最快的一个。在请求其它接口时会自动按照模型忙闲状态进行调度。
|
||||
'''
|
||||
async def task(name: str):
|
||||
try:
|
||||
client = get_OpenAIClient(name, is_async=True)
|
||||
models = await client.models.list()
|
||||
models = models.dict(exclude=["data", "object"])
|
||||
for x in models:
|
||||
models[x]["platform_name"] = name
|
||||
return models
|
||||
except Exception:
|
||||
logger.error(f"failed request to platform: {name}", exc_info=True)
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
tasks = [asyncio.create_task(task(name)) for name in get_config_platforms()]
|
||||
for t in asyncio.as_completed(tasks):
|
||||
for n, v in (await t).items():
|
||||
if n not in result:
|
||||
result[n] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@openai_router.post("/chat/completions")
|
||||
async def create_chat_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.chat.completions.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/completions")
|
||||
async def create_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.completions.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/embeddings")
|
||||
async def create_embeddings(
|
||||
request: Request,
|
||||
body: OpenAIEmbeddingsInput,
|
||||
):
|
||||
params = body.dict(exclude_unset=True)
|
||||
client = get_OpenAIClient(model_name=body.model)
|
||||
return (await client.embeddings.create(**params)).dict()
|
||||
|
||||
|
||||
@openai_router.post("/images/generations")
|
||||
async def create_image_generations(
|
||||
request: Request,
|
||||
body: OpenAIImageGenerationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.images.generate, body)
|
||||
|
||||
|
||||
@openai_router.post("/images/variations")
|
||||
async def create_image_variations(
|
||||
request: Request,
|
||||
body: OpenAIImageVariationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.images.create_variation, body)
|
||||
|
||||
|
||||
@openai_router.post("/images/edit")
|
||||
async def create_image_edit(
|
||||
request: Request,
|
||||
body: OpenAIImageEditsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.images.edit, body)
|
||||
|
||||
|
||||
@openai_router.post("/audio/translations", deprecated="暂不支持")
|
||||
async def create_audio_translations(
|
||||
request: Request,
|
||||
body: OpenAIAudioTranslationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.translations.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/audio/transcriptions", deprecated="暂不支持")
|
||||
async def create_audio_transcriptions(
|
||||
request: Request,
|
||||
body: OpenAIAudioTranscriptionsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.transcriptions.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/audio/speech", deprecated="暂不支持")
|
||||
async def create_audio_speech(
|
||||
request: Request,
|
||||
body: OpenAIAudioSpeechInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.speech.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/files", deprecated="暂不支持")
|
||||
async def files():
|
||||
...
|
||||
91
server/api_server/server_app.py
Normal file
91
server/api_server/server_app.py
Normal file
@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import RedirectResponse
|
||||
import uvicorn
|
||||
|
||||
from configs import VERSION, MEDIA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from server.api_server.chat_routes import chat_router
|
||||
from server.api_server.kb_routes import kb_router
|
||||
from server.api_server.openai_routes import openai_router
|
||||
from server.api_server.server_routes import server_router
|
||||
from server.api_server.tool_routes import tool_router
|
||||
from server.chat.completion import completion
|
||||
from server.utils import MakeFastAPIOffline
|
||||
|
||||
|
||||
def create_app(run_mode: str=None):
|
||||
app = FastAPI(
|
||||
title="Langchain-Chatchat API Server",
|
||||
version=VERSION
|
||||
)
|
||||
MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/", summary="swagger 文档", include_in_schema=False)
|
||||
async def document():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
app.include_router(chat_router)
|
||||
app.include_router(kb_router)
|
||||
app.include_router(tool_router)
|
||||
app.include_router(openai_router)
|
||||
app.include_router(server_router)
|
||||
|
||||
# 其它接口
|
||||
app.post("/other/completion",
|
||||
tags=["Other"],
|
||||
summary="要求llm模型补全(通过LLMChain)",
|
||||
)(completion)
|
||||
|
||||
# 媒体文件
|
||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||||
' | 基于本地知识库的 ChatGLM 问答')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
|
||||
run_api(host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
23
server/api_server/server_routes.py
Normal file
23
server/api_server/server_routes.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Body
|
||||
|
||||
from server.utils import get_server_configs, get_prompt_template
|
||||
|
||||
|
||||
server_router = APIRouter(prefix="/server", tags=["Server State"])
|
||||
|
||||
|
||||
# 服务器相关接口
|
||||
server_router.post("/configs",
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
|
||||
@server_router.post("/get_prompt_template",
|
||||
summary="获取服务区配置的 prompt 模板")
|
||||
def get_server_prompt_template(
|
||||
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"),
|
||||
name: str = Body("default", description="模板名称"),
|
||||
) -> str:
|
||||
return get_prompt_template(type=type, name=name)
|
||||
|
Before Width: | Height: | Size: 7.1 KiB After Width: | Height: | Size: 7.1 KiB |
43
server/api_server/tool_routes.py
Normal file
43
server/api_server/tool_routes.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request, Body
|
||||
|
||||
from configs import logger
|
||||
from server.utils import BaseResponse
|
||||
|
||||
|
||||
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
||||
|
||||
|
||||
@tool_router.get("/", response_model=BaseResponse)
|
||||
async def list_tools():
|
||||
import importlib
|
||||
from server.agent.tools_factory import tools_registry
|
||||
importlib.reload(tools_registry)
|
||||
|
||||
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools}
|
||||
return {"data": data}
|
||||
|
||||
|
||||
@tool_router.post("/call", response_model=BaseResponse)
|
||||
async def call_tool(
|
||||
name: str = Body(examples=["calculate"]),
|
||||
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
|
||||
):
|
||||
import importlib
|
||||
from server.agent.tools_factory import tools_registry
|
||||
importlib.reload(tools_registry)
|
||||
|
||||
tool_names = {t.name: t for t in tools_registry.all_tools}
|
||||
if tool := tool_names.get(name):
|
||||
try:
|
||||
result = await tool.ainvoke(kwargs)
|
||||
return {"data": result}
|
||||
except Exception:
|
||||
msg = f"failed to call tool '{name}'"
|
||||
logger.error(msg, exc_info=True)
|
||||
return {"code": 500, "msg": msg}
|
||||
else:
|
||||
return {"code": 500, "msg": f"no tool named '{name}'"}
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from pydantic import BaseModel, Field
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
## 指定制定列的csv文件加载器
|
||||
|
||||
from langchain.document_loaders import CSVLoader
|
||||
from langchain_community.document_loaders import CSVLoader
|
||||
import csv
|
||||
from io import TextIOWrapper
|
||||
from typing import Dict, List, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain_community.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
|
||||
class FilteredCSVLoader(CSVLoader):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from typing import List
|
||||
import tqdm
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from server.document_loaders.ocr import get_ocr
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from typing import List
|
||||
import tqdm
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from server.knowledge_base.utils import (validate_kb_name, list_files_from_folde
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import FileResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
from pydantic import Json
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
@ -120,7 +119,7 @@ def upload_docs(
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串"),
|
||||
docs: str = Form("", description="自定义的docs,需要转为json字符串"),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
@ -133,6 +132,7 @@ def upload_docs(
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
docs = json.loads(docs) if docs else {}
|
||||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
@ -221,7 +221,7 @@ def update_docs(
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串"),
|
||||
docs: str = Body("", description="自定义的docs,需要转为json字符串"),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
@ -236,6 +236,7 @@ def update_docs(
|
||||
|
||||
failed_files = {}
|
||||
kb_files = []
|
||||
docs = json.loads(docs) if docs else {}
|
||||
|
||||
# 生成需要加载docs的文件列表
|
||||
for file_name in file_names:
|
||||
|
||||
@ -6,9 +6,9 @@ from chromadb.api.types import (GetResult, QueryResult)
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from configs import SCORE_THRESHOLD
|
||||
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
|
||||
KBService, SupportedVSType)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from server.utils import get_Embeddings
|
||||
|
||||
|
||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||
@ -75,7 +75,7 @@ class ChromaKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
||||
Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
||||
return _results_to_docs_and_scores(query_result)
|
||||
|
||||
@ -4,7 +4,8 @@ import shutil
|
||||
from configs import SCORE_THRESHOLD
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path, EmbeddingsFunAdapter
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from server.utils import get_Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
@ -60,8 +61,9 @@ class FaissKBService(KBService):
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) -> List[Document]:
|
||||
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
embeddings = vs.embeddings.embed_query(query)
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
@ -72,11 +74,10 @@ class FaissKBService(KBService):
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
embeddings = vs.embeddings.embed_documents(texts)
|
||||
|
||||
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
|
||||
metadatas=metadatas)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
|
||||
@ -7,9 +7,10 @@ import os
|
||||
from configs import kbs_config
|
||||
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import get_Embeddings
|
||||
|
||||
|
||||
class MilvusKBService(KBService):
|
||||
@ -49,7 +50,7 @@ class MilvusKBService(KBService):
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self):
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
self.milvus = Milvus(embedding_function=(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
connection_args=kbs_config.get("milvus"),
|
||||
index_params=kbs_config.get("milvus_kwargs")["index_params"],
|
||||
@ -66,7 +67,7 @@ class MilvusKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_milvus()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
@ -7,9 +7,10 @@ from sqlalchemy import text
|
||||
|
||||
from configs import kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import get_Embeddings
|
||||
import shutil
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.base import Engine
|
||||
@ -20,7 +21,7 @@ class PGKBService(KBService):
|
||||
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
|
||||
|
||||
def _load_pg_vector(self):
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
self.pg_vector = PGVector(embedding_function=get_Embeddings(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection=PGKBService.engine,
|
||||
@ -59,7 +60,7 @@ class PGKBService(KBService):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
@ -3,9 +3,10 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from configs import kbs_config
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import get_Embeddings
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
@ -46,7 +47,7 @@ class ZillizKBService(KBService):
|
||||
|
||||
def _load_zilliz(self):
|
||||
zilliz_args = kbs_config.get("zilliz")
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
self.zilliz = Zilliz(embedding_function=get_Embeddings(self.embed_model),
|
||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||
|
||||
def do_init(self):
|
||||
@ -59,7 +60,7 @@ class ZillizKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_zilliz()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
@ -12,7 +12,7 @@ from configs import (
|
||||
)
|
||||
import importlib
|
||||
from server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
import langchain_community.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
@ -136,7 +136,7 @@ class JSONLinesLoader(JSONLoader):
|
||||
self._json_lines = True
|
||||
|
||||
|
||||
langchain.document_loaders.JSONLinesLoader = JSONLinesLoader
|
||||
langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
@ -155,13 +155,13 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
|
||||
document_loaders_module = importlib.import_module("server.document_loaders")
|
||||
else:
|
||||
document_loaders_module = importlib.import_module("langchain.document_loaders")
|
||||
document_loaders_module = importlib.import_module("langchain_community.document_loaders")
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
document_loaders_module = importlib.import_module("langchain.document_loaders")
|
||||
document_loaders_module = importlib.import_module("langchain_community.document_loaders")
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name == "UnstructuredFileLoader":
|
||||
|
||||
7
server/pydantic_types.py
Normal file
7
server/pydantic_types.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain_core.pydantic_v1 import *
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
# from pydantic.v1 import *
|
||||
# from pydantic.v1.fields import FieldInfo
|
||||
# from pydantic.v1.schema import model_schema
|
||||
133
server/utils.py
133
server/utils.py
@ -3,11 +3,11 @@ from pathlib import Path
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from langchain_openai.llms import OpenAI
|
||||
import httpx
|
||||
import openai
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
@ -22,8 +22,10 @@ from typing import (
|
||||
)
|
||||
import logging
|
||||
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL, TEMPERATURE
|
||||
from server.minx_chat_openai import MinxChatOpenAI
|
||||
from configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE)
|
||||
from server.pydantic_types import BaseModel, Field
|
||||
from server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
@ -40,6 +42,14 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
event.set()
|
||||
|
||||
|
||||
def get_config_platforms() -> Dict[str, Dict]:
|
||||
import importlib
|
||||
from configs import model_config
|
||||
importlib.reload(model_config)
|
||||
|
||||
return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS}
|
||||
|
||||
|
||||
def get_config_models(
|
||||
model_name: str = None,
|
||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||
@ -88,40 +98,54 @@ def get_config_models(
|
||||
return result
|
||||
|
||||
|
||||
def get_model_info(model_name: str, platform_name: str = None) -> Dict:
|
||||
def get_model_info(model_name: str = None, platform_name: str = None, multiple: bool = False) -> Dict:
|
||||
'''
|
||||
获取配置的模型信息,主要是 api_base_url, api_key
|
||||
如果指定 multiple=True,则返回所有重名模型;否则仅返回第一个
|
||||
'''
|
||||
result = get_config_models(model_name=model_name, platform_name=platform_name)
|
||||
if len(result) > 0:
|
||||
return list(result.values())[0]
|
||||
if multiple:
|
||||
return result
|
||||
else:
|
||||
return list(result.values())[0]
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
model_name: str = DEFAULT_LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
local_wrap: bool = True, # use local wrapped api
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
model_info = get_model_info(model_name)
|
||||
try:
|
||||
model = ChatOpenAI(
|
||||
params = dict(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
)
|
||||
model = ChatOpenAI(**params)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True)
|
||||
model = None
|
||||
@ -136,41 +160,100 @@ def get_OpenAI(
|
||||
echo: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
local_wrap: bool = True, # use local wrapped api
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
# TODO: 从API获取模型信息
|
||||
model_info = get_model_info(model_name)
|
||||
model = OpenAI(
|
||||
params = dict(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
)
|
||||
model = OpenAI(**params)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to create OpenAI for model: {model_name}.", exc_info=True)
|
||||
model = None
|
||||
return model
|
||||
|
||||
|
||||
def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings:
|
||||
def get_Embeddings(
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
local_wrap: bool = True, # use local wrapped api
|
||||
) -> Embeddings:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||
|
||||
model_info = get_model_info(model_name=embed_model)
|
||||
params = dict(model=embed_model)
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
)
|
||||
if model_info.get("platform_type") == "openai":
|
||||
return OpenAIEmbeddings(**params)
|
||||
else:
|
||||
return LocalAIEmbeddings(**params)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
||||
|
||||
|
||||
def get_OpenAIClient(
|
||||
platform_name: str=None,
|
||||
model_name: str=None,
|
||||
is_async: bool=True,
|
||||
) -> Union[openai.Client, openai.AsyncClient]:
|
||||
'''
|
||||
construct an openai Client for specified platform or model
|
||||
'''
|
||||
if platform_name is None:
|
||||
platform_name = get_model_info(model_name=model_name, platform_name=platform_name)["platform_name"]
|
||||
platform_info = get_config_platforms().get(platform_name)
|
||||
assert platform_info, f"cannot find configured platform: {platform_name}"
|
||||
params = {
|
||||
"model": embed_model,
|
||||
"base_url": model_info.get("api_base_url"),
|
||||
"api_key": model_info.get("api_key"),
|
||||
"openai_proxy": model_info.get("api_proxy"),
|
||||
"base_url": platform_info.get("api_base_url"),
|
||||
"api_key": platform_info.get("api_key"),
|
||||
}
|
||||
if model_info.get("platform_type") == "openai":
|
||||
return OpenAIEmbeddings(**params)
|
||||
httpx_params = {}
|
||||
if api_proxy := platform_info.get("api_proxy"):
|
||||
httpx_params = {
|
||||
"proxies": api_proxy,
|
||||
"transport": httpx.HTTPTransport(local_address="0.0.0.0"),
|
||||
}
|
||||
|
||||
if is_async:
|
||||
if httpx_params:
|
||||
params["http_client"] = httpx.AsyncClient(**httpx_params)
|
||||
return openai.AsyncClient(**params)
|
||||
else:
|
||||
return LocalAIEmbeddings(**params)
|
||||
if httpx_params:
|
||||
params["http_client"] = httpx.Client(**httpx_params)
|
||||
return openai.Client(**params)
|
||||
|
||||
|
||||
class MsgType:
|
||||
@ -281,7 +364,7 @@ def iter_over_async(ait, loop=None):
|
||||
|
||||
def MakeFastAPIOffline(
|
||||
app: FastAPI,
|
||||
static_dir=Path(__file__).parent / "static",
|
||||
static_dir=Path(__file__).parent / "api_server" / "static",
|
||||
static_url="/static-offline-docs",
|
||||
docs_url: Optional[str] = "/docs",
|
||||
redoc_url: Optional[str] = "/redoc",
|
||||
|
||||
@ -42,7 +42,7 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
|
||||
|
||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
from server.api import create_app
|
||||
from server.api_server.server_app import create_app
|
||||
import uvicorn
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
@ -199,13 +199,6 @@ async def start_main_server():
|
||||
args.api_worker = True
|
||||
args.webui = False
|
||||
|
||||
elif args.llm_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api_worker = True
|
||||
args.api = False
|
||||
args.webui = False
|
||||
|
||||
if args.lite:
|
||||
args.model_worker = False
|
||||
run_mode = "lite"
|
||||
|
||||
31
tests/api/test_openai_wrap.py
Normal file
31
tests/api/test_openai_wrap.py
Normal file
@ -0,0 +1,31 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import requests
|
||||
|
||||
import openai
|
||||
|
||||
from configs import DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL
|
||||
from server.utils import api_address
|
||||
|
||||
|
||||
api_base_url = f"{api_address()}/v1"
|
||||
client = openai.Client(
|
||||
api_key="EMPTY",
|
||||
base_url=api_base_url,
|
||||
)
|
||||
|
||||
def test_chat():
|
||||
resp = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "你是谁"}],
|
||||
model=DEFAULT_LLM_MODEL,
|
||||
)
|
||||
print(resp)
|
||||
assert hasattr(resp, "choices") and len(resp.choices) > 0
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
resp = client.embeddings.create(input="你是谁", model=DEFAULT_EMBEDDING_MODEL)
|
||||
print(resp)
|
||||
assert hasattr(resp, "data") and len(resp.data) > 0
|
||||
30
tests/api/test_tools.py
Normal file
30
tests/api/test_tools.py
Normal file
@ -0,0 +1,30 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from pprint import pprint
|
||||
import requests
|
||||
|
||||
from server.utils import api_address
|
||||
|
||||
|
||||
api_base_url = f"{api_address()}/tools"
|
||||
|
||||
def test_tool_list():
|
||||
resp = requests.get(api_base_url)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
pprint(data)
|
||||
assert "calculate" in data
|
||||
|
||||
|
||||
def test_tool_call():
|
||||
data = {
|
||||
"name": "calculate",
|
||||
"kwargs": {"a":1,"b":2,"operator":"+"},
|
||||
}
|
||||
resp = requests.post(f"{api_base_url}/call", json=data)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
pprint(data)
|
||||
assert data == 3
|
||||
Loading…
x
Reference in New Issue
Block a user