- pydantic 限定为 v1,并统一项目中所有 pydantic 导入路径,为以后升级 v2 做准备

- 重构 api.py:
    - 按模块划分为不同的 router
    - 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
    - 添加 /tools 接口,可以获取/调用编写的 agent tools
    - 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
    - 待办:
        - /chat/chat 接口改为 openai 兼容
        - 添加 /chat/kb_chat 接口,openai 兼容
        - 改变 ntlk/knowledge_base/logs 等数据目录位置
This commit is contained in:
liunux4odoo 2024-03-05 10:28:39 +08:00
parent 65466007ae
commit d0846f88cc
45 changed files with 865 additions and 355 deletions

View File

@ -76,25 +76,34 @@ LLM_MODEL_CONFIG = {
},
}
# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
# - platform_name 可以任意填写,不要重复即可
# - platform_type 可选openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
MODEL_PLATFORMS = [
{
"platform_name": "openai-api",
"platform_type": "openai",
"llm_models": [
"gpt-3.5-turbo",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "https://api.openai.com/v1",
"api_key": "sk-",
"api_proxy": "",
},
# {
# "platform_name": "openai-api",
# "platform_type": "openai",
# "api_base_url": "https://api.openai.com/v1",
# "api_key": "sk-",
# "api_proxy": "",
# "api_concurrencies": 5,
# "llm_models": [
# "gpt-3.5-turbo",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
# 注意:这里填写的是 xinference 部署的模型 UID而非模型名称
"llm_models": [
"chatglm3-6b",
],
@ -107,40 +116,38 @@ MODEL_PLATFORMS = [
"multimodal_models": [
"qwen-vl",
],
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
},
{
"platform_name": "oneapi",
"platform_type": "oneapi",
"api_key": "",
"llm_models": [
"qwen-turbo",
"qwen-plus",
"chatglm_turbo",
"chatglm_std",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "http://127.0.0.1:3000/v1",
"api_key": "sk-xxx",
},
# {
# "platform_name": "oneapi",
# "platform_type": "oneapi",
# "api_base_url": "http://127.0.0.1:3000/v1",
# "api_key": "",
# "api_concurrencies": 5,
# "llm_models": [
# "qwen-turbo",
# "qwen-plus",
# "chatglm_turbo",
# "chatglm_std",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
{
"platform_name": "loom",
"platform_type": "loom",
"api_key": "",
"llm_models": [
"chatglm3-6b",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "http://127.0.0.1:7860/v1",
"api_key": "EMPTY",
},
# {
# "platform_name": "loom",
# "platform_type": "loom",
# "api_base_url": "http://127.0.0.1:7860/v1",
# "api_key": "",
# "api_concurrencies": 5,
# "llm_models": [
# "chatglm3-6b",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
]
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")

View File

@ -2,7 +2,7 @@ import sys
sys.path.append(".")
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL
from configs.model_config import DEFAULT_EMBEDDING_MODEL
from datetime import datetime

View File

@ -41,6 +41,38 @@ youtube-search==2.1.2
duckduckgo-search==3.9.9
metaphor-python==0.1.23
httpx==0.26.0
httpx_sse==0.4.0
watchdog==3.0.0
pyjwt==2.8.0
elasticsearch
numexpr>=2.8.8
strsimpy>=0.2.1
markdownify>=0.11.6
tqdm>=4.66.1
websockets>=12.0
numpy>=1.26.3
pandas~=2.1.4
pydantic<2
httpx[brotli,http2,socks]>=0.25.2
# optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5
# jq>=1.6.0
# html2text
# beautifulsoup4>=4.12.2
# pysrt>=1.1.2
# Agent and Search Tools
# arxiv>=2.1.0
# youtube-search>=2.1.2
# duckduckgo-search>=4.1.0
# metaphor-python>=0.1.23
# WebUI requirements
streamlit==1.30.0
streamlit-option-menu==0.3.12
streamlit-antd-components==0.3.1
@ -48,7 +80,3 @@ streamlit-chatbox==1.1.11
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
httpx_sse==0.4.0
watchdog==3.0.0
pyjwt==2.8.0

View File

@ -16,10 +16,8 @@ from langchain.output_parsers import OutputFixingParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.pydantic_v1 import Field
from server.pydantic_types import Field, typing, model_schema
from pydantic import typing
from pydantic.schema import model_schema
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,8 @@
# LangChain 的 ArxivQueryRun 工具
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
def arxiv(query: str):
tool = ArxivQueryRun()
return tool.run(tool_input=query)

View File

@ -1,6 +1,6 @@
import base64
import os
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
def save_base64_audio(base64_audio, file_path):
audio_data = base64.b64decode(base64_audio)

View File

@ -1,4 +1,4 @@
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
def calculate(a: float, b: float, operator: str) -> float:
if operator == "+":

View File

@ -1,4 +1,4 @@
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import TOOL_CONFIG

View File

@ -1,5 +1,5 @@
from urllib.parse import urlencode
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from server.knowledge_base.kb_doc_api import search_docs
from configs import TOOL_CONFIG

View File

@ -1,5 +1,7 @@
from langchain_community.tools import YouTubeSearchTool
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
def search_youtube(query: str):
tool = YouTubeSearchTool()
return tool.run(tool_input=query)

View File

@ -1,6 +1,8 @@
# LangChain 的 Shell 工具
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from langchain_community.tools import ShellTool
def shell(query: str):
tool = ShellTool()
return tool.run(tool_input=query)

View File

@ -6,9 +6,8 @@ from typing import List
import uuid
from langchain.agents import tool
from langchain.pydantic_v1 import Field
from server.pydantic_types import Field, FieldInfo
import openai
from pydantic.fields import FieldInfo
from configs.basic_config import MEDIA_PATH
from server.utils import MsgType

View File

@ -4,7 +4,7 @@ Method Use cogagent to generate response for a given image and query.
import base64
from io import BytesIO
from PIL import Image, ImageDraw
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from configs import TOOL_CONFIG
import re
from server.agent.container import container

View File

@ -1,7 +1,7 @@
"""
简单的单参数输入工具实现用于查询现在天气的情况
"""
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
import requests
def weather(location: str, api_key: str):

View File

@ -1,7 +1,9 @@
# Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from langchain.pydantic_v1 import BaseModel, Field
from server.pydantic_types import BaseModel, Field
wolfram_alpha_appid = "your key"
def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
ans = wolfram.run(query)

View File

@ -1,225 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION, MEDIA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse
from server.chat.chat import chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)
from typing import List, Literal
async def document():
return RedirectResponse(url="/docs")
def create_app(run_mode: str = None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
)
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
mount_app_routes(app, run_mode=run_mode)
return app
def mount_app_routes(app: FastAPI, run_mode: str = None):
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
# Tag: Chat
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
)(chat_feedback)
# 知识库相关接口
mount_knowledge_routes(app)
# 摘要相关接口
mount_filename_summary_routes(app)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
# 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
def mount_knowledge_routes(app: FastAPI):
from server.chat.file_chat import upload_temp_docs, file_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, update_info)
app.post("/chat/file_chat",
tags=["Knowledge Base Management"],
summary="文件对话"
)(file_chat)
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_docs)
app.post("/knowledge_base/update_info",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新知识库介绍"
)(update_info)
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
summary="下载对应的知识文件")(download_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
app.post("/knowledge_base/upload_temp_docs",
tags=["Knowledge Base Management"],
summary="上传文件到临时目录,用于文件对话。"
)(upload_temp_docs)
def mount_filename_summary_routes(app: FastAPI):
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据文件名称摘要"
)(summary_file_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据doc_ids摘要",
response_model=BaseResponse,
)(summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="重建单个知识库文件摘要"
)(recreate_summary_vector_store)
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
app = create_app()
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import re
from typing import Dict, List, Literal, Optional, Union
from fastapi import UploadFile
from server.pydantic_types import BaseModel, Field, AnyUrl, root_validator
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
completion_create_params,
)
from configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
class OpenAIBaseInput(BaseModel):
user: Optional[str] = None
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None
extra_body: Optional[Dict] = None
timeout: Optional[float] = None
class OpenAIChatInput(OpenAIBaseInput):
messages: List[ChatCompletionMessageParam]
model: str = DEFAULT_LLM_MODEL
frequency_penalty: Optional[float] = None
function_call: Optional[completion_create_params.FunctionCall] = None
functions: List[completion_create_params.Function] = None
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: completion_create_params.ResponseFormat = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
tools: List[ChatCompletionToolParam] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
class OpenAIEmbeddingsInput(OpenAIBaseInput):
input: Union[str, List[str]]
model: str
dimensions: Optional[int] = None
encoding_format: Optional[Literal["float", "base64"]] = None
class OpenAIImageBaseInput(OpenAIBaseInput):
model: str
n: int = 1
response_format: Optional[Literal["url", "b64_json"]] = None
size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = "256x256"
class OpenAIImageGenerationsInput(OpenAIImageBaseInput):
prompt: str
quality: Literal["standard", "hd"] = None
style: Optional[Literal["vivid", "natural"]] = None
class OpenAIImageVariationsInput(OpenAIImageBaseInput):
image: Union[UploadFile, AnyUrl]
class OpenAIImageEditsInput(OpenAIImageVariationsInput):
prompt: str
mask: Union[UploadFile, AnyUrl]
class OpenAIAudioTranslationsInput(OpenAIBaseInput):
file: Union[UploadFile, AnyUrl]
model: str
prompt: Optional[str] = None
response_format: Optional[str] = None
temperature: float = TEMPERATURE
class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput):
language: Optional[str] = None
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
class OpenAIAudioSpeechInput(OpenAIBaseInput):
input: str
model: str
voice: str
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
speed: Optional[float] = None

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Request
from server.chat.chat import chat
from server.chat.feedback import chat_feedback
from server.chat.file_chat import file_chat
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
chat_router.post("/chat",
summary="与llm模型对话(通过LLMChain)",
)(chat)
chat_router.post("/feedback",
summary="返回llm模型对话评分",
)(chat_feedback)
chat_router.post("/file_chat",
summary="文件对话"
)(file_chat)

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Request
from server.chat.file_chat import upload_temp_docs
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, update_info)
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
summary_doc_ids_to_vector_store)
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import BaseResponse, ListResponse
kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"])
kb_router.get("/list_knowledge_bases",
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
kb_router.post("/create_knowledge_base",
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
kb_router.post("/delete_knowledge_base",
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
kb_router.get("/list_files",
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_files)
kb_router.post("/search_docs",
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)
kb_router.post("/upload_docs",
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
kb_router.post("/delete_docs",
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_docs)
kb_router.post("/update_info",
response_model=BaseResponse,
summary="更新知识库介绍"
)(update_info)
kb_router.post("/update_docs",
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_docs)
kb_router.get("/download_doc",
summary="下载对应的知识文件")(download_doc)
kb_router.post("/recreate_vector_store",
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
kb_router.post("/upload_temp_docs",
summary="上传文件到临时目录,用于文件对话。"
)(upload_temp_docs)
summary_router = APIRouter(prefix="/kb_summary_api")
summary_router.post("/summary_file_to_vector_store",
summary="单个知识库根据文件名称摘要"
)(summary_file_to_vector_store)
summary_router.post("/summary_doc_ids_to_vector_store",
summary="单个知识库根据doc_ids摘要",
response_model=BaseResponse,
)(summary_doc_ids_to_vector_store)
summary_router.post("/recreate_summary_vector_store",
summary="重建单个知识库文件摘要"
)(recreate_summary_vector_store)
kb_router.include_router(summary_router)

View File

@ -0,0 +1,179 @@
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, Tuple, AsyncGenerator
from fastapi import APIRouter, Request
from openai import AsyncClient
from sse_starlette.sse import EventSourceResponse
from .api_schemas import *
from configs import logger
from server.utils import get_model_info, get_config_platforms, get_OpenAIClient
DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform)
openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])
@asynccontextmanager
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
'''
对重名模型进行调度依次选择空闲的模型 -> 当前访问数最少的模型
'''
max_semaphore = 0
selected_platform = ""
model_infos = get_model_info(model_name=model_name, multiple=True)
for m, c in model_infos.items():
key = (m, c["platform_name"])
api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
if key not in model_semaphores:
model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
semaphore = model_semaphores[key]
if semaphore._value >= api_concurrencies:
selected_platform = c["platform_name"]
break
elif semaphore._value > max_semaphore:
selected_platform = c["platform_name"]
key = (m, selected_platform)
semaphore = model_semaphores[key]
try:
await semaphore.acquire()
yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
except Exception:
logger.error(f"failed when request to {key}", exc_info=True)
finally:
semaphore.release()
async def openai_request(method, body):
'''
helper function to make openai request
'''
async def generator():
async for chunk in await method(**params):
yield {"data": chunk.json()}
params = body.dict(exclude_unset=True)
if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
return (await method(**params)).dict()
@openai_router.get("/models")
async def list_models() -> Dict:
'''
整合所有平台的模型列表
由于 openai sdk 不支持重名模型对于重名模型只返回其中响应速度最快的一个在请求其它接口时会自动按照模型忙闲状态进行调度
'''
async def task(name: str):
try:
client = get_OpenAIClient(name, is_async=True)
models = await client.models.list()
models = models.dict(exclude=["data", "object"])
for x in models:
models[x]["platform_name"] = name
return models
except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True)
return {}
result = {}
tasks = [asyncio.create_task(task(name)) for name in get_config_platforms()]
for t in asyncio.as_completed(tasks):
for n, v in (await t).items():
if n not in result:
result[n] = v
return result
@openai_router.post("/chat/completions")
async def create_chat_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.chat.completions.create, body)
@openai_router.post("/completions")
async def create_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.completions.create, body)
@openai_router.post("/embeddings")
async def create_embeddings(
request: Request,
body: OpenAIEmbeddingsInput,
):
params = body.dict(exclude_unset=True)
client = get_OpenAIClient(model_name=body.model)
return (await client.embeddings.create(**params)).dict()
@openai_router.post("/images/generations")
async def create_image_generations(
request: Request,
body: OpenAIImageGenerationsInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.images.generate, body)
@openai_router.post("/images/variations")
async def create_image_variations(
request: Request,
body: OpenAIImageVariationsInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.images.create_variation, body)
@openai_router.post("/images/edit")
async def create_image_edit(
request: Request,
body: OpenAIImageEditsInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.images.edit, body)
@openai_router.post("/audio/translations", deprecated="暂不支持")
async def create_audio_translations(
request: Request,
body: OpenAIAudioTranslationsInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.audio.translations.create, body)
@openai_router.post("/audio/transcriptions", deprecated="暂不支持")
async def create_audio_transcriptions(
request: Request,
body: OpenAIAudioTranscriptionsInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.audio.transcriptions.create, body)
@openai_router.post("/audio/speech", deprecated="暂不支持")
async def create_audio_speech(
request: Request,
body: OpenAIAudioSpeechInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.audio.speech.create, body)
@openai_router.post("/files", deprecated="暂不支持")
async def files():
...

