mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
兼容model_providers,集成webui及API中平台配置的初始化 (#3625)
* provider_configuration init of MODEL_PLATFORMS * 开发手册 * 兼容model_providers,集成webui及API中平台配置的初始化
This commit is contained in:
parent
c0634828a4
commit
b3dee0b1d1
3
.gitignore
vendored
3
.gitignore
vendored
@ -183,4 +183,5 @@ configs/*.py
|
||||
/knowledge_base/samples/content/imi_temeplate.txt
|
||||
/chatchat/configs/*.py
|
||||
/chatchat/configs/*.yaml
|
||||
chatchat/data
|
||||
chatchat/data
|
||||
/chatchat-server/chatchat/configs/model_providers.yaml
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
log_path: "logs"
|
||||
log_level: "DEBUG"
|
||||
|
||||
api_server:
|
||||
host: "127.0.0.1"
|
||||
port: 8000
|
||||
|
||||
publish_server:
|
||||
host: "127.0.0.1"
|
||||
port: 8001
|
||||
|
||||
subscribe_server:
|
||||
host: "127.0.0.1"
|
||||
port: 8002
|
||||
|
||||
openai_plugins_folder:
|
||||
- "openai_plugins"
|
||||
openai_plugins_load_folder:
|
||||
- "configs"
|
||||
|
||||
|
||||
plugins:
|
||||
- openai:
|
||||
name: "openai"
|
||||
- zhipuai:
|
||||
name: "zhipuai"
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
# 默认选用的 LLM 名称
|
||||
DEFAULT_LLM_MODEL = "chatglm3-6b"
|
||||
|
||||
@ -31,7 +30,7 @@ SUPPORT_AGENT_MODELS = [
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
# 意图识别不需要输出,模型后台知道就行
|
||||
# 意图识别不需要输出,模型后台知道就行
|
||||
"preprocess_model": {
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.05,
|
||||
@ -57,7 +56,7 @@ LLM_MODEL_CONFIG = {
|
||||
"prompt_name": "ChatGLM3",
|
||||
"callbacks": True
|
||||
},
|
||||
},
|
||||
},
|
||||
"postprocess_model": {
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.01,
|
||||
@ -76,47 +75,15 @@ LLM_MODEL_CONFIG = {
|
||||
},
|
||||
}
|
||||
|
||||
# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
||||
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台
|
||||
# ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置
|
||||
# - platform_name 可以任意填写,不要重复即可
|
||||
# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
|
||||
# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可
|
||||
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
|
||||
|
||||
MODEL_PLATFORMS = [
|
||||
# {
|
||||
# "platform_name": "openai-api",
|
||||
# "platform_type": "openai",
|
||||
# "api_base_url": "https://api.openai.com/v1",
|
||||
# "api_key": "sk-",
|
||||
# "api_proxy": "",
|
||||
# "api_concurrencies": 5,
|
||||
# "llm_models": [
|
||||
# "gpt-3.5-turbo",
|
||||
# ],
|
||||
# "embed_models": [],
|
||||
# "image_models": [],
|
||||
# "multimodal_models": [],
|
||||
# },
|
||||
|
||||
{
|
||||
"platform_name": "xinference",
|
||||
"platform_type": "xinference",
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
"api_concurrencies": 5,
|
||||
# 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
"embed_models": [
|
||||
"bge-large-zh-v1.5",
|
||||
],
|
||||
"image_models": [
|
||||
"sd-turbo",
|
||||
],
|
||||
"multimodal_models": [
|
||||
"qwen-vl",
|
||||
],
|
||||
},
|
||||
# 创建一个全局的共享字典
|
||||
MODEL_PLATFORMS = [
|
||||
|
||||
{
|
||||
"platform_name": "oneapi",
|
||||
@ -152,41 +119,13 @@ MODEL_PLATFORMS = [
|
||||
"multimodal_models": [],
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "ollama",
|
||||
"platform_type": "ollama",
|
||||
"api_base_url": "http://{host}:{port}/v1",
|
||||
"api_key": "sk-",
|
||||
"api_concurrencies": 5,
|
||||
"llm_models": [
|
||||
# Qwen API,其它更多模型请参考https://ollama.com/library
|
||||
"qwen:7b",
|
||||
],
|
||||
"embed_models": [
|
||||
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
|
||||
"nomic-embed-text"
|
||||
],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
},
|
||||
|
||||
# {
|
||||
# "platform_name": "loom",
|
||||
# "platform_type": "loom",
|
||||
# "api_base_url": "http://127.0.0.1:7860/v1",
|
||||
# "api_key": "",
|
||||
# "api_concurrencies": 5,
|
||||
# "llm_models": [
|
||||
# "chatglm3-6b",
|
||||
# ],
|
||||
# "embed_models": [],
|
||||
# "image_models": [],
|
||||
# "multimodal_models": [],
|
||||
# },
|
||||
]
|
||||
|
||||
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
|
||||
MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml")
|
||||
MODEL_PROVIDERS_CFG_HOST = "127.0.0.1"
|
||||
|
||||
MODEL_PROVIDERS_CFG_PORT = 20000
|
||||
# 工具配置项
|
||||
TOOL_CONFIG = {
|
||||
"search_local_knowledgebase": {
|
||||
|
||||
@ -0,0 +1,29 @@
|
||||
openai:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
- model: 'gpt-4'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
provider_credential:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
{
|
||||
"openai_plugins": [
|
||||
"imitater", "openai"
|
||||
]
|
||||
|
||||
}
|
||||
109
chatchat-server/chatchat/model_loaders/init_server.py
Normal file
109
chatchat-server/chatchat/model_loaders/init_server.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import List, Dict
|
||||
from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG
|
||||
from model_providers import BootstrapWebBuilder
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
get_timestamp_ms,
|
||||
)
|
||||
import multiprocessing as mp
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_server(model_platforms_shard: Dict,
|
||||
started_event: mp.Event = None,
|
||||
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
||||
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
|
||||
provider_port: int = MODEL_PROVIDERS_CFG_PORT,
|
||||
log_path: str = "logs"
|
||||
) -> None:
|
||||
logging_conf = get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
|
||||
try:
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path=model_providers_cfg_path
|
||||
)
|
||||
.host(host=provider_host)
|
||||
.port(port=provider_port)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=started_event)
|
||||
|
||||
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
|
||||
model_platforms_shard['provider_platforms'] = provider_platforms
|
||||
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
|
||||
|
||||
def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
|
||||
provider_list: List[ProviderResponse] = ProvidersWrapper(
|
||||
provider_manager=provider_manager).get_provider_list()
|
||||
logger.info(f"Provider list: {provider_list}")
|
||||
# 转换MODEL_PLATFORMS
|
||||
provider_platforms = []
|
||||
for provider in provider_list:
|
||||
provider_dict = {
|
||||
"platform_name": provider.provider,
|
||||
"platform_type": provider.provider,
|
||||
"api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1",
|
||||
"api_key": "EMPTY",
|
||||
"api_concurrencies": 5
|
||||
}
|
||||
|
||||
provider_dict["llm_models"] = []
|
||||
provider_dict["embed_models"] = []
|
||||
provider_dict["image_models"] = []
|
||||
provider_dict["multimodal_models"] = []
|
||||
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
|
||||
provider.supported_model_types]
|
||||
|
||||
for model_type in supported_model_str_types:
|
||||
|
||||
providers_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type)
|
||||
cur_model_type: List[str] = []
|
||||
# 查询当前provider的模型
|
||||
for provider_model in providers_model_type:
|
||||
if provider_model.provider == provider.provider:
|
||||
models = [model.model for model in provider_model.models]
|
||||
cur_model_type.extend(models)
|
||||
|
||||
if cur_model_type:
|
||||
if model_type == "text-generation":
|
||||
provider_dict["llm_models"] = cur_model_type
|
||||
elif model_type == "text-embedding":
|
||||
provider_dict["embed_models"] = cur_model_type
|
||||
elif model_type == "text2img":
|
||||
provider_dict["image_models"] = cur_model_type
|
||||
elif model_type == "multimodal":
|
||||
provider_dict["multimodal_models"] = cur_model_type
|
||||
else:
|
||||
logger.warning(f"Unsupported model type: {model_type}")
|
||||
|
||||
provider_platforms.append(provider_dict)
|
||||
|
||||
logger.info(f"Provider platforms: {provider_platforms}")
|
||||
|
||||
return provider_platforms
|
||||
@ -27,9 +27,10 @@ from typing import (
|
||||
import logging
|
||||
|
||||
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE)
|
||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE,
|
||||
MODEL_PLATFORMS)
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field
|
||||
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
||||
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
@ -47,17 +48,18 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
||||
|
||||
def get_config_platforms() -> Dict[str, Dict]:
|
||||
import importlib
|
||||
from chatchat.configs import model_config
|
||||
importlib.reload(model_config)
|
||||
# import importlib
|
||||
# 不能支持重载
|
||||
# from chatchat.configs import model_config
|
||||
# importlib.reload(model_config)
|
||||
|
||||
return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS}
|
||||
return {m["platform_name"]: m for m in MODEL_PLATFORMS}
|
||||
|
||||
|
||||
def get_config_models(
|
||||
model_name: str = None,
|
||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||
platform_name: str = None,
|
||||
model_name: str = None,
|
||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||
platform_name: str = None,
|
||||
) -> Dict[str, Dict]:
|
||||
'''
|
||||
获取配置的模型列表,返回值为:
|
||||
@ -71,12 +73,13 @@ def get_config_models(
|
||||
"api_proxy": xx,
|
||||
}}
|
||||
'''
|
||||
import importlib
|
||||
from chatchat.configs import model_config
|
||||
importlib.reload(model_config)
|
||||
# import importlib
|
||||
# 不能支持重载
|
||||
# from chatchat.configs import model_config
|
||||
# importlib.reload(model_config)
|
||||
|
||||
result = {}
|
||||
for m in model_config.MODEL_PLATFORMS:
|
||||
for m in MODEL_PLATFORMS:
|
||||
if platform_name is not None and platform_name != m.get("platform_name"):
|
||||
continue
|
||||
if model_type is not None and f"{model_type}_models" not in m:
|
||||
@ -124,24 +127,24 @@ def get_ChatOpenAI(
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
model_info = get_model_info(model_name)
|
||||
params = dict(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
openai_api_base=f"{api_address()}/v1",
|
||||
openai_api_key="EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
@ -164,7 +167,7 @@ def get_OpenAI(
|
||||
echo: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
# TODO: 从API获取模型信息
|
||||
@ -182,8 +185,8 @@ def get_OpenAI(
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
openai_api_base=f"{api_address()}/v1",
|
||||
openai_api_key="EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
@ -199,20 +202,20 @@ def get_OpenAI(
|
||||
|
||||
|
||||
def get_Embeddings(
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
) -> Embeddings:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||
|
||||
model_info = get_model_info(model_name=embed_model)
|
||||
params = dict(model=embed_model)
|
||||
try:
|
||||
if local_wrap:
|
||||
params.update(
|
||||
openai_api_base = f"{api_address()}/v1",
|
||||
openai_api_key = "EMPTY",
|
||||
openai_api_base=f"{api_address()}/v1",
|
||||
openai_api_key="EMPTY",
|
||||
)
|
||||
else:
|
||||
params.update(
|
||||
@ -223,7 +226,7 @@ def get_Embeddings(
|
||||
if model_info.get("platform_type") == "openai":
|
||||
return OpenAIEmbeddings(**params)
|
||||
elif model_info.get("platform_type") == "ollama":
|
||||
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
|
||||
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1', ''),
|
||||
model=embed_model,
|
||||
)
|
||||
else:
|
||||
@ -233,9 +236,9 @@ def get_Embeddings(
|
||||
|
||||
|
||||
def get_OpenAIClient(
|
||||
platform_name: str=None,
|
||||
model_name: str=None,
|
||||
is_async: bool=True,
|
||||
platform_name: str = None,
|
||||
model_name: str = None,
|
||||
is_async: bool = True,
|
||||
) -> Union[openai.Client, openai.AsyncClient]:
|
||||
'''
|
||||
construct an openai Client for specified platform or model
|
||||
@ -601,7 +604,7 @@ def run_in_process_pool(
|
||||
tasks = []
|
||||
max_workers = None
|
||||
if sys.platform.startswith("win"):
|
||||
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
|
||||
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as pool:
|
||||
for kwargs in params:
|
||||
tasks.append(pool.submit(func, **kwargs))
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
from contextlib import asynccontextmanager
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
@ -6,6 +7,7 @@ import subprocess
|
||||
import sys
|
||||
from multiprocessing import Process
|
||||
|
||||
from chatchat.model_loaders.init_server import init_server
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
@ -23,7 +25,7 @@ from chatchat.configs import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
API_SERVER,
|
||||
WEBUI_SERVER,
|
||||
WEBUI_SERVER, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT
|
||||
)
|
||||
from chatchat.server.utils import FastAPI
|
||||
from chatchat.server.knowledge_base.migrate import create_tables
|
||||
@ -38,15 +40,34 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
yield
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
|
||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
def run_init_server(
|
||||
model_platforms_shard: Dict,
|
||||
started_event: mp.Event = None,
|
||||
run_mode: str = None,
|
||||
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
||||
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
|
||||
provider_port: int = MODEL_PROVIDERS_CFG_PORT):
|
||||
init_server(model_platforms_shard=model_platforms_shard,
|
||||
started_event=started_event,
|
||||
model_providers_cfg_path=model_providers_cfg_path,
|
||||
provider_host=provider_host,
|
||||
provider_port=provider_port)
|
||||
|
||||
|
||||
def run_api_server(model_platforms_shard: Dict,
|
||||
started_event: mp.Event = None,
|
||||
run_mode: str = None):
|
||||
from chatchat.server.api_server.server_app import create_app
|
||||
import uvicorn
|
||||
from chatchat.server.utils import set_httpx_config
|
||||
from chatchat.configs import MODEL_PLATFORMS
|
||||
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
|
||||
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
|
||||
set_httpx_config()
|
||||
|
||||
app = create_app(run_mode=run_mode)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
@ -56,48 +77,65 @@ def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
||||
def run_webui(model_platforms_shard: Dict,
|
||||
started_event: mp.Event = None, run_mode: str = None):
|
||||
import sys
|
||||
from chatchat.server.utils import set_httpx_config
|
||||
|
||||
from chatchat.configs import MODEL_PLATFORMS
|
||||
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
|
||||
logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}")
|
||||
set_httpx_config()
|
||||
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
# 判断系统是否为Windows
|
||||
if sys.platform == "win32":
|
||||
st_exe = os.path.join(os.path.dirname(sys.executable), "Scripts", "streamlit")
|
||||
else:
|
||||
st_exe = os.path.join(os.path.dirname(sys.executable),"streamlit")
|
||||
|
||||
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py')
|
||||
cmd = [st_exe, "run", script_dir,
|
||||
"--server.address", host,
|
||||
"--server.port", str(port),
|
||||
"--theme.base", "light",
|
||||
"--theme.primaryColor", "#165dff",
|
||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||
"--theme.textColor", "#000000",
|
||||
]
|
||||
|
||||
flag_options = {'server_address': host,
|
||||
'server_port': port,
|
||||
'theme_base': 'light',
|
||||
'theme_primaryColor': '#165dff',
|
||||
'theme_secondaryBackgroundColor': '#f5f5f5',
|
||||
'theme_textColor': '#000000',
|
||||
'global_disableWatchdogWarning': None,
|
||||
'global_disableWidgetStateDuplicationWarning': None,
|
||||
'global_showWarningOnDirectExecution': None,
|
||||
'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None,
|
||||
'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None,
|
||||
'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None,
|
||||
'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None,
|
||||
'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None,
|
||||
'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None,
|
||||
'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None,
|
||||
'runner_postScriptGC': None, 'runner_fastReruns': None,
|
||||
'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None,
|
||||
'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None,
|
||||
'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None,
|
||||
'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None,
|
||||
'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None,
|
||||
'server_enableWebsocketCompression': None, 'server_enableStaticServing': None,
|
||||
'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None,
|
||||
'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None,
|
||||
'ui_hideSidebarNav': None, 'magic_displayRootDocString': None,
|
||||
'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None,
|
||||
'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None,
|
||||
'theme_backgroundColor': None, 'theme_font': None}
|
||||
|
||||
args = []
|
||||
if run_mode == "lite":
|
||||
cmd += [
|
||||
args += [
|
||||
"--",
|
||||
"lite",
|
||||
]
|
||||
p = subprocess.Popen(cmd)
|
||||
|
||||
try:
|
||||
# for streamlit >= 1.12.1
|
||||
from streamlit.web import bootstrap
|
||||
except ImportError:
|
||||
from streamlit import bootstrap
|
||||
|
||||
bootstrap.run(script_dir, False, args, flag_options)
|
||||
started_event.set()
|
||||
p.wait()
|
||||
|
||||
|
||||
def run_loom(started_event: mp.Event = None):
|
||||
from chatchat.configs import LOOM_CONFIG
|
||||
|
||||
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
|
||||
"-f", LOOM_CONFIG
|
||||
]
|
||||
|
||||
p = subprocess.Popen(cmd)
|
||||
started_event.set()
|
||||
p.wait()
|
||||
|
||||
|
||||
def parse_args() -> argparse.ArgumentParser:
|
||||
@ -106,13 +144,13 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
"-a",
|
||||
"--all-webui",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
|
||||
help="run model_providers servers,run api.py and webui.py",
|
||||
dest="all_webui",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
||||
help="run model_providers servers, run api.py",
|
||||
dest="all_api",
|
||||
)
|
||||
|
||||
@ -156,11 +194,18 @@ def dump_server_info(after_start=False, args=None):
|
||||
|
||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||
|
||||
print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}")
|
||||
print(f"默认选用的 Embedding 名称: {DEFAULT_EMBEDDING_MODEL}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
if args.api:
|
||||
print(
|
||||
f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n"
|
||||
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n"
|
||||
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n")
|
||||
|
||||
print(f" Chatchat Api Server: {api_address()}")
|
||||
if args.webui:
|
||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
@ -193,21 +238,16 @@ async def start_main_server():
|
||||
args, parser = parse_args()
|
||||
|
||||
if args.all_webui:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = True
|
||||
|
||||
elif args.all_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = False
|
||||
|
||||
if args.lite:
|
||||
args.model_worker = False
|
||||
run_mode = "lite"
|
||||
|
||||
dump_server_info(args=args)
|
||||
@ -216,25 +256,29 @@ async def start_main_server():
|
||||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {"online_api": {}, "model_worker": {}}
|
||||
processes = {}
|
||||
|
||||
def process_count():
|
||||
return len(processes)
|
||||
|
||||
loom_started = manager.Event()
|
||||
# process = Process(
|
||||
# target=run_loom,
|
||||
# name=f"run_loom Server",
|
||||
# kwargs=dict(started_event=loom_started),
|
||||
# daemon=True,
|
||||
# )
|
||||
# processes["run_loom"] = process
|
||||
# 定义全局配置变量,使用 Manager 创建共享字典
|
||||
model_platforms_shard = manager.dict()
|
||||
model_providers_started = manager.Event()
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_init_server,
|
||||
name=f"Model providers Server",
|
||||
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started,
|
||||
run_mode=run_mode),
|
||||
daemon=True,
|
||||
)
|
||||
processes["model_providers"] = process
|
||||
api_started = manager.Event()
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server",
|
||||
kwargs=dict(started_event=api_started, run_mode=run_mode),
|
||||
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
|
||||
daemon=True,
|
||||
)
|
||||
processes["api"] = process
|
||||
@ -244,7 +288,7 @@ async def start_main_server():
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server",
|
||||
kwargs=dict(started_event=webui_started, run_mode=run_mode),
|
||||
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode),
|
||||
daemon=True,
|
||||
)
|
||||
processes["webui"] = process
|
||||
@ -254,10 +298,10 @@ async def start_main_server():
|
||||
else:
|
||||
try:
|
||||
# 保证任务收到SIGINT后,能够正常退出
|
||||
if p := processes.get("run_loom"):
|
||||
if p := processes.get("model_providers"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
loom_started.wait() # 等待Loom启动完成
|
||||
model_providers_started.wait() # 等待model_providers启动完成
|
||||
|
||||
if p := processes.get("api"):
|
||||
p.start()
|
||||
@ -295,6 +339,8 @@ async def start_main_server():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 添加这行代码
|
||||
multiprocessing.freeze_support()
|
||||
create_tables()
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@ -1,184 +0,0 @@
|
||||
from typing import Tuple, Any
|
||||
|
||||
import streamlit as st
|
||||
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
|
||||
|
||||
|
||||
def update_store():
|
||||
logger.info("update_status")
|
||||
st.session_state.status = client.status()
|
||||
logger.info("update_list_plugins")
|
||||
list_plugins = client.list_plugins()
|
||||
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
|
||||
|
||||
logger.info("update_launch_subscribe_info")
|
||||
launch_subscribe_info = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
|
||||
|
||||
st.session_state.launch_subscribe_info = launch_subscribe_info
|
||||
|
||||
logger.info("update_list_running_models")
|
||||
list_running_models = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
list_running_models[plugin_name] = client.list_running_models(plugin_name)
|
||||
st.session_state.list_running_models = list_running_models
|
||||
|
||||
logger.info("update_model_config")
|
||||
model_config = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
model_config[plugin_name] = client.list_llm_models(plugin_name)
|
||||
st.session_state.model_config = model_config
|
||||
|
||||
|
||||
def start_plugin():
|
||||
import time
|
||||
start_plugins_name = st.session_state.plugins_name
|
||||
if start_plugins_name in st.session_state.run_plugins_list:
|
||||
st.toast(start_plugins_name + " has already been counted.")
|
||||
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("start_plugin " + start_plugins_name + ",starting.")
|
||||
result = client.launch_subscribe(start_plugins_name)
|
||||
st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", ""))
|
||||
time.sleep(3)
|
||||
result1 = client.launch_subscribe_start(start_plugins_name)
|
||||
|
||||
st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", ""))
|
||||
time.sleep(2)
|
||||
update_store()
|
||||
|
||||
|
||||
def start_worker():
|
||||
select_plugins_name = st.session_state.plugins_name
|
||||
select_worker_id = st.session_state.worker_id
|
||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||
already_counted = False
|
||||
for model in start_model_list:
|
||||
if model['worker_id'] == select_worker_id:
|
||||
already_counted = True
|
||||
break
|
||||
|
||||
if already_counted:
|
||||
st.toast(
|
||||
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
|
||||
import time
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
|
||||
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
|
||||
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||
import time
|
||||
time.sleep(.5)
|
||||
update_store()
|
||||
|
||||
|
||||
def stop_worker():
|
||||
select_plugins_name = st.session_state.plugins_name
|
||||
select_worker_id = st.session_state.worker_id
|
||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||
already_counted = False
|
||||
for model in start_model_list:
|
||||
if model['worker_id'] == select_worker_id:
|
||||
already_counted = True
|
||||
break
|
||||
|
||||
if not already_counted:
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
|
||||
import time
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
|
||||
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
|
||||
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||
import time
|
||||
time.sleep(.5)
|
||||
update_store()
|
||||
|
||||
|
||||
def build_providers_model_plugins_name():
|
||||
import streamlit_antd_components as sac
|
||||
if "run_plugins_list" not in st.session_state:
|
||||
return []
|
||||
# 按照模型构建sac.menu(菜单
|
||||
menu_items = []
|
||||
for key, value in st.session_state.list_running_models.items():
|
||||
menu_item_children = []
|
||||
for model in value:
|
||||
if "model" in model["providers"]:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
|
||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||
|
||||
return menu_items
|
||||
|
||||
|
||||
def build_providers_embedding_plugins_name():
|
||||
import streamlit_antd_components as sac
|
||||
if "run_plugins_list" not in st.session_state:
|
||||
return []
|
||||
# 按照模型构建sac.menu(菜单
|
||||
menu_items = []
|
||||
for key, value in st.session_state.list_running_models.items():
|
||||
menu_item_children = []
|
||||
for model in value:
|
||||
if "embedding" in model["providers"]:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
|
||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||
|
||||
return menu_items
|
||||
|
||||
|
||||
def find_menu_items_by_index(menu_items, key):
|
||||
for menu_item in menu_items:
|
||||
if menu_item.get('children') is not None:
|
||||
for child in menu_item.get('children'):
|
||||
if child.get('key') == key:
|
||||
return menu_item, child
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def set_llm_select(plugins_info, llm_model_worker):
|
||||
st.session_state["select_plugins_info"] = plugins_info
|
||||
st.session_state["select_model_worker"] = llm_model_worker
|
||||
|
||||
|
||||
def get_select_model_endpoint() -> Tuple[str, str]:
|
||||
plugins_info = st.session_state["select_plugins_info"]
|
||||
llm_model_worker = st.session_state["select_model_worker"]
|
||||
if plugins_info is None or llm_model_worker is None:
|
||||
raise ValueError("plugins_info or llm_model_worker is None")
|
||||
plugins_name = st.session_state["select_plugins_info"]['label']
|
||||
select_model_name = st.session_state["select_model_worker"]['label']
|
||||
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
|
||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||
return endpoint_host, select_model_name
|
||||
|
||||
|
||||
def set_embed_select(plugins_info, embed_model_worker):
|
||||
st.session_state["select_embed_plugins_info"] = plugins_info
|
||||
st.session_state["select_embed_model_worker"] = embed_model_worker
|
||||
|
||||
|
||||
def get_select_embed_endpoint() -> Tuple[str, str]:
|
||||
select_embed_plugins_info = st.session_state["select_embed_plugins_info"]
|
||||
select_embed_model_worker = st.session_state["select_embed_model_worker"]
|
||||
if select_embed_plugins_info is None or select_embed_model_worker is None:
|
||||
raise ValueError("select_embed_plugins_info or select_embed_model_worker is None")
|
||||
embed_plugins_name = st.session_state["select_embed_plugins_info"]['label']
|
||||
select_embed_model_name = st.session_state["select_embed_model_worker"]['label']
|
||||
endpoint_host = None
|
||||
if embed_plugins_name in st.session_state.launch_subscribe_info:
|
||||
adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name]
|
||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||
return endpoint_host, select_embed_model_name
|
||||
@ -1 +0,0 @@
|
||||
from .base import openai_plugins_page
|
||||
@ -1,67 +0,0 @@
|
||||
import streamlit as st
|
||||
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
|
||||
|
||||
from chatchat.webui_pages.utils import ApiRequest
|
||||
from chatchat.webui_pages.loom_view_client import (
|
||||
update_store,
|
||||
start_plugin,
|
||||
start_worker,
|
||||
stop_worker,
|
||||
)
|
||||
|
||||
|
||||
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
|
||||
|
||||
|
||||
with (st.container()):
|
||||
|
||||
if "worker_id" not in st.session_state:
|
||||
st.session_state.worker_id = ''
|
||||
if "plugins_name" not in st.session_state and "status" in st.session_state:
|
||||
|
||||
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
|
||||
st.session_state.plugins_name = v.get("plugins_names", [])[0]
|
||||
break
|
||||
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
|
||||
with col1:
|
||||
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
|
||||
run_list_plugins=st.session_state.run_plugins_list,
|
||||
launch_subscribe_info=st.session_state.launch_subscribe_info,
|
||||
list_running_models=st.session_state.list_running_models,
|
||||
model_config=st.session_state.model_config)
|
||||
|
||||
with col2:
|
||||
st.write("操作")
|
||||
if not st.session_state.run_plugins_list:
|
||||
button_type_disabled = False
|
||||
button_start_text = '启动'
|
||||
else:
|
||||
button_type_disabled = True
|
||||
button_start_text = '已启动'
|
||||
|
||||
if event:
|
||||
event_type = event.get("event")
|
||||
if event_type == "BottomNavigationAction":
|
||||
st.session_state.plugins_name = event.get("data")
|
||||
st.session_state.worker_id = ''
|
||||
# 不存在run_plugins_list,打开启动按钮
|
||||
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
|
||||
or st.session_state.run_plugins_list:
|
||||
button_type_disabled = False
|
||||
button_start_text = '启动'
|
||||
else:
|
||||
button_type_disabled = True
|
||||
button_start_text = '已启动'
|
||||
if event_type == "CardCoverComponent":
|
||||
st.session_state.worker_id = event.get("data")
|
||||
|
||||
st.button(button_start_text, disabled=button_type_disabled, key="start",
|
||||
on_click=start_plugin)
|
||||
|
||||
if st.session_state.worker_id != '':
|
||||
st.button("启动" + st.session_state.worker_id, key="start_worker",
|
||||
on_click=start_worker)
|
||||
st.button("停止" + st.session_state.worker_id, key="stop_worker",
|
||||
on_click=stop_worker)
|
||||
@ -216,7 +216,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
|
||||
6
chatchat-server/tests/unit_server/test_init_server.py
Normal file
6
chatchat-server/tests/unit_server/test_init_server.py
Normal file
@ -0,0 +1,6 @@
|
||||
from chatchat.model_loaders.init_server import init_server
|
||||
|
||||
|
||||
def test_init_server():
|
||||
|
||||
init_server()
|
||||
47
model-providers/model_providers/__main__.py
Normal file
47
model-providers/model_providers/__main__.py
Normal file
@ -0,0 +1,47 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from model_providers import BootstrapWebBuilder
|
||||
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-providers",
|
||||
type=str,
|
||||
default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
|
||||
help="run model_providers servers",
|
||||
dest="model_providers",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
logging_conf = get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path=args.model_providers
|
||||
)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
@ -45,6 +45,7 @@ from model_providers.core.bootstrap.openai_protocol import (
|
||||
Role,
|
||||
UsageInfo,
|
||||
)
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
@ -111,7 +112,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: Union[str, list]
|
||||
message_type: str, template: Union[str, list]
|
||||
) -> PromptMessage:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
@ -170,7 +171,7 @@ def _create_template_from_message_type(
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
message: MessageLikeRepresentation,
|
||||
) -> Union[PromptMessage]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
@ -212,7 +213,7 @@ def _convert_to_message(
|
||||
|
||||
|
||||
async def _stream_openai_chat_completion(
|
||||
response: Generator,
|
||||
response: Generator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
request_id, model = None, None
|
||||
for chunk in response:
|
||||
@ -362,11 +363,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
started_event.set()
|
||||
|
||||
async def workspaces_model_providers(self, request: Request):
|
||||
provider_list = self.get_provider_list(model_type=request.get("model_type"))
|
||||
|
||||
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list(
|
||||
model_type=request.get("model_type"))
|
||||
return ProviderListResponse(data=provider_list)
|
||||
|
||||
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type)
|
||||
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
@ -399,7 +403,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
return ModelList(data=models_list)
|
||||
|
||||
async def create_embeddings(
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
@ -409,7 +413,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
return EmbeddingsResponse(**dictify(response))
|
||||
|
||||
async def create_chat_completion(
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||
@ -469,9 +473,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
|
||||
|
||||
def run(
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
try:
|
||||
|
||||
@ -4,22 +4,11 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
ModelResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingsRequest,
|
||||
)
|
||||
from model_providers.core.entities.model_entities import ModelStatus
|
||||
from model_providers.core.entities.provider_entities import ProviderType
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
@ -43,137 +32,6 @@ class Bootstrap:
|
||||
def provider_manager(self, provider_manager: ModelManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_provider_list(
|
||||
self, model_type: Optional[str] = None
|
||||
) -> List[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
if (
|
||||
model_type_entity
|
||||
not in provider_configuration.provider.supported_model_types
|
||||
):
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=ProviderType.value_of("custom"),
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
|
||||
return provider_responses
|
||||
|
||||
def get_models_by_model_type(
|
||||
self, model_type: str
|
||||
) -> List[ProviderWithModelsResponse]:
|
||||
"""
|
||||
get models by model type.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
first_model = models[0]
|
||||
|
||||
has_active_models = any(
|
||||
[model.status == ModelStatus.ACTIVE for model in models]
|
||||
)
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if has_active_models
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
models=[
|
||||
ModelResponse(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return providers_with_models
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, cfg=None):
|
||||
|
||||
@ -0,0 +1,153 @@
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
ModelResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
|
||||
from model_providers.core.entities.model_entities import ModelStatus
|
||||
from model_providers.core.entities.provider_entities import ProviderType
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class ProvidersWrapper:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
def get_provider_list(
|
||||
self, model_type: Optional[str] = None
|
||||
) -> List[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
if (
|
||||
model_type_entity
|
||||
not in provider_configuration.provider.supported_model_types
|
||||
):
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=ProviderType.value_of("custom"),
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
|
||||
return provider_responses
|
||||
|
||||
def get_models_by_model_type(
|
||||
self, model_type: str
|
||||
) -> List[ProviderWithModelsResponse]:
|
||||
"""
|
||||
get models by model type.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
first_model = models[0]
|
||||
|
||||
has_active_models = any(
|
||||
[model.status == ModelStatus.ACTIVE for model in models]
|
||||
)
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if has_active_models
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
models=[
|
||||
ModelResponse(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return providers_with_models
|
||||
Loading…
x
Reference in New Issue
Block a user