Merge pull request #3466 from glide-the/dev_model_providers

model_providers bootstrap
This commit is contained in:
glide-the 2024-03-22 01:00:10 +08:00 committed by GitHub
commit 42dc6d18c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 592 additions and 2478 deletions

View File

@ -1,387 +1,18 @@
import os
from typing import cast, Generator
from model_providers.core.entities.application_entities import ModelConfigEntity, AppOrchestrationConfigEntity, \
PromptTemplateEntity, AdvancedChatPromptTemplateEntity, ExternalDataVariableEntity, AgentEntity, AgentToolEntity, \
AgentPromptEntity, DatasetEntity, DatasetRetrieveConfigEntity, FileUploadEntity, TextToSpeechEntity, \
SensitiveWordAvoidanceEntity, AdvancedCompletionPromptTemplateEntity
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, \
QuotaExceededError
from model_providers.core.model_manager import ModelInstance
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole, UserPromptMessage, \
AssistantPromptMessage
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.provider_manager import ProviderManager
from model_providers.core.tools.prompt.template import REACT_PROMPT_TEMPLATES
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
"""
Convert app model config dict to entity.
:param tenant_id: tenant ID
:param app_model_config_dict: app model config dict
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
properties = {}
copy_app_model_config_dict = app_model_config_dict.copy()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=copy_app_model_config_dict['model']['provider'],
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = copy_app_model_config_dict['model']['name']
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=copy_app_model_config_dict['model']['name']
)
if model_credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=copy_app_model_config_dict['model']['name'],
model_type=ModelType.LLM
)
if provider_model is None:
model_name = copy_app_model_config_dict['model']['name']
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = copy_app_model_config_dict['model'].get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = copy_app_model_config_dict['model'].get('mode')
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=copy_app_model_config_dict['model']['name'],
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
copy_app_model_config_dict['model']['name'],
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
properties['model_config'] = ModelConfigEntity(
provider=copy_app_model_config_dict['model']['provider'],
model=copy_app_model_config_dict['model']['name'],
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
# prompt template
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
# external data variables
properties['external_data_variables'] = []
# old external_data_tools
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# current external_data_tools
for variable in copy_app_model_config_dict.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
)
)
# show retrieve source
show_retrieve_source = False
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
properties['show_retrieve_source'] = show_retrieve_source
dataset_ids = []
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset['dataset']
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
else:
datasets = {'strategy': 'router', 'datasets': []}
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
and copy_app_model_config_dict['agent_mode']['enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if copy_app_model_config_dict['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
elif len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
properties['agent'] = AgentEntity(
provider=properties['model_config'].provider,
model=properties['model_config'].model,
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
)
if len(dataset_ids) > 0:
# dataset configs
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
single_strategy=datasets.get('strategy', 'router')
)
)
else:
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
# file upload
file_upload_dict = copy_app_model_config_dict.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
properties['file_upload'] = FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
# opening statement
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
# suggested questions after answer
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
properties['suggested_questions_after_answer'] = True
# more like this
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = True
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
properties['speech_to_text'] = True
# text to speech
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
properties['text_to_speech'] = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
if sensitive_word_avoidance_dict:
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
return AppOrchestrationConfigEntity(**properties)
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# provider_manager = ProviderManager()
# provider_model_bundle = provider_manager.get_provider_model_bundle(
# tenant_id="tenant_id",
# provider="copy_app_model_config_dict['model']['provider']",
# model_type=ModelType.LLM
# )
#
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
)
#
# model_instance = ModelInstance(

View File

@ -0,0 +1,236 @@
import asyncio
import os
from typing import Optional, Any, Dict
from fastapi import (APIRouter,
FastAPI,
HTTPException,
Response,
Request,
status
)
import logging
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
import json
import pprint
import tiktoken
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \
ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable
from uvicorn import Config, Server
from fastapi.middleware.cors import CORSMiddleware
import multiprocessing as mp
import threading
from sse_starlette import EventSourceResponse
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.utils.generic import dictify, jsonify
from model_providers.core.model_runtime.model_providers import model_provider_factory
logger = logging.getLogger(__name__)
async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
**chat_request.to_model_parameters_dict()
},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123"
)
return response
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
"""
Bootstrap Server Lifecycle
"""
def __init__(self, host: str, port: int):
super().__init__()
self._host = host
self._port = port
self._router = APIRouter()
self._app = FastAPI()
self._server_thread = None
@classmethod
def from_config(cls, cfg=None):
host = cfg.get("host", "127.0.0.1")
port = cfg.get("port", 20000)
logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}")
return cls(host=host, port=port)
def serve(self, logging_conf: Optional[dict] = None):
self._app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self._router.add_api_route(
"/v1/models",
self.list_models,
response_model=ModelList,
methods=["GET"],
)
self._router.add_api_route(
"/v1/embeddings",
self.create_embeddings,
response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._app.include_router(self._router)
config = Config(
app=self._app, host=self._host, port=self._port, log_config=logging_conf
)
server = Server(config)
def run_server():
server.run()
self._server_thread = threading.Thread(target=run_server)
self._server_thread.start()
async def join(self):
await self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
async def on_startup():
if started_event is not None:
started_event.set()
async def list_models(self, request: Request):
pass
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}")
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}")
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request)
return EventSourceResponse(generator, media_type="text/event-stream")
else:
response = model_type_instance.invoke(
model='gpt-4',
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
},
stop=['you'],
stream=False,
user="abc-123"
)
chat_response = ChatCompletionResponse(**dictify(response))
return chat_response
def run(
cfg: Dict, logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {}))
api.set_app_event(started_event=started_event)
api.serve(logging_conf=logging_conf)
async def pool_join_thread():
await api.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,8 @@
from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.bootstrap_register import bootstrap_register
__all__ = [
"bootstrap_register",
"Bootstrap",
"OpenAIBootstrapBaseWeb",
]

View File

@ -0,0 +1,54 @@
from abc import abstractmethod
from collections import deque
from fastapi import Request
class Bootstrap:
"""最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1
"""任务队列"""
_QUEUE: deque = deque()
def __init__(self):
self._version = "v0.0.1"
@classmethod
@abstractmethod
def from_config(cls, cfg=None):
return cls()
@property
def version(self):
return self._version
@property
def queue(self) -> deque:
return self._QUEUE
@classmethod
async def run(cls):
raise NotImplementedError
@classmethod
async def destroy(cls):
raise NotImplementedError
class OpenAIBootstrapBaseWeb(Bootstrap):
def __init__(self):
super().__init__()
@abstractmethod
async def list_models(self, request: Request):
pass
@abstractmethod
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
pass
@abstractmethod
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
pass

