mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
统一模型类型编码
This commit is contained in:
parent
1f378256e4
commit
7b6cc81ff5
@ -69,10 +69,7 @@ LLM_MODEL_CONFIG = {
|
||||
"sd-turbo": {
|
||||
"size": "256*256",
|
||||
}
|
||||
},
|
||||
"multimodal_model": {
|
||||
"qwen-vl": {}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台
|
||||
@ -116,7 +113,9 @@ MODEL_PLATFORMS = [
|
||||
"Embedding-V1",
|
||||
],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
"reranking_models": [],
|
||||
"speech2text_models": [],
|
||||
"tts_models": [],
|
||||
},
|
||||
|
||||
|
||||
|
||||
@ -74,7 +74,9 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
|
||||
provider_dict["llm_models"] = []
|
||||
provider_dict["embed_models"] = []
|
||||
provider_dict["image_models"] = []
|
||||
provider_dict["multimodal_models"] = []
|
||||
provider_dict["reranking_models"] = []
|
||||
provider_dict["speech2text_models"] = []
|
||||
provider_dict["tts_models"] = []
|
||||
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
|
||||
provider.supported_model_types]
|
||||
|
||||
@ -93,12 +95,16 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
|
||||
if cur_model_type:
|
||||
if model_type == "text-generation":
|
||||
provider_dict["llm_models"] = cur_model_type
|
||||
elif model_type == "text-embedding":
|
||||
elif model_type == "embeddings":
|
||||
provider_dict["embed_models"] = cur_model_type
|
||||
elif model_type == "text2img":
|
||||
provider_dict["image_models"] = cur_model_type
|
||||
elif model_type == "multimodal":
|
||||
provider_dict["multimodal_models"] = cur_model_type
|
||||
elif model_type == "reranking":
|
||||
provider_dict["reranking_models"] = cur_model_type
|
||||
elif model_type == "speech2text":
|
||||
provider_dict["speech2text_models"] = cur_model_type
|
||||
elif model_type == "tts":
|
||||
provider_dict["tts_models"] = cur_model_type
|
||||
else:
|
||||
logger.warning(f"Unsupported model type: {model_type}")
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ def get_config_platforms() -> Dict[str, Dict]:
|
||||
|
||||
def get_config_models(
|
||||
model_name: str = None,
|
||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||
model_type: Literal["llm", "embed", "image", "reranking","speech2text","tts"] = None,
|
||||
platform_name: str = None,
|
||||
) -> Dict[str, Dict]:
|
||||
'''
|
||||
@ -88,7 +88,14 @@ def get_config_models(
|
||||
continue
|
||||
|
||||
if model_type is None:
|
||||
model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"]
|
||||
model_types = [
|
||||
"llm_models",
|
||||
"embed_models",
|
||||
"image_models",
|
||||
"reranking_models",
|
||||
"speech2text_models",
|
||||
"tts_models",
|
||||
]
|
||||
else:
|
||||
model_types = [f"{model_type}_models"]
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import List
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
@ -47,3 +48,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
|
||||
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
for model_type in ModelType.__members__.values():
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
|
||||
|
||||
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
@ -43,3 +44,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
|
||||
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
for model_type in ModelType.__members__.values():
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
|
||||
|
||||
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
for model_type in ModelType.__members__.values():
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
|
||||
|
||||
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
@ -7,3 +5,8 @@ xinference:
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
- model: 'bge-m3'
|
||||
model_type: 'embeddings'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'bge-m3'
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
for model_type in ModelType.__members__.values():
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
|
||||
|
||||
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
for model_type in ModelType.__members__.values():
|
||||
models_by_model_type = ProvidersWrapper(
|
||||
provider_manager=provider_manager
|
||||
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
|
||||
|
||||
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user