View File

@ -0,0 +1,91 @@
import argparse
from typing import Literal
from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse
import uvicorn
from configs import VERSION, MEDIA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
from server.api_server.chat_routes import chat_router
from server.api_server.kb_routes import kb_router
from server.api_server.openai_routes import openai_router
from server.api_server.server_routes import server_router
from server.api_server.tool_routes import tool_router
from server.chat.completion import completion
from server.utils import MakeFastAPIOffline
def create_app(run_mode: str=None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
)
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", summary="swagger 文档", include_in_schema=False)
async def document():
return RedirectResponse(url="/docs")
app.include_router(chat_router)
app.include_router(kb_router)
app.include_router(tool_router)
app.include_router(openai_router)
app.include_router(server_router)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
# 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
return app
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
app = create_app()
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)

View File

@ -0,0 +1,23 @@
from typing import Literal
from fastapi import APIRouter, Body
from server.utils import get_server_configs, get_prompt_template
server_router = APIRouter(prefix="/server", tags=["Server State"])
# 服务器相关接口
server_router.post("/configs",
summary="获取服务器原始配置信息",
)(get_server_configs)
@server_router.post("/get_prompt_template",
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)

View File

Before

Width:  |  Height:  |  Size: 7.1 KiB

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Request, Body
from configs import logger
from server.utils import BaseResponse
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse)
async def list_tools():
import importlib
from server.agent.tools_factory import tools_registry
importlib.reload(tools_registry)
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools}
return {"data": data}
@tool_router.post("/call", response_model=BaseResponse)
async def call_tool(
name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
):
import importlib
from server.agent.tools_factory import tools_registry
importlib.reload(tools_registry)
tool_names = {t.name: t for t in tools_registry.all_tools}
if tool := tool_names.get(name):
try:
result = await tool.ainvoke(kwargs)
return {"data": result}
except Exception:
msg = f"failed to call tool '{name}'"
logger.error(msg, exc_info=True)
return {"code": 500, "msg": msg}
else:
return {"code": 500, "msg": f"no tool named '{name}'"}