View File

@ -0,0 +1,51 @@
from model_providers.core.bootstrap import Bootstrap
class BootstrapRegister:
"""
注册管理器
"""
mapping = {
"bootstrap": {},
}
@classmethod
def register_bootstrap(cls, name):
r"""Register system bootstrap to registry with key 'name'
Args:
name: Key with which the task will be registered.
Usage:
from lavis.common.registry import registry
"""
print(f"register_bootstrap {name}")
def wrap(task_cls):
assert issubclass(
task_cls, Bootstrap
), "All tasks must inherit bootstrap class"
if name in cls.mapping["bootstrap"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["bootstrap"][name]
)
)
cls.mapping["bootstrap"][name] = task_cls
return task_cls
return wrap
@classmethod
def get_bootstrap_class(cls, name):
return cls.mapping["bootstrap"].get(name, None)
@classmethod
def list_bootstrap(cls):
return sorted(cls.mapping["bootstrap"].keys())
bootstrap_register = BootstrapRegister()

View File

@ -0,0 +1,143 @@
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator
from typing_extensions import Literal
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionCallDefinition(BaseModel):
name: str
class FunctionCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: Function
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class ChatMessage(BaseModel):
role: Role
content: str
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
function_call: Optional[Function] = None
class UsageInfo(BaseModel):
prompt_tokens: int
completion_tokens: Optional[int] = None
total_tokens: int
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None
functions: Optional[List[FunctionDefinition]] = None
function_call: Optional[FunctionCallDefinition] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[list[str]] = None,
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs)
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: Finish
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionStreamResponseChoice]
class EmbeddingsRequest(BaseModel):
input: Union[str, List[List[int]], List[int], List[str]]
model: str
encoding_format: Literal["base64", "float"] = "float"
class Embeddings(BaseModel):
object: Literal["embedding"] = "embedding"
embedding: Union[List[float], bytes]
index: int
class EmbeddingsResponse(BaseModel):
object: Literal["list"] = "list"
data: List[Embeddings]
model: str
usage: UsageInfo

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
# from model_providers.core.helper import encrypter
from model_providers.core.helper import encrypter
from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
@ -22,7 +22,7 @@ from model_providers.core.model_runtime.model_providers import model_provider_fa
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.extensions.ext_database import db
from model_providers.models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
logger = logging.getLogger(__name__)
@ -172,20 +172,20 @@ class ProviderConfiguration(BaseModel):
original_credentials = {}
# encrypt credentials
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# # if send [__HIDDEN__] in secret input, it will be same as original value
# if value == '[__HIDDEN__]' and key in original_credentials:
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate(
self.provider.provider,
credentials
)
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_record, credentials
@ -315,11 +315,11 @@ class ProviderConfiguration(BaseModel):
original_credentials = {}
# decrypt credentials
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# # if send [__HIDDEN__] in secret input, it will be same as original value
# if value == '[__HIDDEN__]' and key in original_credentials:
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
@ -328,9 +328,9 @@ class ProviderConfiguration(BaseModel):
credentials=credentials
)
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_model_record, credentials
@ -481,10 +481,10 @@ class ProviderConfiguration(BaseModel):
)
# Obfuscate provider credentials
# copy_credentials = credentials.copy()
# for key, value in copy_credentials.items():
# if key in credential_secret_variables:
# copy_credentials[key] = encrypter.obfuscated_token(value)
copy_credentials = credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials

