mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-21 14:23:33 +08:00
116 lines
4.5 KiB
Python
116 lines
4.5 KiB
Python
from typing import List, Dict
|
|
from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG
|
|
from model_providers import BootstrapWebBuilder
|
|
from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse
|
|
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
|
from model_providers.core.provider_manager import ProviderManager
|
|
from model_providers.core.utils.utils import (
|
|
get_config_dict,
|
|
get_log_file,
|
|
get_timestamp_ms,
|
|
)
|
|
import multiprocessing as mp
|
|
import asyncio
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def init_server(model_platforms_shard: Dict,
|
|
started_event: mp.Event = None,
|
|
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
|
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
|
|
provider_port: int = MODEL_PROVIDERS_CFG_PORT,
|
|
log_path: str = "logs"
|
|
) -> None:
|
|
logging_conf = get_config_dict(
|
|
"INFO",
|
|
get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"),
|
|
|
|
1024*1024*1024*3,
|
|
1024*1024*1024*3,
|
|
)
|
|
|
|
try:
|
|
boot = (
|
|
BootstrapWebBuilder()
|
|
.model_providers_cfg_path(
|
|
model_providers_cfg_path=model_providers_cfg_path
|
|
)
|
|
.host(host=provider_host)
|
|
.port(port=provider_port)
|
|
.build()
|
|
)
|
|
boot.set_app_event(started_event=started_event)
|
|
|
|
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
|
|
model_platforms_shard['provider_platforms'] = provider_platforms
|
|
boot.logging_conf(logging_conf=logging_conf)
|
|
boot.run()
|
|
|
|
async def pool_join_thread():
|
|
await boot.join()
|
|
|
|
asyncio.run(pool_join_thread())
|
|
except SystemExit:
|
|
logger.info("SystemExit raised, exiting")
|
|
raise
|
|
|
|
|
|
def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
|
|
provider_list: List[ProviderResponse] = ProvidersWrapper(
|
|
provider_manager=provider_manager).get_provider_list()
|
|
logger.info(f"Provider list: {provider_list}")
|
|
# 转换MODEL_PLATFORMS
|
|
provider_platforms = []
|
|
for provider in provider_list:
|
|
provider_dict = {
|
|
"platform_name": provider.provider,
|
|
"platform_type": provider.provider,
|
|
"api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1",
|
|
"api_key": "EMPTY",
|
|
"api_concurrencies": 5
|
|
}
|
|
|
|
provider_dict["llm_models"] = []
|
|
provider_dict["embed_models"] = []
|
|
provider_dict["image_models"] = []
|
|
provider_dict["reranking_models"] = []
|
|
provider_dict["speech2text_models"] = []
|
|
provider_dict["tts_models"] = []
|
|
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
|
|
provider.supported_model_types]
|
|
|
|
for model_type in supported_model_str_types:
|
|
|
|
providers_model_type = ProvidersWrapper(
|
|
provider_manager=provider_manager
|
|
).get_models_by_model_type(model_type=model_type)
|
|
cur_model_type: List[str] = []
|
|
# 查询当前provider的模型
|
|
for provider_model in providers_model_type:
|
|
if provider_model.provider == provider.provider:
|
|
models = [model.model for model in provider_model.models]
|
|
cur_model_type.extend(models)
|
|
|
|
if cur_model_type:
|
|
if model_type == "text-generation":
|
|
provider_dict["llm_models"] = cur_model_type
|
|
elif model_type == "embeddings":
|
|
provider_dict["embed_models"] = cur_model_type
|
|
elif model_type == "text2img":
|
|
provider_dict["image_models"] = cur_model_type
|
|
elif model_type == "reranking":
|
|
provider_dict["reranking_models"] = cur_model_type
|
|
elif model_type == "speech2text":
|
|
provider_dict["speech2text_models"] = cur_model_type
|
|
elif model_type == "tts":
|
|
provider_dict["tts_models"] = cur_model_type
|
|
else:
|
|
logger.warning(f"Unsupported model type: {model_type}")
|
|
|
|
provider_platforms.append(provider_dict)
|
|
|
|
logger.info(f"Provider platforms: {provider_platforms}")
|
|
|
|
return provider_platforms |