View File

@ -1,5 +1,5 @@
from functools import lru_cache
from pydantic import BaseModel, Field
from server.pydantic_types import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
from typing import List, Tuple, Dict, Union

View File

@ -1,11 +1,11 @@
## 指定制定列的csv文件加载器
from langchain.document_loaders import CSVLoader
from langchain_community.document_loaders import CSVLoader
import csv
from io import TextIOWrapper
from typing import Dict, List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.helpers import detect_file_encodings
from langchain_community.document_loaders.helpers import detect_file_encodings
class FilteredCSVLoader(CSVLoader):

View File

@ -1,4 +1,4 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm

View File

@ -1,5 +1,5 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from server.document_loaders.ocr import get_ocr

View File

@ -1,5 +1,5 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
import cv2
from PIL import Image
import numpy as np

View File

@ -1,4 +1,4 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm

View File

@ -10,7 +10,6 @@ from server.knowledge_base.utils import (validate_kb_name, list_files_from_folde
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import FileResponse
from sse_starlette import EventSourceResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
@ -120,7 +119,7 @@ def upload_docs(
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串"),
docs: str = Form("", description="自定义的docs需要转为json字符串"),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
@ -133,6 +132,7 @@ def upload_docs(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
docs = json.loads(docs) if docs else {}
failed_files = {}
file_names = list(docs.keys())
@ -221,7 +221,7 @@ def update_docs(
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs需要转为json字符串"),
docs: str = Body("", description="自定义的docs需要转为json字符串"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
@ -236,6 +236,7 @@ def update_docs(
failed_files = {}
kb_files = []
docs = json.loads(docs) if docs else {}
# 生成需要加载docs的文件列表
for file_name in file_names:

View File

@ -6,9 +6,9 @@ from chromadb.api.types import (GetResult, QueryResult)
from langchain.docstore.document import Document
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
KBService, SupportedVSType)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import get_Embeddings
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
@ -75,7 +75,7 @@ class ChromaKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)

View File

@ -4,7 +4,8 @@ import shutil
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path, EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import get_Embeddings
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
@ -60,8 +61,9 @@ class FaissKBService(KBService):
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]:
) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
@ -72,11 +74,10 @@ class FaissKBService(KBService):
**kwargs,
) -> List[Dict]:
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
with self.load_vector_store().acquire() as vs:
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = vs.embeddings.embed_documents(texts)
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
metadatas=metadatas)
if not kwargs.get("not_refresh_vs_cache"):