View File

@ -4,7 +4,7 @@ from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.models.provider import ProviderQuotaType
from models.provider import ProviderQuotaType
class QuotaUnit(Enum):

View File

@ -1,38 +0,0 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"

View File

@ -1,90 +0,0 @@
import enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.file.upload_file_parser import UploadFileParser
from model_providers.core.model_runtime.entities.message_entities import ImagePromptMessageContent
from model_providers.extensions.ext_database import db
from model_providers.models.model import UploadFile
class FileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
TOOL_FILE = 'tool_file'
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = 'user'
ASSISTANT = 'assistant'
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileObj(BaseModel):
id: Optional[str]
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str]
upload_file_id: Optional[str]
file_config: dict
@property
def data(self) -> Optional[str]:
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
return self._get_data(force_url=True)
@property
def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE:
image_config = self.file_config.get('image')
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.upload_file_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
)
return None

View File

@ -1,184 +0,0 @@
from typing import Optional, Union
import requests
from model_providers.core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
from model_providers.extensions.ext_database import db
from model_providers.models.account import Account
from model_providers.models.model import AppModelConfig, EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> list[FileObj]:
"""
validate and transform files arg
:param files:
:param app_model_config:
:param user:
:return:
"""
file_upload_config = app_model_config.file_upload_dict
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_upload_config.get('image')
# check if image file feature is enabled
if not image_config['enabled']:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}')
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.upload_file_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(IMAGE_EXTENSIONS)
).first())
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
"""
transform message files
:param files:
:param app_model_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
"""
transform files to file objs
:param files:
:param file_upload_config:
:return:
"""
type_file_objs: dict[FileType, list[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
if isinstance(file, MessageFile):
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue
file_obj = self._to_file_obj(file, file_upload_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
return FileObj(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
file_config=file_upload_config
)
else:
return FileObj(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
upload_file_id=file.upload_file_id or None,
file_config=file_upload_config
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

View File

@ -1,8 +0,0 @@
tool_file_manager = {
'manager': None
}
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> 'ToolFileManager':
return tool_file_manager['manager']

View File

@ -1,79 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from flask import current_app
from model_providers.extensions.ext_storage import storage
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
return cls.get_signed_temp_image_url(upload_file)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
@classmethod
def get_signed_temp_image_url(cls, upload_file) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = current_app.config.get('FILES_URL')
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes

View File

@ -1,51 +0,0 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from model_providers.extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -1,250 +0,0 @@
from typing import Optional
from flask import Config, Flask
from pydantic import BaseModel
from model_providers.core.entities.provider_entities import QuotaUnit, RestrictModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.models.provider import ProviderQuotaType
class HostingQuota(BaseModel):
quota_type: ProviderQuotaType
restrict_models: list[RestrictModel] = []
class TrialHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
quota_limit: int = 0
"""Quota limit for the hosting provider models. -1 means unlimited."""
class PaidHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
class FreeHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
class HostingProvider(BaseModel):
enabled: bool = False
credentials: Optional[dict] = None
quota_unit: Optional[QuotaUnit] = None
quotas: list[HostingQuota] = []
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask) -> None:
config = app.config
if config.get('EDITION') != 'CLOUD':
return
self.provider_map["azure_openai"] = self.init_azure_openai(config)
self.provider_map["openai"] = self.init_openai(config)
self.provider_map["anthropic"] = self.init_anthropic(config)
self.provider_map["minimax"] = self.init_minimax(config)
self.provider_map["spark"] = self.init_spark(config)
self.provider_map["zhipuai"] = self.init_zhipuai(config)
self.moderation_config = self.init_moderation_config(config)
def init_azure_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
credentials = {
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}
quotas = []
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas = []
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
restrict_models=paid_models
)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
}
if app_config.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas = []
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit
)
quotas.append(trial_quota)
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
}
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_minimax(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_MINIMAX_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_spark(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_SPARK_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_zhipuai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_moderation_config(self, app_config: Config) -> HostedModerationConfig:
if app_config.get("HOSTED_MODERATION_ENABLED") \
and app_config.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True,
providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)
return HostedModerationConfig(
enabled=False
)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]

View File

@ -1,257 +0,0 @@
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.errors.error import ProviderTokenNotInitError
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from model_providers.core.provider_manager import ProviderManager
class ModelInstance:
"""
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
:param language: tts language
:return: tts model voices
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
)
class ModelManager:
def __init__(self) -> None:
self._provider_manager = ProviderManager()
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
if not provider:
return self.get_default_model_instance(tenant_id, model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
:param tenant_id: tenant id
:param model_type: model type
:return:
"""
default_model_entity = self._provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type
)
if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
return self.get_model_instance(
tenant_id=tenant_id,
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
)

View File

@ -3,8 +3,7 @@ import copy
from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from fastapi.responses import StreamingResponse
from openai import AzureOpenAI
from pydub import AudioSegment
@ -37,15 +36,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
if not voice or voice not in [d['value'] for d in
self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
return StreamingResponse(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice), media_type='text/event-stream')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
@ -68,7 +68,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
"""
_tts_invoke text2speech model
@ -103,7 +103,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
@ -160,7 +160,6 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in TTS_BASE_MODELS:

View File

@ -3,10 +3,9 @@ from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from openai import OpenAI
from pydub import AudioSegment
from fastapi.responses import StreamingResponse
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
@ -37,12 +36,11 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
return StreamingResponse(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice), media_type='text/event-stream')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
@ -65,7 +63,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
"""
_tts_invoke text2speech model
@ -100,7 +98,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -4,9 +4,8 @@ from io import BytesIO
from typing import Optional
import dashscope
from flask import Response, stream_with_context
from pydub import AudioSegment
from fastapi.responses import StreamingResponse
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
@ -37,12 +36,11 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
tenant_id=tenant_id)),
status=200, mimetype=f'audio/{audio_type}')
return StreamingResponse(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice), media_type='text/event-stream')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
@ -101,7 +99,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))

View File

@ -1,761 +0,0 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from sqlalchemy.exc import IntegrityError
from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from model_providers.core.entities.provider_entities import (
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
QuotaConfiguration,
SystemConfiguration,
)
# from model_providers.core.helper import encrypter
from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.extensions import ext_hosting_provider
from model_providers.extensions.ext_database import db
from model_providers.models.provider import (
Provider,
ProviderModel,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
Get model provider configurations.
Construct ProviderConfiguration objects for each provider
Including:
1. Basic information of the provider
2. Hosting configuration information, including:
(1. Whether to enable (support) hosting type, if enabled, the following information exists
(2. List of hosting type provider configurations
(including quota type, quota limit, current remaining quota, etc.)
(3. The current hosting type in use (whether there is a quota or not)
paid quotas > provider free quotas > hosting trial quotas
(4. Unified credentials for hosting providers
3. Custom configuration information, including:
(1. Whether to enable (support) custom type, if enabled, the following information exists
(2. Custom provider configuration (including credentials)
(3. List of custom provider model configurations (including credentials)
4. Hosting/custom preferred provider type.
Provide methods:
- Get the current configuration (including credentials)
- Get the availability and status of the hosting configuration: active available,
quota_exceeded insufficient quota, unsupported hosting
- Get the availability of custom configuration
Custom provider available conditions:
(1. custom provider credentials available
(2. at least one custom model credentials available
- Verify, update, and delete custom provider configuration
- Verify, update, and delete custom provider model configuration
- Get the list of available models (optional provider filtering, model type filtering)
Append custom provider models to the list
- Get provider instance
- Switch selection priority
:param tenant_id:
:return:
"""
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
# Initialize trial provider records if not exist
provider_name_to_provider_records_dict = self._init_trial_provider_records(
tenant_id,
provider_name_to_provider_records_dict
)
# Get all provider model records of the workspace
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
# Get all provider entities
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
)
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider)
if not provider_records:
provider_records = []
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
if not provider_model_records:
provider_model_records = []
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
tenant_id,
provider_entity,
provider_records,
provider_model_records
)
# Convert to system configuration
system_configuration = self._to_system_configuration(
tenant_id,
provider_entity,
provider_records
)
# Get preferred provider type
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
else:
if custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
else:
preferred_provider_type = ProviderType.CUSTOM
using_provider_type = preferred_provider_type
if preferred_provider_type == ProviderType.SYSTEM:
if not system_configuration.enabled:
using_provider_type = ProviderType.CUSTOM
has_valid_quota = False
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.is_valid:
has_valid_quota = True
break
if not has_valid_quota:
using_provider_type = ProviderType.CUSTOM
else:
if not custom_configuration.provider and not custom_configuration.models:
if system_configuration.enabled:
has_valid_quota = False
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.is_valid:
has_valid_quota = True
break
if has_valid_quota:
using_provider_type = ProviderType.SYSTEM
provider_configuration = ProviderConfiguration(
tenant_id=tenant_id,
provider=provider_entity,
preferred_provider_type=preferred_provider_type,
using_provider_type=using_provider_type,
system_configuration=system_configuration,
custom_configuration=custom_configuration
)
provider_configurations[provider_name] = provider_configuration
# Return the encapsulated object
return provider_configurations
def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle:
"""
Get provider model bundle.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:return:
"""
provider_configurations = self.get_configurations(tenant_id)
# get provider instance
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_instance = provider_configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model_type)
return ProviderModelBundle(
configuration=provider_configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]:
"""
Get default model.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get the corresponding TenantDefaultModel record
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type()
).first()
# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
if not default_model:
# Get provider configurations
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=model_type,
only_active=True
)
if available_models:
found = False
for available_model in available_models:
if available_model.model == "gpt-3.5-turbo-1106":
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model
)
db.session.add(default_model)
db.session.commit()
found = True
break
if not found:
available_model = available_models[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model
)
db.session.add(default_model)
db.session.commit()
if not default_model:
return None
provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name)
provider_schema = provider_instance.get_provider_schema()
return DefaultModelEntity(
model=default_model.model_name,
model_type=model_type,
provider=DefaultModelProviderEntity(
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types
)
)
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel:
"""
Update default model record.
:param tenant_id: workspace id
:param model_type: model type
:param provider: provider name
:param model: model name
:return:
"""
provider_configurations = self.get_configurations(tenant_id)
if provider not in provider_configurations:
raise ValueError(f"Provider {provider} does not exist.")
# get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=model_type,
only_active=True
)
# check if the model is exist in available models
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")
# Get the list of available models from get_configurations and check if it is LLM
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type()
).first()
# create or update TenantDefaultModel record
if default_model:
# update default model
default_model.provider_name = provider
default_model.model_name = model
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider,
model_name=model,
)
db.session.add(default_model)
db.session.commit()
return default_model
def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.is_valid == True
).all()
provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
provider_name_to_provider_records_dict[provider.provider_name].append(provider)
return provider_name_to_provider_records_dict
def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]:
"""
Get all provider model records of the workspace.
:param tenant_id: workspace id
:return:
"""
# Get all provider model records of the workspace
provider_models = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == tenant_id,
ProviderModel.is_valid == True
).all()
provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict
def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
"""
Get All preferred provider types of the workspace.
:param tenant_id:
:return:
"""
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id
).all()
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict
def _init_trial_provider_records(self, tenant_id: str,
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
"""
Initialize trial provider records if not exists.
:param tenant_id: workspace id
:param provider_name_to_provider_records_dict: provider name to provider records dict
:return:
"""
# Get hosting configuration
hosting_configuration = ext_hosting_provider.hosting_configuration
for provider_name, configuration in hosting_configuration.provider_map.items():
if not configuration.enabled:
continue
provider_records = provider_name_to_provider_records_dict.get(provider_name)
if not provider_records:
provider_records = []
provider_quota_to_provider_record_dict = dict()
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
= provider_record
for quota in configuration.quotas:
if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
provider_record = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit,
quota_used=0,
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
if provider_record and not provider_record.is_valid:
provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record)
return provider_name_to_provider_records_dict
def _to_custom_configuration(self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_records: list[Provider],
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
"""
Convert to custom configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:param provider_model_records: provider model records
:return:
"""
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema else []
)
# Get custom provider record
custom_provider_record = None
for provider_record in provider_records:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
if not provider_record.encrypted_config:
continue
custom_provider_record = provider_record
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
try:
# fix origin data
if (custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")):
provider_credentials = {
"openai_api_key": custom_provider_record.encrypted_config
}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# # Get decoding rsa key and cipher for decrypting credentials
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# for variable in provider_credential_secret_variables:
# if variable in provider_credentials:
# try:
# provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
# provider_credentials.get(variable),
# self.decoding_rsa_key,
# self.decoding_cipher_rsa
# )
# except ValueError:
# pass
# cache provider credentials
provider_credentials_cache.set(
credentials=provider_credentials
)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema else []
)
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.encrypted_config:
continue
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# # Get decoding rsa key and cipher for decrypting credentials
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
#
# for variable in model_credential_secret_variables:
# if variable in provider_model_credentials:
# try:
# provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
# provider_model_credentials.get(variable),
# self.decoding_rsa_key,
# self.decoding_cipher_rsa
# )
# except ValueError:
# pass
# cache provider model credentials
provider_model_credentials_cache.set(
credentials=provider_model_credentials
)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
credentials=provider_model_credentials
)
)
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations
)
def _to_system_configuration(self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_records: list[Provider]) -> SystemConfiguration:
"""
Convert to system configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:return:
"""
# Get hosting configuration
hosting_configuration = ext_hosting_provider.hosting_configuration
if provider_entity.provider not in hosting_configuration.provider_map \
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled:
return SystemConfiguration(
enabled=False
)
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
# Convert provider_records to dict
quota_type_to_provider_records_dict = dict()
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
= provider_record
quota_configurations = []
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit,
quota_used=0,
quota_limit=0,
is_valid=False,
restrict_models=provider_quota.restrict_models
)
else:
continue
else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models
)
quota_configurations.append(quota_configuration)
if len(quota_configurations) == 0:
return SystemConfiguration(
enabled=False
)
current_quota_type = self._choice_current_using_quota_type(quota_configurations)
current_using_credentials = provider_hosting_configuration.credentials
if current_quota_type == ProviderQuotaType.FREE:
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
if provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
try:
provider_credentials = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# # Get provider credential secret variables
# provider_credential_secret_variables = self._extract_secret_variables(
# provider_entity.provider_credential_schema.credential_form_schemas
# if provider_entity.provider_credential_schema else []
# )
# # Get decoding rsa key and cipher for decrypting credentials
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
#
# for variable in provider_credential_secret_variables:
# if variable in provider_credentials:
# try:
# provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
# provider_credentials.get(variable),
# self.decoding_rsa_key,
# self.decoding_cipher_rsa
# )
# except ValueError:
# pass
current_using_credentials = provider_credentials
# cache provider credentials
provider_credentials_cache.set(
credentials=current_using_credentials
)
else:
current_using_credentials = cached_provider_credentials
else:
current_using_credentials = {}
quota_configurations = []
return SystemConfiguration(
enabled=True,
current_quota_type=current_quota_type,
quota_configurations=quota_configurations,
credentials=current_using_credentials
)
def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType:
"""
Choice current using quota type.
paid quotas > provider free quotas > hosting trial quotas
If there is still quota for the corresponding quota type according to the sorting,
:param quota_configurations:
:return:
"""
# convert to dict
quota_type_to_quota_configuration_dict = {
quota_configuration.quota_type: quota_configuration
for quota_configuration in quota_configurations
}
last_quota_configuration = None
for quota_type in [ProviderQuotaType.PAID, ProviderQuotaType.FREE, ProviderQuotaType.TRIAL]:
if quota_type in quota_type_to_quota_configuration_dict:
last_quota_configuration = quota_type_to_quota_configuration_dict[quota_type]
if last_quota_configuration.is_valid:
return quota_type
if last_quota_configuration:
return last_quota_configuration.quota_type
raise ValueError('No quota type available')
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -1,102 +0,0 @@
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {{query}}
Thought: {{agent_scratchpad}}"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
"""
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}

