mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-29 18:29:44 +08:00
provider_configuration.py
查询所有的平台信息,包含计费策略和配置schema_validators(参数必填信息校验规则)
/workspaces/current/model-providers
查询平台模型分类的详细默认信息,包含了模型类型,模型参数,模型状态
workspaces/current/models/model-types/{model_type}
This commit is contained in:
parent
bfcf2775f5
commit
a1fe8d714f
@ -0,0 +1,183 @@
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.entities.model_entities import (
|
||||
ModelStatus,
|
||||
ModelWithProviderEntity,
|
||||
)
|
||||
from model_providers.core.entities.provider_entities import (
|
||||
ProviderQuotaType,
|
||||
ProviderType,
|
||||
QuotaConfiguration,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
ProviderModel,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
ProviderHelpEntity,
|
||||
SimpleProviderEntity,
|
||||
)
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration response.
|
||||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration response.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
preferred_provider_type: ProviderType
|
||||
custom_configuration: CustomConfigurationResponse
|
||||
system_configuration: SystemConfigurationResponse
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
#
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class ProviderListResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ProviderResponse] = []
|
||||
|
||||
|
||||
class ModelResponse(ProviderModel):
|
||||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class ProviderWithModelsResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ModelResponse]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class ProviderModelTypeResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ProviderWithModelsResponse] = []
|
||||
|
||||
|
||||
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
"""
|
||||
Simple provider entity response.
|
||||
"""
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class DefaultModelResponse(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.dict())
|
||||
@ -24,6 +24,10 @@ from sse_starlette import EventSourceResponse
|
||||
from uvicorn import Config, Server
|
||||
|
||||
from model_providers.bootstrap_web.common import create_stream_chunk
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
ProviderListResponse,
|
||||
ProviderModelTypeResponse,
|
||||
)
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionMessage,
|
||||
@ -301,6 +305,18 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/workspaces/current/model-providers",
|
||||
self.workspaces_model_providers,
|
||||
response_model=ProviderListResponse,
|
||||
methods=["GET"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/workspaces/current/models/model-types/{model_type}",
|
||||
self.workspaces_model_types,
|
||||
response_model=ProviderModelTypeResponse,
|
||||
methods=["GET"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/{provider}/v1/models",
|
||||
self.list_models,
|
||||
@ -345,6 +361,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
|
||||
async def workspaces_model_providers(self, request: Request):
|
||||
provider_list = self.get_provider_list(model_type=request.get("model_type"))
|
||||
return ProviderListResponse(data=provider_list)
|
||||
|
||||
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
|
||||
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
logger.info(f"Received list_models request for provider: {provider}")
|
||||
# 返回ModelType所有的枚举
|
||||
|
||||
@ -1,13 +1,25 @@
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
ModelResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingsRequest,
|
||||
)
|
||||
from model_providers.core.entities.model_entities import ModelStatus
|
||||
from model_providers.core.entities.provider_entities import ProviderType
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
@ -18,9 +30,150 @@ class Bootstrap:
|
||||
"""任务队列"""
|
||||
_QUEUE: deque = deque()
|
||||
|
||||
_provider_manager: ModelManager
|
||||
|
||||
def __init__(self):
|
||||
self._version = "v0.0.1"
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ModelManager:
|
||||
return self._provider_manager
|
||||
|
||||
@provider_manager.setter
|
||||
def provider_manager(self, provider_manager: ModelManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_provider_list(
|
||||
self, model_type: Optional[str] = None
|
||||
) -> List[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
if (
|
||||
model_type_entity
|
||||
not in provider_configuration.provider.supported_model_types
|
||||
):
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=ProviderType.value_of("custom"),
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
|
||||
return provider_responses
|
||||
|
||||
def get_models_by_model_type(
|
||||
self, model_type: str
|
||||
) -> List[ProviderWithModelsResponse]:
|
||||
"""
|
||||
get models by model type.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
first_model = models[0]
|
||||
|
||||
has_active_models = any(
|
||||
[model.status == ModelStatus.ACTIVE for model in models]
|
||||
)
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if has_active_models
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
models=[
|
||||
ModelResponse(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return providers_with_models
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, cfg=None):
|
||||
@ -44,19 +197,9 @@ class Bootstrap:
|
||||
|
||||
|
||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||
_provider_manager: ModelManager
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ModelManager:
|
||||
return self._provider_manager
|
||||
|
||||
@provider_manager.setter
|
||||
def provider_manager(self, provider_manager: ModelManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
pass
|
||||
|
||||
@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (
|
||||
self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0
|
||||
)
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@ -6,12 +6,82 @@ from pydantic import BaseModel
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = "custom"
|
||||
SYSTEM = "system"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = "times"
|
||||
TOKENS = "tokens"
|
||||
CREDITS = "credits"
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
model: str
|
||||
base_model_name: Optional[str] = None
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class QuotaConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider quota configuration.
|
||||
"""
|
||||
|
||||
quota_type: ProviderQuotaType
|
||||
quota_unit: QuotaUnit
|
||||
quota_limit: int
|
||||
quota_used: int
|
||||
is_valid: bool
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: Optional[dict] = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -49,7 +49,9 @@ class ModelProviderFactory:
|
||||
if init_cache:
|
||||
self.get_providers()
|
||||
|
||||
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]:
|
||||
def get_providers(
|
||||
self, provider_name: Union[str, set] = ""
|
||||
) -> list[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
@ -60,20 +62,36 @@ class ModelProviderFactory:
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for name, model_provider_extension in model_provider_extensions.items():
|
||||
if provider_name in (name, ""):
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
if isinstance(provider_name, str):
|
||||
if provider_name in (name, ""):
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
|
||||
providers.append(provider_schema)
|
||||
providers.append(provider_schema)
|
||||
elif isinstance(provider_name, set):
|
||||
if name in provider_name:
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
|
||||
providers.append(provider_schema)
|
||||
|
||||
# return providers
|
||||
return providers
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@ -45,7 +45,7 @@ class ProviderManager:
|
||||
provider_name_to_provider_model_records_dict
|
||||
)
|
||||
|
||||
def get_configurations(self, provider: str) -> ProviderConfigurations:
|
||||
def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
|
||||
"""
|
||||
Get model provider configurations.
|
||||
|
||||
@ -155,7 +155,7 @@ class ProviderManager:
|
||||
|
||||
default_model = {}
|
||||
# Get provider configurations
|
||||
provider_configurations = self.get_configurations()
|
||||
provider_configurations = self.get_configurations(provider="openai")
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user