View File

@ -7,9 +7,10 @@ import os
from configs import kbs_config
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
from server.utils import get_Embeddings
class MilvusKBService(KBService):
@ -49,7 +50,7 @@ class MilvusKBService(KBService):
return SupportedVSType.MILVUS
def _load_milvus(self):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
self.milvus = Milvus(embedding_function=(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"],
@ -66,7 +67,7 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
embed_func = EmbeddingsFunAdapter(self.embed_model)
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)

View File

@ -7,9 +7,10 @@ from sqlalchemy import text
from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
from server.utils import get_Embeddings
import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
@ -20,7 +21,7 @@ class PGKBService(KBService):
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
self.pg_vector = PGVector(embedding_function=get_Embeddings(self.embed_model),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
@ -59,7 +60,7 @@ class PGKBService(KBService):
shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float):
embed_func = EmbeddingsFunAdapter(self.embed_model)
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)

View File

@ -3,9 +3,10 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Zilliz
from configs import kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
from server.utils import get_Embeddings
class ZillizKBService(KBService):
@ -46,7 +47,7 @@ class ZillizKBService(KBService):
def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
self.zilliz = Zilliz(embedding_function=get_Embeddings(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self):
@ -59,7 +60,7 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz()
embed_func = EmbeddingsFunAdapter(self.embed_model)
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)