View File

@ -0,0 +1,21 @@
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except Exception: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except Exception: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)

View File

@ -0,0 +1,12 @@
import orjson
import os
from pydantic import BaseModel
def json_dumps(o):
def _default(obj):
if isinstance(obj, BaseModel):
return obj.dict()
raise TypeError
return orjson.dumps(o, default=_default)

View File

@ -1,9 +0,0 @@
from flask import Flask
from model_providers.core.hosting_configuration import HostingConfiguration
hosting_configuration = HostingConfiguration()
def init_app(app: Flask):
hosting_configuration.init_app(app)

View File

@ -1,7 +0,0 @@
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy()
def init_app(app):
db.init_app(app)

View File

@ -1,9 +0,0 @@
from flask import Flask
from model_providers.core.hosting_configuration import HostingConfiguration
hosting_configuration = HostingConfiguration()
def init_app(app: Flask):
hosting_configuration.init_app(app)

View File

@ -6,7 +6,6 @@ from typing import Union
import boto3
from botocore.exceptions import ClientError
from flask import Flask
class Storage:
@ -16,21 +15,21 @@ class Storage:
self.client = None
self.folder = None
def init_app(self, app: Flask):
self.storage_type = app.config.get('STORAGE_TYPE')
def init_config(self, config: dict):
self.storage_type = config.get('STORAGE_TYPE')
if self.storage_type == 's3':
self.bucket_name = app.config.get('S3_BUCKET_NAME')
self.bucket_name = config.get('S3_BUCKET_NAME')
self.client = boto3.client(
's3',
aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
endpoint_url=app.config.get('S3_ENDPOINT'),
region_name=app.config.get('S3_REGION')
aws_secret_access_key=config.get('S3_SECRET_KEY'),
aws_access_key_id=config.get('S3_ACCESS_KEY'),
endpoint_url=config.get('S3_ENDPOINT'),
region_name=config.get('S3_REGION')
)
else:
self.folder = app.config.get('STORAGE_LOCAL_PATH')
self.folder = config.get('STORAGE_LOCAL_PATH')
if not os.path.isabs(self.folder):
self.folder = os.path.join(app.root_path, self.folder)
self.folder = os.path.join(config.get('root_path'), self.folder)
def save(self, filename, data):
if self.storage_type == 's3':
@ -140,5 +139,3 @@ class Storage:
storage = Storage()
def init_app(app: Flask):
storage.init_app(app)

