mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
兼容model_providers,集成webui及API中平台配置的初始化 (#3625)
* provider_configuration init of MODEL_PLATFORMS * 开发手册 * 兼容model_providers,集成webui及API中平台配置的初始化
This commit is contained in:
parent
c0634828a4
commit
b3dee0b1d1
3
.gitignore
vendored
3
.gitignore
vendored
@ -183,4 +183,5 @@ configs/*.py
|
|||||||
/knowledge_base/samples/content/imi_temeplate.txt
|
/knowledge_base/samples/content/imi_temeplate.txt
|
||||||
/chatchat/configs/*.py
|
/chatchat/configs/*.py
|
||||||
/chatchat/configs/*.yaml
|
/chatchat/configs/*.yaml
|
||||||
chatchat/data
|
chatchat/data
|
||||||
|
/chatchat-server/chatchat/configs/model_providers.yaml
|
||||||
|
|||||||
@ -1,26 +0,0 @@
|
|||||||
log_path: "logs"
|
|
||||||
log_level: "DEBUG"
|
|
||||||
|
|
||||||
api_server:
|
|
||||||
host: "127.0.0.1"
|
|
||||||
port: 8000
|
|
||||||
|
|
||||||
publish_server:
|
|
||||||
host: "127.0.0.1"
|
|
||||||
port: 8001
|
|
||||||
|
|
||||||
subscribe_server:
|
|
||||||
host: "127.0.0.1"
|
|
||||||
port: 8002
|
|
||||||
|
|
||||||
openai_plugins_folder:
|
|
||||||
- "openai_plugins"
|
|
||||||
openai_plugins_load_folder:
|
|
||||||
- "configs"
|
|
||||||
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- openai:
|
|
||||||
name: "openai"
|
|
||||||
- zhipuai:
|
|
||||||
name: "zhipuai"
|
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
# 默认选用的 LLM 名称
|
# 默认选用的 LLM 名称
|
||||||
DEFAULT_LLM_MODEL = "chatglm3-6b"
|
DEFAULT_LLM_MODEL = "chatglm3-6b"
|
||||||
|
|
||||||
@ -31,7 +30,7 @@ SUPPORT_AGENT_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
# 意图识别不需要输出,模型后台知道就行
|
# 意图识别不需要输出,模型后台知道就行
|
||||||
"preprocess_model": {
|
"preprocess_model": {
|
||||||
DEFAULT_LLM_MODEL: {
|
DEFAULT_LLM_MODEL: {
|
||||||
"temperature": 0.05,
|
"temperature": 0.05,
|
||||||
@ -57,7 +56,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
"prompt_name": "ChatGLM3",
|
"prompt_name": "ChatGLM3",
|
||||||
"callbacks": True
|
"callbacks": True
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"postprocess_model": {
|
"postprocess_model": {
|
||||||
DEFAULT_LLM_MODEL: {
|
DEFAULT_LLM_MODEL: {
|
||||||
"temperature": 0.01,
|
"temperature": 0.01,
|
||||||
@ -76,47 +75,15 @@ LLM_MODEL_CONFIG = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台
|
||||||
|
# ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置
|
||||||
# - platform_name 可以任意填写,不要重复即可
|
# - platform_name 可以任意填写,不要重复即可
|
||||||
# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
|
# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可
|
||||||
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
|
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
|
||||||
|
|
||||||
MODEL_PLATFORMS = [
|
|
||||||
# {
|
|
||||||
# "platform_name": "openai-api",
|
|
||||||
# "platform_type": "openai",
|
|
||||||
# "api_base_url": "https://api.openai.com/v1",
|
|
||||||
# "api_key": "sk-",
|
|
||||||
# "api_proxy": "",
|
|
||||||
# "api_concurrencies": 5,
|
|
||||||
# "llm_models": [
|
|
||||||
# "gpt-3.5-turbo",
|
|
||||||
# ],
|
|
||||||
# "embed_models": [],
|
|
||||||
# "image_models": [],
|
|
||||||
# "multimodal_models": [],
|
|
||||||
# },
|
|
||||||
|
|
||||||
{
|
# 创建一个全局的共享字典
|
||||||
"platform_name": "xinference",
|
MODEL_PLATFORMS = [
|
||||||
"platform_type": "xinference",
|
|
||||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
|
||||||
"api_key": "EMPTY",
|
|
||||||
"api_concurrencies": 5,
|
|
||||||
# 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称
|
|
||||||
"llm_models": [
|
|
||||||
"chatglm3-6b",
|
|
||||||
],
|
|
||||||
"embed_models": [
|
|
||||||
"bge-large-zh-v1.5",
|
|
||||||
],
|
|
||||||
"image_models": [
|
|
||||||
"sd-turbo",
|
|
||||||
],
|
|
||||||
"multimodal_models": [
|
|
||||||
"qwen-vl",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
{
|
||||||
"platform_name": "oneapi",
|
"platform_name": "oneapi",
|
||||||
@ -152,41 +119,13 @@ MODEL_PLATFORMS = [
|
|||||||
"multimodal_models": [],
|
"multimodal_models": [],
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
|
||||||
"platform_name": "ollama",
|
|
||||||
"platform_type": "ollama",
|
|
||||||
"api_base_url": "http://{host}:{port}/v1",
|
|
||||||
"api_key": "sk-",
|
|
||||||
"api_concurrencies": 5,
|
|
||||||
"llm_models": [
|
|
||||||
# Qwen API,其它更多模型请参考https://ollama.com/library
|
|
||||||
"qwen:7b",
|
|
||||||
],
|
|
||||||
"embed_models": [
|
|
||||||
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
|
|
||||||
"nomic-embed-text"
|
|
||||||
],
|
|
||||||
"image_models": [],
|
|
||||||
"multimodal_models": [],
|
|
||||||
},
|
|
||||||
|
|
||||||
# {
|
|
||||||
# "platform_name": "loom",
|
|
||||||
# "platform_type": "loom",
|
|
||||||
# "api_base_url": "http://127.0.0.1:7860/v1",
|
|
||||||
# "api_key": "",
|
|
||||||
# "api_concurrencies": 5,
|
|
||||||
# "llm_models": [
|
|
||||||
# "chatglm3-6b",
|
|
||||||
# ],
|
|
||||||
# "embed_models": [],
|
|
||||||
# "image_models": [],
|
|
||||||
# "multimodal_models": [],
|
|
||||||
# },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
|
MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml")
|
||||||
|
MODEL_PROVIDERS_CFG_HOST = "127.0.0.1"
|
||||||
|
|
||||||
|
MODEL_PROVIDERS_CFG_PORT = 20000
|
||||||
# 工具配置项
|
# 工具配置项
|
||||||
TOOL_CONFIG = {
|
TOOL_CONFIG = {
|
||||||
"search_local_knowledgebase": {
|
"search_local_knowledgebase": {
|
||||||
|
|||||||
@ -0,0 +1,29 @@
|
|||||||
|
openai:
|
||||||
|
model_credential:
|
||||||
|
- model: 'gpt-3.5-turbo'
|
||||||
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
openai_api_key: 'sk-'
|
||||||
|
openai_organization: ''
|
||||||
|
openai_api_base: ''
|
||||||
|
- model: 'gpt-4'
|
||||||
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
openai_api_key: 'sk-'
|
||||||
|
openai_organization: ''
|
||||||
|
openai_api_base: ''
|
||||||
|
|
||||||
|
provider_credential:
|
||||||
|
openai_api_key: 'sk-'
|
||||||
|
openai_organization: ''
|
||||||
|
openai_api_base: ''
|
||||||
|
|
||||||
|
xinference:
|
||||||
|
model_credential:
|
||||||
|
- model: 'chatglm3-6b'
|
||||||
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
server_url: 'http://127.0.0.1:9997/'
|
||||||
|
model_uid: 'chatglm3-6b'
|
||||||
|
|
||||||
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"openai_plugins": [
|
|
||||||
"imitater", "openai"
|
|
||||||
]
|
|
||||||
|
|
||||||
}
|
|
||||||
109
chatchat-server/chatchat/model_loaders/init_server.py
Normal file
109
chatchat-server/chatchat/model_loaders/init_server.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG
|
||||||
|
from model_providers import BootstrapWebBuilder
|
||||||
|
from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse
|
||||||
|
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||||
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
from model_providers.core.utils.utils import (
|
||||||
|
get_config_dict,
|
||||||
|
get_log_file,
|
||||||
|
get_timestamp_ms,
|
||||||
|
)
|
||||||
|
import multiprocessing as mp
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_server(model_platforms_shard: Dict,
|
||||||
|
started_event: mp.Event = None,
|
||||||
|
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
||||||
|
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
|
||||||
|
provider_port: int = MODEL_PROVIDERS_CFG_PORT,
|
||||||
|
log_path: str = "logs"
|
||||||
|
) -> None:
|
||||||
|
logging_conf = get_config_dict(
|
||||||
|
"DEBUG",
|
||||||
|
get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"),
|
||||||
|
122,
|
||||||
|
111,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
boot = (
|
||||||
|
BootstrapWebBuilder()
|
||||||
|
.model_providers_cfg_path(
|
||||||
|
model_providers_cfg_path=model_providers_cfg_path
|
||||||
|
)
|
||||||
|
.host(host=provider_host)
|
||||||
|
.port(port=provider_port)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
boot.set_app_event(started_event=started_event)
|
||||||
|
|
||||||
|
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
|
||||||
|
model_platforms_shard['provider_platforms'] = provider_platforms
|
||||||
|
|
||||||
|
boot.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
async def pool_join_thread():
|
||||||
|
await boot.join()
|
||||||
|
|
||||||
|
asyncio.run(pool_join_thread())
|
||||||
|
except SystemExit:
|
||||||
|
logger.info("SystemExit raised, exiting")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
|
||||||
|
provider_list: List[ProviderResponse] = ProvidersWrapper(
|
||||||
|
provider_manager=provider_manager).get_provider_list()
|
||||||
|
logger.info(f"Provider list: {provider_list}")
|
||||||
|
# 转换MODEL_PLATFORMS
|
||||||
|
provider_platforms = []
|
||||||
|
for provider in provider_list:
|
||||||
|
provider_dict = {
|
||||||
|
"platform_name": provider.provider,
|
||||||
|
"platform_type": provider.provider,
|
||||||
|
"api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1",
|
||||||
|
"api_key": "EMPTY",
|
||||||
|
"api_concurrencies": 5
|
||||||
|
}
|
||||||
|
|
||||||
|
provider_dict["llm_models"] = []
|
||||||
|
provider_dict["embed_models"] = []
|
||||||
|
provider_dict["image_models"] = []
|
||||||
|
provider_dict["multimodal_models"] = []
|
||||||
|
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
|
||||||
|
provider.supported_model_types]
|
||||||
|
|
||||||
|
for model_type in supported_model_str_types:
|
||||||
|
|
||||||
|
providers_model_type = ProvidersWrapper(
|
||||||
|
provider_manager=provider_manager
|
||||||
|
).get_models_by_model_type(model_type=model_type)
|
||||||
|
cur_model_type: List[str] = []
|
||||||
|
# 查询当前provider的模型
|
||||||
|
for provider_model in providers_model_type:
|
||||||
|
if provider_model.provider == provider.provider:
|
||||||
|
models = [model.model for model in provider_model.models]
|
||||||
|
cur_model_type.extend(models)
|
||||||
|
|
||||||
|
if cur_model_type:
|
||||||
|
if model_type == "text-generation":
|
||||||
|
provider_dict["llm_models"] = cur_model_type
|
||||||
|
elif model_type == "text-embedding":
|
||||||
|
provider_dict["embed_models"] = cur_model_type
|
||||||
|
elif model_type == "text2img":
|
||||||
|
provider_dict["image_models"] = cur_model_type
|
||||||
|
elif model_type == "multimodal":
|
||||||
|
provider_dict["multimodal_models"] = cur_model_type
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported model type: {model_type}")
|
||||||
|
|
||||||
|
provider_platforms.append(provider_dict)
|
||||||
|
|
||||||
|
logger.info(f"Provider platforms: {provider_platforms}")
|
||||||
|
|
||||||
|
return provider_platforms
|
||||||
@ -27,9 +27,10 @@ from typing import (
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
||||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE)
|
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE,
|
||||||
|
MODEL_PLATFORMS)
|
||||||
from chatchat.server.pydantic_v2 import BaseModel, Field
|
from chatchat.server.pydantic_v2 import BaseModel, Field
|
||||||
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
@ -47,17 +48,18 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||||||
|
|
||||||
|
|
||||||
def get_config_platforms() -> Dict[str, Dict]:
|
def get_config_platforms() -> Dict[str, Dict]:
|
||||||
import importlib
|
# import importlib
|
||||||
from chatchat.configs import model_config
|
# 不能支持重载
|
||||||
importlib.reload(model_config)
|
# from chatchat.configs import model_config
|
||||||
|
# importlib.reload(model_config)
|
||||||
|
|
||||||
return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS}
|
return {m["platform_name"]: m for m in MODEL_PLATFORMS}
|
||||||
|
|
||||||
|
|
||||||
def get_config_models(
|
def get_config_models(
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||||
platform_name: str = None,
|
platform_name: str = None,
|
||||||
) -> Dict[str, Dict]:
|
) -> Dict[str, Dict]:
|
||||||
'''
|
'''
|
||||||
获取配置的模型列表,返回值为:
|
获取配置的模型列表,返回值为:
|
||||||
@ -71,12 +73,13 @@ def get_config_models(
|
|||||||
"api_proxy": xx,
|
"api_proxy": xx,
|
||||||
}}
|
}}
|
||||||
'''
|
'''
|
||||||
import importlib
|
# import importlib
|
||||||
from chatchat.configs import model_config
|
# 不能支持重载
|
||||||
importlib.reload(model_config)
|
# from chatchat.configs import model_config
|
||||||
|
# importlib.reload(model_config)
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for m in model_config.MODEL_PLATFORMS:
|
for m in MODEL_PLATFORMS:
|
||||||
if platform_name is not None and platform_name != m.get("platform_name"):
|
if platform_name is not None and platform_name != m.get("platform_name"):
|
||||||
continue
|
continue
|
||||||
if model_type is not None and f"{model_type}_models" not in m:
|
if model_type is not None and f"{model_type}_models" not in m:
|
||||||
@ -124,24 +127,24 @@ def get_ChatOpenAI(
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
callbacks: List[Callable] = [],
|
callbacks: List[Callable] = [],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
local_wrap: bool = False, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI:
|
||||||
model_info = get_model_info(model_name)
|
model_info = get_model_info(model_name)
|
||||||
params = dict(
|
params = dict(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if local_wrap:
|
if local_wrap:
|
||||||
params.update(
|
params.update(
|
||||||
openai_api_base = f"{api_address()}/v1",
|
openai_api_base=f"{api_address()}/v1",
|
||||||
openai_api_key = "EMPTY",
|
openai_api_key="EMPTY",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
params.update(
|
params.update(
|
||||||
@ -164,7 +167,7 @@ def get_OpenAI(
|
|||||||
echo: bool = True,
|
echo: bool = True,
|
||||||
callbacks: List[Callable] = [],
|
callbacks: List[Callable] = [],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
local_wrap: bool = False, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OpenAI:
|
) -> OpenAI:
|
||||||
# TODO: 从API获取模型信息
|
# TODO: 从API获取模型信息
|
||||||
@ -182,8 +185,8 @@ def get_OpenAI(
|
|||||||
try:
|
try:
|
||||||
if local_wrap:
|
if local_wrap:
|
||||||
params.update(
|
params.update(
|
||||||
openai_api_base = f"{api_address()}/v1",
|
openai_api_base=f"{api_address()}/v1",
|
||||||
openai_api_key = "EMPTY",
|
openai_api_key="EMPTY",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
params.update(
|
params.update(
|
||||||
@ -199,20 +202,20 @@ def get_OpenAI(
|
|||||||
|
|
||||||
|
|
||||||
def get_Embeddings(
|
def get_Embeddings(
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
local_wrap: bool = False, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
) -> Embeddings:
|
) -> Embeddings:
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||||
|
|
||||||
model_info = get_model_info(model_name=embed_model)
|
model_info = get_model_info(model_name=embed_model)
|
||||||
params = dict(model=embed_model)
|
params = dict(model=embed_model)
|
||||||
try:
|
try:
|
||||||
if local_wrap:
|
if local_wrap:
|
||||||
params.update(
|
params.update(
|
||||||
openai_api_base = f"{api_address()}/v1",
|
openai_api_base=f"{api_address()}/v1",
|
||||||
openai_api_key = "EMPTY",
|
openai_api_key="EMPTY",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
params.update(
|
params.update(
|
||||||
@ -223,7 +226,7 @@ def get_Embeddings(
|
|||||||
if model_info.get("platform_type") == "openai":
|
if model_info.get("platform_type") == "openai":
|
||||||
return OpenAIEmbeddings(**params)
|
return OpenAIEmbeddings(**params)
|
||||||
elif model_info.get("platform_type") == "ollama":
|
elif model_info.get("platform_type") == "ollama":
|
||||||
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
|
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1', ''),
|
||||||
model=embed_model,
|
model=embed_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -233,9 +236,9 @@ def get_Embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def get_OpenAIClient(
|
def get_OpenAIClient(
|
||||||
platform_name: str=None,
|
platform_name: str = None,
|
||||||
model_name: str=None,
|
model_name: str = None,
|
||||||
is_async: bool=True,
|
is_async: bool = True,
|
||||||
) -> Union[openai.Client, openai.AsyncClient]:
|
) -> Union[openai.Client, openai.AsyncClient]:
|
||||||
'''
|
'''
|
||||||
construct an openai Client for specified platform or model
|
construct an openai Client for specified platform or model
|
||||||
@ -601,7 +604,7 @@ def run_in_process_pool(
|
|||||||
tasks = []
|
tasks = []
|
||||||
max_workers = None
|
max_workers = None
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
|
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
|
||||||
with ProcessPoolExecutor(max_workers=max_workers) as pool:
|
with ProcessPoolExecutor(max_workers=max_workers) as pool:
|
||||||
for kwargs in params:
|
for kwargs in params:
|
||||||
tasks.append(pool.submit(func, **kwargs))
|
tasks.append(pool.submit(func, **kwargs))
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import multiprocessing
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
@ -6,6 +7,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
from chatchat.model_loaders.init_server import init_server
|
||||||
|
|
||||||
# 设置numexpr最大线程数,默认为CPU核心数
|
# 设置numexpr最大线程数,默认为CPU核心数
|
||||||
try:
|
try:
|
||||||
@ -23,7 +25,7 @@ from chatchat.configs import (
|
|||||||
DEFAULT_EMBEDDING_MODEL,
|
DEFAULT_EMBEDDING_MODEL,
|
||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
API_SERVER,
|
API_SERVER,
|
||||||
WEBUI_SERVER,
|
WEBUI_SERVER, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT
|
||||||
)
|
)
|
||||||
from chatchat.server.utils import FastAPI
|
from chatchat.server.utils import FastAPI
|
||||||
from chatchat.server.knowledge_base.migrate import create_tables
|
from chatchat.server.knowledge_base.migrate import create_tables
|
||||||
@ -38,15 +40,34 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
|||||||
if started_event is not None:
|
if started_event is not None:
|
||||||
started_event.set()
|
started_event.set()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app.router.lifespan_context = lifespan
|
app.router.lifespan_context = lifespan
|
||||||
|
|
||||||
|
|
||||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
def run_init_server(
|
||||||
|
model_platforms_shard: Dict,
|
||||||
|
started_event: mp.Event = None,
|
||||||
|
run_mode: str = None,
|
||||||
|
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
||||||
|
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
|
||||||
|
provider_port: int = MODEL_PROVIDERS_CFG_PORT):
|
||||||
|
init_server(model_platforms_shard=model_platforms_shard,
|
||||||
|
started_event=started_event,
|
||||||
|
model_providers_cfg_path=model_providers_cfg_path,
|
||||||
|
provider_host=provider_host,
|
||||||
|
provider_port=provider_port)
|
||||||
|
|
||||||
|
|
||||||
|
def run_api_server(model_platforms_shard: Dict,
|
||||||
|
started_event: mp.Event = None,
|
||||||
|
run_mode: str = None):
|
||||||
from chatchat.server.api_server.server_app import create_app
|
from chatchat.server.api_server.server_app import create_app
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from chatchat.server.utils import set_httpx_config
|
from chatchat.server.utils import set_httpx_config
|
||||||
|
from chatchat.configs import MODEL_PLATFORMS
|
||||||
|
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
|
||||||
|
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
app = create_app(run_mode=run_mode)
|
app = create_app(run_mode=run_mode)
|
||||||
_set_app_event(app, started_event)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
@ -56,48 +77,65 @@ def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
|||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
def run_webui(model_platforms_shard: Dict,
|
||||||
|
started_event: mp.Event = None, run_mode: str = None):
|
||||||
import sys
|
import sys
|
||||||
from chatchat.server.utils import set_httpx_config
|
from chatchat.server.utils import set_httpx_config
|
||||||
|
from chatchat.configs import MODEL_PLATFORMS
|
||||||
|
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
|
||||||
|
logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}")
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
host = WEBUI_SERVER["host"]
|
host = WEBUI_SERVER["host"]
|
||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
# 判断系统是否为Windows
|
|
||||||
if sys.platform == "win32":
|
|
||||||
st_exe = os.path.join(os.path.dirname(sys.executable), "Scripts", "streamlit")
|
|
||||||
else:
|
|
||||||
st_exe = os.path.join(os.path.dirname(sys.executable),"streamlit")
|
|
||||||
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py')
|
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py')
|
||||||
cmd = [st_exe, "run", script_dir,
|
|
||||||
"--server.address", host,
|
flag_options = {'server_address': host,
|
||||||
"--server.port", str(port),
|
'server_port': port,
|
||||||
"--theme.base", "light",
|
'theme_base': 'light',
|
||||||
"--theme.primaryColor", "#165dff",
|
'theme_primaryColor': '#165dff',
|
||||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
'theme_secondaryBackgroundColor': '#f5f5f5',
|
||||||
"--theme.textColor", "#000000",
|
'theme_textColor': '#000000',
|
||||||
]
|
'global_disableWatchdogWarning': None,
|
||||||
|
'global_disableWidgetStateDuplicationWarning': None,
|
||||||
|
'global_showWarningOnDirectExecution': None,
|
||||||
|
'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None,
|
||||||
|
'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None,
|
||||||
|
'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None,
|
||||||
|
'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None,
|
||||||
|
'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None,
|
||||||
|
'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None,
|
||||||
|
'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None,
|
||||||
|
'runner_postScriptGC': None, 'runner_fastReruns': None,
|
||||||
|
'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None,
|
||||||
|
'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None,
|
||||||
|
'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None,
|
||||||
|
'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None,
|
||||||
|
'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None,
|
||||||
|
'server_enableWebsocketCompression': None, 'server_enableStaticServing': None,
|
||||||
|
'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None,
|
||||||
|
'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None,
|
||||||
|
'ui_hideSidebarNav': None, 'magic_displayRootDocString': None,
|
||||||
|
'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None,
|
||||||
|
'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None,
|
||||||
|
'theme_backgroundColor': None, 'theme_font': None}
|
||||||
|
|
||||||
|
args = []
|
||||||
if run_mode == "lite":
|
if run_mode == "lite":
|
||||||
cmd += [
|
args += [
|
||||||
"--",
|
"--",
|
||||||
"lite",
|
"lite",
|
||||||
]
|
]
|
||||||
p = subprocess.Popen(cmd)
|
|
||||||
|
try:
|
||||||
|
# for streamlit >= 1.12.1
|
||||||
|
from streamlit.web import bootstrap
|
||||||
|
except ImportError:
|
||||||
|
from streamlit import bootstrap
|
||||||
|
|
||||||
|
bootstrap.run(script_dir, False, args, flag_options)
|
||||||
started_event.set()
|
started_event.set()
|
||||||
p.wait()
|
|
||||||
|
|
||||||
|
|
||||||
def run_loom(started_event: mp.Event = None):
|
|
||||||
from chatchat.configs import LOOM_CONFIG
|
|
||||||
|
|
||||||
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
|
|
||||||
"-f", LOOM_CONFIG
|
|
||||||
]
|
|
||||||
|
|
||||||
p = subprocess.Popen(cmd)
|
|
||||||
started_event.set()
|
|
||||||
p.wait()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.ArgumentParser:
|
def parse_args() -> argparse.ArgumentParser:
|
||||||
@ -106,13 +144,13 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
"-a",
|
"-a",
|
||||||
"--all-webui",
|
"--all-webui",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
|
help="run model_providers servers,run api.py and webui.py",
|
||||||
dest="all_webui",
|
dest="all_webui",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--all-api",
|
"--all-api",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
help="run model_providers servers, run api.py",
|
||||||
dest="all_api",
|
dest="all_api",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,11 +194,18 @@ def dump_server_info(after_start=False, args=None):
|
|||||||
|
|
||||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||||
|
|
||||||
print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}")
|
print(f"默认选用的 Embedding 名称: {DEFAULT_EMBEDDING_MODEL}")
|
||||||
|
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"服务端运行信息:")
|
print(f"服务端运行信息:")
|
||||||
|
if args.api:
|
||||||
|
print(
|
||||||
|
f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n"
|
||||||
|
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n"
|
||||||
|
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n")
|
||||||
|
|
||||||
|
print(f" Chatchat Api Server: {api_address()}")
|
||||||
if args.webui:
|
if args.webui:
|
||||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||||
@ -193,21 +238,16 @@ async def start_main_server():
|
|||||||
args, parser = parse_args()
|
args, parser = parse_args()
|
||||||
|
|
||||||
if args.all_webui:
|
if args.all_webui:
|
||||||
args.openai_api = True
|
|
||||||
args.model_worker = True
|
|
||||||
args.api = True
|
args.api = True
|
||||||
args.api_worker = True
|
args.api_worker = True
|
||||||
args.webui = True
|
args.webui = True
|
||||||
|
|
||||||
elif args.all_api:
|
elif args.all_api:
|
||||||
args.openai_api = True
|
|
||||||
args.model_worker = True
|
|
||||||
args.api = True
|
args.api = True
|
||||||
args.api_worker = True
|
args.api_worker = True
|
||||||
args.webui = False
|
args.webui = False
|
||||||
|
|
||||||
if args.lite:
|
if args.lite:
|
||||||
args.model_worker = False
|
|
||||||
run_mode = "lite"
|
run_mode = "lite"
|
||||||
|
|
||||||
dump_server_info(args=args)
|
dump_server_info(args=args)
|
||||||
@ -216,25 +256,29 @@ async def start_main_server():
|
|||||||
logger.info(f"正在启动服务:")
|
logger.info(f"正在启动服务:")
|
||||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||||
|
|
||||||
processes = {"online_api": {}, "model_worker": {}}
|
processes = {}
|
||||||
|
|
||||||
def process_count():
|
def process_count():
|
||||||
return len(processes)
|
return len(processes)
|
||||||
|
|
||||||
loom_started = manager.Event()
|
# 定义全局配置变量,使用 Manager 创建共享字典
|
||||||
# process = Process(
|
model_platforms_shard = manager.dict()
|
||||||
# target=run_loom,
|
model_providers_started = manager.Event()
|
||||||
# name=f"run_loom Server",
|
if args.api:
|
||||||
# kwargs=dict(started_event=loom_started),
|
process = Process(
|
||||||
# daemon=True,
|
target=run_init_server,
|
||||||
# )
|
name=f"Model providers Server",
|
||||||
# processes["run_loom"] = process
|
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started,
|
||||||
|
run_mode=run_mode),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
processes["model_providers"] = process
|
||||||
api_started = manager.Event()
|
api_started = manager.Event()
|
||||||
if args.api:
|
if args.api:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_api_server,
|
target=run_api_server,
|
||||||
name=f"API Server",
|
name=f"API Server",
|
||||||
kwargs=dict(started_event=api_started, run_mode=run_mode),
|
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
processes["api"] = process
|
processes["api"] = process
|
||||||
@ -244,7 +288,7 @@ async def start_main_server():
|
|||||||
process = Process(
|
process = Process(
|
||||||
target=run_webui,
|
target=run_webui,
|
||||||
name=f"WEBUI Server",
|
name=f"WEBUI Server",
|
||||||
kwargs=dict(started_event=webui_started, run_mode=run_mode),
|
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
processes["webui"] = process
|
processes["webui"] = process
|
||||||
@ -254,10 +298,10 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 保证任务收到SIGINT后,能够正常退出
|
# 保证任务收到SIGINT后,能够正常退出
|
||||||
if p := processes.get("run_loom"):
|
if p := processes.get("model_providers"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
loom_started.wait() # 等待Loom启动完成
|
model_providers_started.wait() # 等待model_providers启动完成
|
||||||
|
|
||||||
if p := processes.get("api"):
|
if p := processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
@ -295,6 +339,8 @@ async def start_main_server():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# 添加这行代码
|
||||||
|
multiprocessing.freeze_support()
|
||||||
create_tables()
|
create_tables()
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@ -1,184 +0,0 @@
|
|||||||
from typing import Tuple, Any
|
|
||||||
|
|
||||||
import streamlit as st
|
|
||||||
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
|
|
||||||
|
|
||||||
|
|
||||||
def update_store():
|
|
||||||
logger.info("update_status")
|
|
||||||
st.session_state.status = client.status()
|
|
||||||
logger.info("update_list_plugins")
|
|
||||||
list_plugins = client.list_plugins()
|
|
||||||
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
|
|
||||||
|
|
||||||
logger.info("update_launch_subscribe_info")
|
|
||||||
launch_subscribe_info = {}
|
|
||||||
for plugin_name in st.session_state.run_plugins_list:
|
|
||||||
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
|
|
||||||
|
|
||||||
st.session_state.launch_subscribe_info = launch_subscribe_info
|
|
||||||
|
|
||||||
logger.info("update_list_running_models")
|
|
||||||
list_running_models = {}
|
|
||||||
for plugin_name in st.session_state.run_plugins_list:
|
|
||||||
list_running_models[plugin_name] = client.list_running_models(plugin_name)
|
|
||||||
st.session_state.list_running_models = list_running_models
|
|
||||||
|
|
||||||
logger.info("update_model_config")
|
|
||||||
model_config = {}
|
|
||||||
for plugin_name in st.session_state.run_plugins_list:
|
|
||||||
model_config[plugin_name] = client.list_llm_models(plugin_name)
|
|
||||||
st.session_state.model_config = model_config
|
|
||||||
|
|
||||||
|
|
||||||
def start_plugin():
|
|
||||||
import time
|
|
||||||
start_plugins_name = st.session_state.plugins_name
|
|
||||||
if start_plugins_name in st.session_state.run_plugins_list:
|
|
||||||
st.toast(start_plugins_name + " has already been counted.")
|
|
||||||
|
|
||||||
time.sleep(.5)
|
|
||||||
else:
|
|
||||||
|
|
||||||
st.toast("start_plugin " + start_plugins_name + ",starting.")
|
|
||||||
result = client.launch_subscribe(start_plugins_name)
|
|
||||||
st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", ""))
|
|
||||||
time.sleep(3)
|
|
||||||
result1 = client.launch_subscribe_start(start_plugins_name)
|
|
||||||
|
|
||||||
st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", ""))
|
|
||||||
time.sleep(2)
|
|
||||||
update_store()
|
|
||||||
|
|
||||||
|
|
||||||
def start_worker():
|
|
||||||
select_plugins_name = st.session_state.plugins_name
|
|
||||||
select_worker_id = st.session_state.worker_id
|
|
||||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
|
||||||
already_counted = False
|
|
||||||
for model in start_model_list:
|
|
||||||
if model['worker_id'] == select_worker_id:
|
|
||||||
already_counted = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if already_counted:
|
|
||||||
st.toast(
|
|
||||||
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
|
|
||||||
import time
|
|
||||||
time.sleep(.5)
|
|
||||||
else:
|
|
||||||
|
|
||||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
|
|
||||||
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
|
|
||||||
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
|
||||||
import time
|
|
||||||
time.sleep(.5)
|
|
||||||
update_store()
|
|
||||||
|
|
||||||
|
|
||||||
def stop_worker():
|
|
||||||
select_plugins_name = st.session_state.plugins_name
|
|
||||||
select_worker_id = st.session_state.worker_id
|
|
||||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
|
||||||
already_counted = False
|
|
||||||
for model in start_model_list:
|
|
||||||
if model['worker_id'] == select_worker_id:
|
|
||||||
already_counted = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not already_counted:
|
|
||||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
|
|
||||||
import time
|
|
||||||
time.sleep(.5)
|
|
||||||
else:
|
|
||||||
|
|
||||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
|
|
||||||
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
|
|
||||||
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
|
||||||
import time
|
|
||||||
time.sleep(.5)
|
|
||||||
update_store()
|
|
||||||
|
|
||||||
|
|
||||||
def build_providers_model_plugins_name():
|
|
||||||
import streamlit_antd_components as sac
|
|
||||||
if "run_plugins_list" not in st.session_state:
|
|
||||||
return []
|
|
||||||
# 按照模型构建sac.menu(菜单
|
|
||||||
menu_items = []
|
|
||||||
for key, value in st.session_state.list_running_models.items():
|
|
||||||
menu_item_children = []
|
|
||||||
for model in value:
|
|
||||||
if "model" in model["providers"]:
|
|
||||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
|
||||||
|
|
||||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
|
||||||
|
|
||||||
return menu_items
|
|
||||||
|
|
||||||
|
|
||||||
def build_providers_embedding_plugins_name():
|
|
||||||
import streamlit_antd_components as sac
|
|
||||||
if "run_plugins_list" not in st.session_state:
|
|
||||||
return []
|
|
||||||
# 按照模型构建sac.menu(菜单
|
|
||||||
menu_items = []
|
|
||||||
for key, value in st.session_state.list_running_models.items():
|
|
||||||
menu_item_children = []
|
|
||||||
for model in value:
|
|
||||||
if "embedding" in model["providers"]:
|
|
||||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
|
||||||
|
|
||||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
|
||||||
|
|
||||||
return menu_items
|
|
||||||
|
|
||||||
|
|
||||||
def find_menu_items_by_index(menu_items, key):
|
|
||||||
for menu_item in menu_items:
|
|
||||||
if menu_item.get('children') is not None:
|
|
||||||
for child in menu_item.get('children'):
|
|
||||||
if child.get('key') == key:
|
|
||||||
return menu_item, child
|
|
||||||
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def set_llm_select(plugins_info, llm_model_worker):
|
|
||||||
st.session_state["select_plugins_info"] = plugins_info
|
|
||||||
st.session_state["select_model_worker"] = llm_model_worker
|
|
||||||
|
|
||||||
|
|
||||||
def get_select_model_endpoint() -> Tuple[str, str]:
|
|
||||||
plugins_info = st.session_state["select_plugins_info"]
|
|
||||||
llm_model_worker = st.session_state["select_model_worker"]
|
|
||||||
if plugins_info is None or llm_model_worker is None:
|
|
||||||
raise ValueError("plugins_info or llm_model_worker is None")
|
|
||||||
plugins_name = st.session_state["select_plugins_info"]['label']
|
|
||||||
select_model_name = st.session_state["select_model_worker"]['label']
|
|
||||||
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
|
|
||||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
|
||||||
return endpoint_host, select_model_name
|
|
||||||
|
|
||||||
|
|
||||||
def set_embed_select(plugins_info, embed_model_worker):
|
|
||||||
st.session_state["select_embed_plugins_info"] = plugins_info
|
|
||||||
st.session_state["select_embed_model_worker"] = embed_model_worker
|
|
||||||
|
|
||||||
|
|
||||||
def get_select_embed_endpoint() -> Tuple[str, str]:
|
|
||||||
select_embed_plugins_info = st.session_state["select_embed_plugins_info"]
|
|
||||||
select_embed_model_worker = st.session_state["select_embed_model_worker"]
|
|
||||||
if select_embed_plugins_info is None or select_embed_model_worker is None:
|
|
||||||
raise ValueError("select_embed_plugins_info or select_embed_model_worker is None")
|
|
||||||
embed_plugins_name = st.session_state["select_embed_plugins_info"]['label']
|
|
||||||
select_embed_model_name = st.session_state["select_embed_model_worker"]['label']
|
|
||||||
endpoint_host = None
|
|
||||||
if embed_plugins_name in st.session_state.launch_subscribe_info:
|
|
||||||
adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name]
|
|
||||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
|
||||||
return endpoint_host, select_embed_model_name
|
|
||||||
@ -1 +0,0 @@
|
|||||||
from .base import openai_plugins_page
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
|
|
||||||
|
|
||||||
from chatchat.webui_pages.utils import ApiRequest
|
|
||||||
from chatchat.webui_pages.loom_view_client import (
|
|
||||||
update_store,
|
|
||||||
start_plugin,
|
|
||||||
start_worker,
|
|
||||||
stop_worker,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
|
|
||||||
|
|
||||||
|
|
||||||
with (st.container()):
|
|
||||||
|
|
||||||
if "worker_id" not in st.session_state:
|
|
||||||
st.session_state.worker_id = ''
|
|
||||||
if "plugins_name" not in st.session_state and "status" in st.session_state:
|
|
||||||
|
|
||||||
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
|
|
||||||
st.session_state.plugins_name = v.get("plugins_names", [])[0]
|
|
||||||
break
|
|
||||||
|
|
||||||
col1, col2 = st.columns([0.8, 0.2])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
|
|
||||||
run_list_plugins=st.session_state.run_plugins_list,
|
|
||||||
launch_subscribe_info=st.session_state.launch_subscribe_info,
|
|
||||||
list_running_models=st.session_state.list_running_models,
|
|
||||||
model_config=st.session_state.model_config)
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
st.write("操作")
|
|
||||||
if not st.session_state.run_plugins_list:
|
|
||||||
button_type_disabled = False
|
|
||||||
button_start_text = '启动'
|
|
||||||
else:
|
|
||||||
button_type_disabled = True
|
|
||||||
button_start_text = '已启动'
|
|
||||||
|
|
||||||
if event:
|
|
||||||
event_type = event.get("event")
|
|
||||||
if event_type == "BottomNavigationAction":
|
|
||||||
st.session_state.plugins_name = event.get("data")
|
|
||||||
st.session_state.worker_id = ''
|
|
||||||
# 不存在run_plugins_list,打开启动按钮
|
|
||||||
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
|
|
||||||
or st.session_state.run_plugins_list:
|
|
||||||
button_type_disabled = False
|
|
||||||
button_start_text = '启动'
|
|
||||||
else:
|
|
||||||
button_type_disabled = True
|
|
||||||
button_start_text = '已启动'
|
|
||||||
if event_type == "CardCoverComponent":
|
|
||||||
st.session_state.worker_id = event.get("data")
|
|
||||||
|
|
||||||
st.button(button_start_text, disabled=button_type_disabled, key="start",
|
|
||||||
on_click=start_plugin)
|
|
||||||
|
|
||||||
if st.session_state.worker_id != '':
|
|
||||||
st.button("启动" + st.session_state.worker_id, key="start_worker",
|
|
||||||
on_click=start_worker)
|
|
||||||
st.button("停止" + st.session_state.worker_id, key="stop_worker",
|
|
||||||
on_click=stop_worker)
|
|
||||||
@ -216,7 +216,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
#
|
#
|
||||||
# https://github.com/tophat/syrupy
|
# https://github.com/tophat/syrupy
|
||||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
|
||||||
# Registering custom markers.
|
# Registering custom markers.
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
markers = [
|
markers = [
|
||||||
|
|||||||
6
chatchat-server/tests/unit_server/test_init_server.py
Normal file
6
chatchat-server/tests/unit_server/test_init_server.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from chatchat.model_loaders.init_server import init_server
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_server():
|
||||||
|
|
||||||
|
init_server()
|
||||||
47
model-providers/model_providers/__main__.py
Normal file
47
model-providers/model_providers/__main__.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from model_providers import BootstrapWebBuilder
|
||||||
|
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-providers",
|
||||||
|
type=str,
|
||||||
|
default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
|
||||||
|
help="run model_providers servers",
|
||||||
|
dest="model_providers",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
try:
|
||||||
|
logging_conf = get_config_dict(
|
||||||
|
"DEBUG",
|
||||||
|
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||||
|
122,
|
||||||
|
111,
|
||||||
|
)
|
||||||
|
boot = (
|
||||||
|
BootstrapWebBuilder()
|
||||||
|
.model_providers_cfg_path(
|
||||||
|
model_providers_cfg_path=args.model_providers
|
||||||
|
)
|
||||||
|
.host(host="127.0.0.1")
|
||||||
|
.port(port=20000)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
boot.set_app_event(started_event=None)
|
||||||
|
boot.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
|
||||||
|
async def pool_join_thread():
|
||||||
|
await boot.join()
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(pool_join_thread())
|
||||||
|
except SystemExit:
|
||||||
|
logger.info("SystemExit raised, exiting")
|
||||||
|
raise
|
||||||
@ -45,6 +45,7 @@ from model_providers.core.bootstrap.openai_protocol import (
|
|||||||
Role,
|
Role,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
LLMResultChunk,
|
LLMResultChunk,
|
||||||
@ -111,7 +112,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def _create_template_from_message_type(
|
def _create_template_from_message_type(
|
||||||
message_type: str, template: Union[str, list]
|
message_type: str, template: Union[str, list]
|
||||||
) -> PromptMessage:
|
) -> PromptMessage:
|
||||||
"""Create a message prompt template from a message type and template string.
|
"""Create a message prompt template from a message type and template string.
|
||||||
|
|
||||||
@ -170,7 +171,7 @@ def _create_template_from_message_type(
|
|||||||
|
|
||||||
|
|
||||||
def _convert_to_message(
|
def _convert_to_message(
|
||||||
message: MessageLikeRepresentation,
|
message: MessageLikeRepresentation,
|
||||||
) -> Union[PromptMessage]:
|
) -> Union[PromptMessage]:
|
||||||
"""Instantiate a message from a variety of message formats.
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
@ -212,7 +213,7 @@ def _convert_to_message(
|
|||||||
|
|
||||||
|
|
||||||
async def _stream_openai_chat_completion(
|
async def _stream_openai_chat_completion(
|
||||||
response: Generator,
|
response: Generator,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
request_id, model = None, None
|
request_id, model = None, None
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@ -362,11 +363,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
async def workspaces_model_providers(self, request: Request):
|
async def workspaces_model_providers(self, request: Request):
|
||||||
provider_list = self.get_provider_list(model_type=request.get("model_type"))
|
|
||||||
|
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list(
|
||||||
|
model_type=request.get("model_type"))
|
||||||
return ProviderListResponse(data=provider_list)
|
return ProviderListResponse(data=provider_list)
|
||||||
|
|
||||||
async def workspaces_model_types(self, model_type: str, request: Request):
|
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||||
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
|
models_by_model_type = ProvidersWrapper(
|
||||||
|
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type)
|
||||||
return ProviderModelTypeResponse(data=models_by_model_type)
|
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||||
|
|
||||||
async def list_models(self, provider: str, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
@ -399,7 +403,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return ModelList(data=models_list)
|
return ModelList(data=models_list)
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
@ -409,7 +413,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return EmbeddingsResponse(**dictify(response))
|
return EmbeddingsResponse(**dictify(response))
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
@ -469,9 +473,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
logging_conf: Optional[dict] = None,
|
logging_conf: Optional[dict] = None,
|
||||||
started_event: mp.Event = None,
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -4,22 +4,11 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
|
||||||
CustomConfigurationResponse,
|
|
||||||
CustomConfigurationStatus,
|
|
||||||
ModelResponse,
|
|
||||||
ProviderResponse,
|
|
||||||
ProviderWithModelsResponse,
|
|
||||||
SystemConfigurationResponse,
|
|
||||||
)
|
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
)
|
)
|
||||||
from model_providers.core.entities.model_entities import ModelStatus
|
|
||||||
from model_providers.core.entities.provider_entities import ProviderType
|
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class Bootstrap:
|
class Bootstrap:
|
||||||
@ -43,137 +32,6 @@ class Bootstrap:
|
|||||||
def provider_manager(self, provider_manager: ModelManager):
|
def provider_manager(self, provider_manager: ModelManager):
|
||||||
self._provider_manager = provider_manager
|
self._provider_manager = provider_manager
|
||||||
|
|
||||||
def get_provider_list(
|
|
||||||
self, model_type: Optional[str] = None
|
|
||||||
) -> List[ProviderResponse]:
|
|
||||||
"""
|
|
||||||
get provider list.
|
|
||||||
|
|
||||||
:param model_type: model type
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# 合并两个字典的键
|
|
||||||
provider = set(
|
|
||||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
|
||||||
)
|
|
||||||
provider.update(
|
|
||||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
|
||||||
)
|
|
||||||
# Get all provider configurations of the current workspace
|
|
||||||
provider_configurations = (
|
|
||||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_responses = []
|
|
||||||
for provider_configuration in provider_configurations.values():
|
|
||||||
if model_type:
|
|
||||||
model_type_entity = ModelType.value_of(model_type)
|
|
||||||
if (
|
|
||||||
model_type_entity
|
|
||||||
not in provider_configuration.provider.supported_model_types
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
provider_response = ProviderResponse(
|
|
||||||
provider=provider_configuration.provider.provider,
|
|
||||||
label=provider_configuration.provider.label,
|
|
||||||
description=provider_configuration.provider.description,
|
|
||||||
icon_small=provider_configuration.provider.icon_small,
|
|
||||||
icon_large=provider_configuration.provider.icon_large,
|
|
||||||
background=provider_configuration.provider.background,
|
|
||||||
help=provider_configuration.provider.help,
|
|
||||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
|
||||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
|
||||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
|
||||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
|
||||||
preferred_provider_type=ProviderType.value_of("custom"),
|
|
||||||
custom_configuration=CustomConfigurationResponse(
|
|
||||||
status=CustomConfigurationStatus.ACTIVE
|
|
||||||
if provider_configuration.is_custom_configuration_available()
|
|
||||||
else CustomConfigurationStatus.NO_CONFIGURE
|
|
||||||
),
|
|
||||||
system_configuration=SystemConfigurationResponse(enabled=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_responses.append(provider_response)
|
|
||||||
|
|
||||||
return provider_responses
|
|
||||||
|
|
||||||
def get_models_by_model_type(
|
|
||||||
self, model_type: str
|
|
||||||
) -> List[ProviderWithModelsResponse]:
|
|
||||||
"""
|
|
||||||
get models by model type.
|
|
||||||
|
|
||||||
:param model_type: model type
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# 合并两个字典的键
|
|
||||||
provider = set(
|
|
||||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
|
||||||
)
|
|
||||||
provider.update(
|
|
||||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
|
||||||
)
|
|
||||||
# Get all provider configurations of the current workspace
|
|
||||||
provider_configurations = (
|
|
||||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider available models
|
|
||||||
models = provider_configurations.get_models(
|
|
||||||
model_type=ModelType.value_of(model_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Group models by provider
|
|
||||||
provider_models = {}
|
|
||||||
for model in models:
|
|
||||||
if model.provider.provider not in provider_models:
|
|
||||||
provider_models[model.provider.provider] = []
|
|
||||||
|
|
||||||
if model.deprecated:
|
|
||||||
continue
|
|
||||||
|
|
||||||
provider_models[model.provider.provider].append(model)
|
|
||||||
|
|
||||||
# convert to ProviderWithModelsResponse list
|
|
||||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
|
||||||
for provider, models in provider_models.items():
|
|
||||||
if not models:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_model = models[0]
|
|
||||||
|
|
||||||
has_active_models = any(
|
|
||||||
[model.status == ModelStatus.ACTIVE for model in models]
|
|
||||||
)
|
|
||||||
|
|
||||||
providers_with_models.append(
|
|
||||||
ProviderWithModelsResponse(
|
|
||||||
provider=provider,
|
|
||||||
label=first_model.provider.label,
|
|
||||||
icon_small=first_model.provider.icon_small,
|
|
||||||
icon_large=first_model.provider.icon_large,
|
|
||||||
status=CustomConfigurationStatus.ACTIVE
|
|
||||||
if has_active_models
|
|
||||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
|
||||||
models=[
|
|
||||||
ModelResponse(
|
|
||||||
model=model.model,
|
|
||||||
label=model.label,
|
|
||||||
model_type=model.model_type,
|
|
||||||
features=model.features,
|
|
||||||
fetch_from=model.fetch_from,
|
|
||||||
model_properties=model.model_properties,
|
|
||||||
status=model.status,
|
|
||||||
)
|
|
||||||
for model in models
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return providers_with_models
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_config(cls, cfg=None):
|
def from_config(cls, cfg=None):
|
||||||
|
|||||||
@ -0,0 +1,153 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||||
|
CustomConfigurationResponse,
|
||||||
|
CustomConfigurationStatus,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderResponse,
|
||||||
|
ProviderWithModelsResponse,
|
||||||
|
SystemConfigurationResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_providers.core.entities.model_entities import ModelStatus
|
||||||
|
from model_providers.core.entities.provider_entities import ProviderType
|
||||||
|
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
|
||||||
|
class ProvidersWrapper:
|
||||||
|
def __init__(self, provider_manager: ProviderManager):
|
||||||
|
self.provider_manager = provider_manager
|
||||||
|
|
||||||
|
def get_provider_list(
|
||||||
|
self, model_type: Optional[str] = None
|
||||||
|
) -> List[ProviderResponse]:
|
||||||
|
"""
|
||||||
|
get provider list.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 合并两个字典的键
|
||||||
|
provider = set(
|
||||||
|
self.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||||
|
)
|
||||||
|
provider.update(
|
||||||
|
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
|
)
|
||||||
|
# Get all provider configurations of the current workspace
|
||||||
|
provider_configurations = (
|
||||||
|
self.provider_manager.get_configurations(provider=provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_responses = []
|
||||||
|
for provider_configuration in provider_configurations.values():
|
||||||
|
if model_type:
|
||||||
|
model_type_entity = ModelType.value_of(model_type)
|
||||||
|
if (
|
||||||
|
model_type_entity
|
||||||
|
not in provider_configuration.provider.supported_model_types
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_response = ProviderResponse(
|
||||||
|
provider=provider_configuration.provider.provider,
|
||||||
|
label=provider_configuration.provider.label,
|
||||||
|
description=provider_configuration.provider.description,
|
||||||
|
icon_small=provider_configuration.provider.icon_small,
|
||||||
|
icon_large=provider_configuration.provider.icon_large,
|
||||||
|
background=provider_configuration.provider.background,
|
||||||
|
help=provider_configuration.provider.help,
|
||||||
|
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||||
|
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||||
|
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||||
|
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||||
|
preferred_provider_type=ProviderType.value_of("custom"),
|
||||||
|
custom_configuration=CustomConfigurationResponse(
|
||||||
|
status=CustomConfigurationStatus.ACTIVE
|
||||||
|
if provider_configuration.is_custom_configuration_available()
|
||||||
|
else CustomConfigurationStatus.NO_CONFIGURE
|
||||||
|
),
|
||||||
|
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_responses.append(provider_response)
|
||||||
|
|
||||||
|
return provider_responses
|
||||||
|
|
||||||
|
def get_models_by_model_type(
|
||||||
|
self, model_type: str
|
||||||
|
) -> List[ProviderWithModelsResponse]:
|
||||||
|
"""
|
||||||
|
get models by model type.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 合并两个字典的键
|
||||||
|
provider = set(
|
||||||
|
self.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||||
|
)
|
||||||
|
provider.update(
|
||||||
|
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
|
)
|
||||||
|
# Get all provider configurations of the current workspace
|
||||||
|
provider_configurations = (
|
||||||
|
self.provider_manager.get_configurations(provider=provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get provider available models
|
||||||
|
models = provider_configurations.get_models(
|
||||||
|
model_type=ModelType.value_of(model_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group models by provider
|
||||||
|
provider_models = {}
|
||||||
|
for model in models:
|
||||||
|
if model.provider.provider not in provider_models:
|
||||||
|
provider_models[model.provider.provider] = []
|
||||||
|
|
||||||
|
if model.deprecated:
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_models[model.provider.provider].append(model)
|
||||||
|
|
||||||
|
# convert to ProviderWithModelsResponse list
|
||||||
|
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||||
|
for provider, models in provider_models.items():
|
||||||
|
if not models:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_model = models[0]
|
||||||
|
|
||||||
|
has_active_models = any(
|
||||||
|
[model.status == ModelStatus.ACTIVE for model in models]
|
||||||
|
)
|
||||||
|
|
||||||
|
providers_with_models.append(
|
||||||
|
ProviderWithModelsResponse(
|
||||||
|
provider=provider,
|
||||||
|
label=first_model.provider.label,
|
||||||
|
icon_small=first_model.provider.icon_small,
|
||||||
|
icon_large=first_model.provider.icon_large,
|
||||||
|
status=CustomConfigurationStatus.ACTIVE
|
||||||
|
if has_active_models
|
||||||
|
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||||
|
models=[
|
||||||
|
ModelResponse(
|
||||||
|
model=model.model,
|
||||||
|
label=model.label,
|
||||||
|
model_type=model.model_type,
|
||||||
|
features=model.features,
|
||||||
|
fetch_from=model.fetch_from,
|
||||||
|
model_properties=model.model_properties,
|
||||||
|
status=model.status,
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return providers_with_models
|
||||||
Loading…
x
Reference in New Issue
Block a user