View File

@ -12,7 +12,7 @@ from configs import (
)
import importlib
from server.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain.document_loaders
import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
@ -136,7 +136,7 @@ class JSONLinesLoader(JSONLoader):
self._json_lines = True
langchain.document_loaders.JSONLinesLoader = JSONLinesLoader
langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader
def get_LoaderClass(file_extension):
@ -155,13 +155,13 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module("server.document_loaders")
else:
document_loaders_module = importlib.import_module("langchain.document_loaders")
document_loaders_module = importlib.import_module("langchain_community.document_loaders")
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
document_loaders_module = importlib.import_module("langchain.document_loaders")
document_loaders_module = importlib.import_module("langchain_community.document_loaders")
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":

7
server/pydantic_types.py Normal file
View File

@ -0,0 +1,7 @@
from langchain_core.pydantic_v1 import *
from pydantic.fields import FieldInfo
from pydantic.schema import model_schema
# from pydantic.v1 import *
# from pydantic.v1.fields import FieldInfo
# from pydantic.v1.schema import model_schema

View File

@ -3,11 +3,11 @@ from pathlib import Path
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.pydantic_v1 import BaseModel, Field
from langchain.embeddings.base import Embeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.llms import OpenAI
import httpx
import openai
from typing import (
Optional,
Callable,
@ -22,8 +22,10 @@ from typing import (
)
import logging
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL, TEMPERATURE
from server.minx_chat_openai import MinxChatOpenAI
from configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE)
from server.pydantic_types import BaseModel, Field
from server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -40,6 +42,14 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
event.set()
def get_config_platforms() -> Dict[str, Dict]:
import importlib
from configs import model_config
importlib.reload(model_config)
return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS}
def get_config_models(
model_name: str = None,
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
@ -88,40 +98,54 @@ def get_config_models(
return result
def get_model_info(model_name: str, platform_name: str = None) -> Dict:
def get_model_info(model_name: str = None, platform_name: str = None, multiple: bool = False) -> Dict:
'''
获取配置的模型信息主要是 api_base_url, api_key
如果指定 multiple=True则返回所有重名模型否则仅返回第一个
'''
result = get_config_models(model_name=model_name, platform_name=platform_name)
if len(result) > 0:
return list(result.values())[0]
if multiple:
return result
else:
return list(result.values())[0]
else:
return {}
def get_ChatOpenAI(
model_name: str,
model_name: str = DEFAULT_LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
local_wrap: bool = True, # use local wrapped api
**kwargs: Any,
) -> ChatOpenAI:
model_info = get_model_info(model_name)
try:
model = ChatOpenAI(
params = dict(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
**kwargs
)
)
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
)
else:
params.update(
openai_api_base=model_info.get("api_base_url"),
openai_api_key=model_info.get("api_key"),
openai_proxy=model_info.get("api_proxy"),
)
model = ChatOpenAI(**params)
except Exception as e:
logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True)
model = None
@ -136,41 +160,100 @@ def get_OpenAI(
echo: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
local_wrap: bool = True, # use local wrapped api
**kwargs: Any,
) -> OpenAI:
# TODO: 从API获取模型信息
model_info = get_model_info(model_name)
model = OpenAI(
params = dict(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
echo=echo,
**kwargs
)
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
)
else:
params.update(
openai_api_base=model_info.get("api_base_url"),
openai_api_key=model_info.get("api_key"),
openai_proxy=model_info.get("api_proxy"),
)
model = OpenAI(**params)
except Exception as e:
logger.error(f"failed to create OpenAI for model: {model_name}.", exc_info=True)
model = None
return model
def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings:
def get_Embeddings(
embed_model: str = DEFAULT_EMBEDDING_MODEL,
local_wrap: bool = True, # use local wrapped api
) -> Embeddings:
from langchain_community.embeddings.openai import OpenAIEmbeddings
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
model_info = get_model_info(model_name=embed_model)
params = dict(model=embed_model)
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
)
else:
params.update(
openai_api_base=model_info.get("api_base_url"),
openai_api_key=model_info.get("api_key"),
openai_proxy=model_info.get("api_proxy"),
)
if model_info.get("platform_type") == "openai":
return OpenAIEmbeddings(**params)
else:
return LocalAIEmbeddings(**params)
except Exception as e:
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
def get_OpenAIClient(
platform_name: str=None,
model_name: str=None,
is_async: bool=True,
) -> Union[openai.Client, openai.AsyncClient]:
'''
construct an openai Client for specified platform or model
'''
if platform_name is None:
platform_name = get_model_info(model_name=model_name, platform_name=platform_name)["platform_name"]
platform_info = get_config_platforms().get(platform_name)
assert platform_info, f"cannot find configured platform: {platform_name}"
params = {
"model": embed_model,
"base_url": model_info.get("api_base_url"),
"api_key": model_info.get("api_key"),
"openai_proxy": model_info.get("api_proxy"),
"base_url": platform_info.get("api_base_url"),
"api_key": platform_info.get("api_key"),
}
if model_info.get("platform_type") == "openai":
return OpenAIEmbeddings(**params)
httpx_params = {}
if api_proxy := platform_info.get("api_proxy"):
httpx_params = {
"proxies": api_proxy,
"transport": httpx.HTTPTransport(local_address="0.0.0.0"),
}
if is_async:
if httpx_params:
params["http_client"] = httpx.AsyncClient(**httpx_params)
return openai.AsyncClient(**params)
else:
return LocalAIEmbeddings(**params)
if httpx_params:
params["http_client"] = httpx.Client(**httpx_params)
return openai.Client(**params)
class MsgType:
@ -281,7 +364,7 @@ def iter_over_async(ait, loop=None):
def MakeFastAPIOffline(
app: FastAPI,
static_dir=Path(__file__).parent / "static",
static_dir=Path(__file__).parent / "api_server" / "static",
static_url="/static-offline-docs",
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",

View File

@ -42,7 +42,7 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api import create_app
from server.api_server.server_app import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
@ -199,13 +199,6 @@ async def start_main_server():
args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api_worker = True
args.api = False
args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"

View File

@ -0,0 +1,31 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
import requests
import openai
from configs import DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL
from server.utils import api_address
api_base_url = f"{api_address()}/v1"
client = openai.Client(
api_key="EMPTY",
base_url=api_base_url,
)
def test_chat():
resp = client.chat.completions.create(
messages=[{"role": "user", "content": "你是谁"}],
model=DEFAULT_LLM_MODEL,
)
print(resp)
assert hasattr(resp, "choices") and len(resp.choices) > 0
def test_embeddings():
resp = client.embeddings.create(input="你是谁", model=DEFAULT_EMBEDDING_MODEL)
print(resp)
assert hasattr(resp, "data") and len(resp.data) > 0

30
tests/api/test_tools.py Normal file
View File

@ -0,0 +1,30 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from pprint import pprint
import requests
from server.utils import api_address
api_base_url = f"{api_address()}/tools"
def test_tool_list():
resp = requests.get(api_base_url)
assert resp.status_code == 200
data = resp.json()["data"]
pprint(data)
assert "calculate" in data
def test_tool_call():
data = {
"name": "calculate",
"kwargs": {"a":1,"b":2,"operator":"+"},
}
resp = requests.post(f"{api_base_url}/call", json=data)
assert resp.status_code == 200
data = resp.json()["data"]
pprint(data)
assert data == 3