View File

@ -1 +0,0 @@
# -*- coding:utf-8 -*-

View File

@ -1,28 +0,0 @@
from sqlalchemy import Float, text
from sqlalchemy.dialects.postgresql import UUID
from model_providers.extensions.ext_database import db
class UploadFile(db.Model):
__tablename__ = 'upload_files'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='upload_file_pkey'),
db.Index('upload_file_tenant_idx', 'tenant_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
storage_type = db.Column(db.String(255), nullable=False)
key = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
size = db.Column(db.Integer, nullable=False)
extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
used_by = db.Column(UUID, nullable=True)
used_at = db.Column(db.DateTime, nullable=True)
hash = db.Column(db.String(255), nullable=True)

View File

@ -1,160 +0,0 @@
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from model_providers.extensions.ext_database import db
class ProviderType(Enum):
CUSTOM = 'custom'
SYSTEM = 'system'
@staticmethod
def value_of(value):
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = 'paid'
"""hosted paid quota"""
FREE = 'free'
"""third-party free quota"""
TRIAL = 'trial'
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class Provider(db.Model):
"""
Provider model representing the API providers and their configurations.
"""
__tablename__ = 'providers'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='provider_pkey'),
db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'),
db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
last_used = db.Column(db.DateTime, nullable=True)
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
quota_limit = db.Column(db.BigInteger, nullable=True)
quota_used = db.Column(db.BigInteger, default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
def __repr__(self):
return f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}', provider_type='{self.provider_type}')>"
@property
def token_is_set(self):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_config is not None
@property
def is_enabled(self):
"""
Returns True if the provider is enabled.
"""
if self.provider_type == ProviderType.SYSTEM.value:
return self.is_valid
else:
return self.is_valid and self.token_is_set
class ProviderModel(db.Model):
"""
Provider model representing the API provider_models and their configurations.
"""
__tablename__ = 'provider_models'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='provider_model_pkey'),
db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'),
db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class TenantDefaultModel(db.Model):
__tablename__ = 'tenant_default_models'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'),
db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class TenantPreferredModelProvider(db.Model):
__tablename__ = 'tenant_preferred_model_providers'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'),
db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class ProviderOrder(db.Model):
__tablename__ = 'provider_orders'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='provider_order_pkey'),
db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
account_id = db.Column(UUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191))
transaction_id = db.Column(db.String(191))
quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1'))
currency = db.Column(db.String(40))
total_amount = db.Column(db.Integer)
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
paid_at = db.Column(db.DateTime)
pay_failed_at = db.Column(db.DateTime)
refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -8,8 +8,9 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
transformers = "4.31.0"
Flask-SQLAlchemy = "3.0.5"
SQLAlchemy = "1.4.28"
fastapi = "^0.108"
uvicorn = "0.25.0"
sse-starlette = "^1.8.2"
pyyaml = "6.0.1"
pydantic = "1.10.14"
redis = "4.5.4"
@ -18,7 +19,6 @@ openai = "1.13.3"
tiktoken = "0.5.2"
pydub = "0.25.1"
boto3 = "1.28.17"
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
# dependencies used for running tests (e.g., pytest, freezegun, response).