From 4873d136b517bd7422d1c3a7d6016666cbd204b2 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 22 Mar 2024 00:43:09 +0800 Subject: [PATCH 1/2] model_providers bootstrap --- model_providers/model_providers/__main__.py | 381 +-------- .../bootstrap_web/openai_bootstrap_web.py | 236 ++++++ .../model_providers/core/__init__.py | 0 .../core/bootstrap/__init__.py | 8 + .../model_providers/core/bootstrap/base.py | 54 ++ .../core/bootstrap/bootstrap_register.py | 51 ++ .../core/bootstrap/openai_protocol.py | 143 ++++ .../core/entities/provider_configuration.py | 44 +- .../core/entities/provider_entities.py | 2 +- .../model_providers/core/errors/__init__.py | 0 .../model_providers/core/errors/error.py | 38 - .../model_providers/core/file/__init__.py | 0 .../model_providers/core/file/file_obj.py | 90 --- .../core/file/message_file_parser.py | 184 ----- .../core/file/tool_file_parser.py | 8 - .../core/file/upload_file_parser.py | 79 -- .../model_providers/core/helper/__init__.py | 0 .../core/helper/model_provider_cache.py | 51 -- .../core/hosting_configuration.py | 250 ------ .../model_providers/core/model_manager.py | 257 ------ .../model_providers/azure_openai/tts/tts.py | 23 +- .../model_providers/openai/tts/tts.py | 18 +- .../model_providers/tongyi/tts/tts.py | 16 +- .../model_providers/core/provider_manager.py | 761 ------------------ .../core/tools/prompt/template.py | 102 --- .../model_providers/core/utils/generic.py | 21 + .../model_providers/core/utils/json_dumps.py | 12 + .../model_providers/ext_hosting_provider.py | 9 - .../extensions/ext_database.py | 7 - .../extensions/ext_hosting_provider.py | 9 - .../model_providers/extensions/ext_storage.py | 21 +- .../model_providers/models/__init__.py | 1 - .../model_providers/models/model.py | 28 - .../model_providers/models/provider.py | 160 ---- model_providers/pyproject.toml | 6 +- 35 files changed, 592 insertions(+), 2478 deletions(-) create mode 100644 model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py delete mode 100644 model_providers/model_providers/core/__init__.py create mode 100644 model_providers/model_providers/core/bootstrap/__init__.py create mode 100644 model_providers/model_providers/core/bootstrap/base.py create mode 100644 model_providers/model_providers/core/bootstrap/bootstrap_register.py create mode 100644 model_providers/model_providers/core/bootstrap/openai_protocol.py delete mode 100644 model_providers/model_providers/core/errors/__init__.py delete mode 100644 model_providers/model_providers/core/errors/error.py delete mode 100644 model_providers/model_providers/core/file/__init__.py delete mode 100644 model_providers/model_providers/core/file/file_obj.py delete mode 100644 model_providers/model_providers/core/file/message_file_parser.py delete mode 100644 model_providers/model_providers/core/file/tool_file_parser.py delete mode 100644 model_providers/model_providers/core/file/upload_file_parser.py delete mode 100644 model_providers/model_providers/core/helper/__init__.py delete mode 100644 model_providers/model_providers/core/helper/model_provider_cache.py delete mode 100644 model_providers/model_providers/core/hosting_configuration.py delete mode 100644 model_providers/model_providers/core/model_manager.py delete mode 100644 model_providers/model_providers/core/provider_manager.py delete mode 100644 model_providers/model_providers/core/tools/prompt/template.py create mode 100644 model_providers/model_providers/core/utils/generic.py create mode 100644 model_providers/model_providers/core/utils/json_dumps.py delete mode 100644 model_providers/model_providers/ext_hosting_provider.py delete mode 100644 model_providers/model_providers/extensions/ext_database.py delete mode 100644 model_providers/model_providers/extensions/ext_hosting_provider.py delete mode 100644 model_providers/model_providers/models/__init__.py delete mode 100644 model_providers/model_providers/models/model.py delete mode 100644 model_providers/model_providers/models/provider.py diff --git a/model_providers/model_providers/__main__.py b/model_providers/model_providers/__main__.py index 6906ead3..c08e76c5 100644 --- a/model_providers/model_providers/__main__.py +++ b/model_providers/model_providers/__main__.py @@ -1,387 +1,18 @@ import os from typing import cast, Generator -from model_providers.core.entities.application_entities import ModelConfigEntity, AppOrchestrationConfigEntity, \ - PromptTemplateEntity, AdvancedChatPromptTemplateEntity, ExternalDataVariableEntity, AgentEntity, AgentToolEntity, \ - AgentPromptEntity, DatasetEntity, DatasetRetrieveConfigEntity, FileUploadEntity, TextToSpeechEntity, \ - SensitiveWordAvoidanceEntity, AdvancedCompletionPromptTemplateEntity -from model_providers.core.entities.model_entities import ModelStatus -from model_providers.core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, \ - QuotaExceededError -from model_providers.core.model_manager import ModelInstance from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta -from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole, UserPromptMessage, \ - AssistantPromptMessage +from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.provider_manager import ProviderManager -from model_providers.core.tools.prompt.template import REACT_PROMPT_TEMPLATES - - -def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ - -> AppOrchestrationConfigEntity: - """ - Convert app model config dict to entity. - :param tenant_id: tenant ID - :param app_model_config_dict: app model config dict - :raises ProviderTokenNotInitError: provider token not init error - :return: app orchestration config entity - """ - properties = {} - - copy_app_model_config_dict = app_model_config_dict.copy() - - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=copy_app_model_config_dict['model']['provider'], - model_type=ModelType.LLM - ) - - provider_name = provider_model_bundle.configuration.provider.provider - model_name = copy_app_model_config_dict['model']['name'] - - model_type_instance = provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # check model credentials - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=copy_app_model_config_dict['model']['name'] - ) - - if model_credentials is None: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = copy_app_model_config_dict['model'].get('completion_params') - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = copy_app_model_config_dict['model'].get('mode') - if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=copy_app_model_config_dict['model']['name'], - credentials=model_credentials - ) - - model_mode = mode_enum.value - - model_schema = model_type_instance.get_model_schema( - copy_app_model_config_dict['model']['name'], - model_credentials - ) - - if not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - properties['model_config'] = ModelConfigEntity( - provider=copy_app_model_config_dict['model']['provider'], - model=copy_app_model_config_dict['model']['name'], - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - # prompt template - prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type']) - if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: - simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "") - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) - else: - advanced_chat_prompt_template = None - chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {}) - if chat_prompt_config: - chat_prompt_messages = [] - for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) - - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) - - advanced_completion_prompt_template = None - completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {}) - if completion_prompt_config: - completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], - } - - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] - } - - advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( - **completion_prompt_template_params - ) - - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template - ) - - # external data variables - properties['external_data_variables'] = [] - - # old external_data_tools - external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) - for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: - continue - - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] - ) - ) - - # current external_data_tools - for variable in copy_app_model_config_dict.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] - ) - ) - - # show retrieve source - show_retrieve_source = False - retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource') - if retriever_resource_dict: - if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: - show_retrieve_source = True - - properties['show_retrieve_source'] = show_retrieve_source - - dataset_ids = [] - if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): - datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) - - - for dataset in datasets.get('datasets', []): - keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': - continue - dataset = dataset['dataset'] - - if 'enabled' not in dataset or not dataset['enabled']: - continue - - dataset_id = dataset.get('id', None) - if dataset_id: - dataset_ids.append(dataset_id) - else: - datasets = {'strategy': 'router', 'datasets': []} - - if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ - and 'enabled' in copy_app_model_config_dict['agent_mode'] \ - and copy_app_model_config_dict['agent_mode']['enabled']: - - agent_dict = copy_app_model_config_dict.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if copy_app_model_config_dict['model']['provider'] == 'openai': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - - agent_tools = [] - for tool in agent_dict.get('tools', []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: - continue - - agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} - } - - agent_tools.append(AgentToolEntity(**agent_tool_properties)) - elif len(keys) == 1: - # old standard - key = list(tool.keys())[0] - - if key != 'dataset': - continue - - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - continue - - dataset_id = tool_item['id'] - dataset_ids.append(dataset_id) - - if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ - copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} - # check model mode - model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), - ) - else: - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), - ) - - properties['agent'] = AgentEntity( - provider=properties['model_config'].provider, - model=properties['model_config'].model, - strategy=strategy, - prompt=agent_prompt_entity, - tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) - ) - - if len(dataset_ids) > 0: - # dataset configs - dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) - query_variable = copy_app_model_config_dict.get('dataset_query_variable') - - if dataset_configs['retrieval_model'] == 'single': - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - single_strategy=datasets.get('strategy', 'router') - ) - ) - else: - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - top_k=dataset_configs.get('top_k'), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') - ) - ) - - # file upload - file_upload_dict = copy_app_model_config_dict.get('file_upload') - if file_upload_dict: - if 'image' in file_upload_dict and file_upload_dict['image']: - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - properties['file_upload'] = FileUploadEntity( - image_config={ - 'number_limits': file_upload_dict['image']['number_limits'], - 'detail': file_upload_dict['image']['detail'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] - } - ) - - # opening statement - properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement') - - # suggested questions after answer - suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer') - if suggested_questions_after_answer_dict: - if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: - properties['suggested_questions_after_answer'] = True - - # more like this - more_like_this_dict = copy_app_model_config_dict.get('more_like_this') - if more_like_this_dict: - if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: - properties['more_like_this'] = True - - # speech to text - speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text') - if speech_to_text_dict: - if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: - properties['speech_to_text'] = True - - # text to speech - text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech') - if text_to_speech_dict: - if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: - properties['text_to_speech'] = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), - ) - - # sensitive word avoidance - sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') - if sensitive_word_avoidance_dict: - if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: - properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), - ) - - return AppOrchestrationConfigEntity(**properties) - if __name__ == '__main__': # 基于配置管理器创建的模型实例 # provider_manager = ProviderManager() - # provider_model_bundle = provider_manager.get_provider_model_bundle( - # tenant_id="tenant_id", - # provider="copy_app_model_config_dict['model']['provider']", - # model_type=ModelType.LLM - # ) - # + + provider_configurations = ProviderConfigurations( + tenant_id=tenant_id + ) + # # model_instance = ModelInstance( diff --git a/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py new file mode 100644 index 00000000..57fd0add --- /dev/null +++ b/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -0,0 +1,236 @@ +import asyncio +import os +from typing import Optional, Any, Dict + +from fastapi import (APIRouter, + FastAPI, + HTTPException, + Response, + Request, + status + ) +import logging +from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb +import json +import pprint +import tiktoken +from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \ + ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable +from uvicorn import Config, Server +from fastapi.middleware.cors import CORSMiddleware +import multiprocessing as mp +import threading +from sse_starlette import EventSourceResponse + +from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel +from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.utils.generic import dictify, jsonify + +from model_providers.core.model_runtime.model_providers import model_provider_factory + +logger = logging.getLogger(__name__) + + +async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest): + try: + + + response = model_type_instance.invoke( + model=chat_request.model, + credentials={ + 'openai_api_key': "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", + 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), + 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + }, + prompt_messages=[ + UserPromptMessage( + content='北京今天的天气怎么样' + ) + ], + model_parameters={ + **chat_request.to_model_parameters_dict() + }, + stop=chat_request.stop, + stream=chat_request.stream, + user="abc-123" + ) + return response + + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + +class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): + """ + Bootstrap Server Lifecycle + """ + + def __init__(self, host: str, port: int): + super().__init__() + self._host = host + self._port = port + self._router = APIRouter() + self._app = FastAPI() + self._server_thread = None + + @classmethod + def from_config(cls, cfg=None): + host = cfg.get("host", "127.0.0.1") + port = cfg.get("port", 20000) + + logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}") + return cls(host=host, port=port) + + def serve(self, logging_conf: Optional[dict] = None): + self._app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + self._router.add_api_route( + "/v1/models", + self.list_models, + response_model=ModelList, + methods=["GET"], + ) + + self._router.add_api_route( + "/v1/embeddings", + self.create_embeddings, + response_model=EmbeddingsResponse, + status_code=status.HTTP_200_OK, + methods=["POST"], + ) + self._router.add_api_route( + "/v1/chat/completions", + self.create_chat_completion, + response_model=ChatCompletionResponse, + status_code=status.HTTP_200_OK, + methods=["POST"], + ) + + self._app.include_router(self._router) + + config = Config( + app=self._app, host=self._host, port=self._port, log_config=logging_conf + ) + server = Server(config) + + def run_server(): + server.run() + + self._server_thread = threading.Thread(target=run_server) + self._server_thread.start() + + async def join(self): + await self._server_thread.join() + + def set_app_event(self, started_event: mp.Event = None): + @self._app.on_event("startup") + async def on_startup(): + if started_event is not None: + started_event.set() + + async def list_models(self, request: Request): + pass + + async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest): + logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}") + if os.environ["API_KEY"] is None: + authorization = request.headers.get("Authorization") + authorization = authorization.split("Bearer ")[-1] + else: + authorization = os.environ["API_KEY"] + client = ZhipuAI(api_key=authorization) + # 判断embeddings_request.input是否为list + input = None + if isinstance(embeddings_request.input, list): + tokens = embeddings_request.input + try: + encoding = tiktoken.encoding_for_model(embeddings_request.model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken.get_encoding(model) + for i, token in enumerate(tokens): + text = encoding.decode(token) + input += text + + else: + input = embeddings_request.input + + response = client.embeddings.create( + model=embeddings_request.model, + input=input, + ) + return EmbeddingsResponse(**dictify(response)) + + async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest): + logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}") + if os.environ["API_KEY"] is None: + authorization = request.headers.get("Authorization") + authorization = authorization.split("Bearer ")[-1] + else: + authorization = os.environ["API_KEY"] + model_provider_factory.get_providers(provider_name='openai') + provider_instance = model_provider_factory.get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + if chat_request.stream: + generator = create_stream_chat_completion(model_type_instance, chat_request) + return EventSourceResponse(generator, media_type="text/event-stream") + else: + + response = model_type_instance.invoke( + model='gpt-4', + credentials={ + 'openai_api_key': "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", + 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), + 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + }, + prompt_messages=[ + UserPromptMessage( + content='北京今天的天气怎么样' + ) + ], + model_parameters={ + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 1, + 'plugin_web_search': True, + }, + stop=['you'], + stream=False, + user="abc-123" + ) + + chat_response = ChatCompletionResponse(**dictify(response)) + + return chat_response + + +def run( + cfg: Dict, logging_conf: Optional[dict] = None, + started_event: mp.Event = None, +): + logging.config.dictConfig(logging_conf) # type: ignore + try: + import signal + # 跳过键盘中断,使用xoscar的信号处理 + signal.signal(signal.SIGINT, lambda *_: None) + api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {})) + api.set_app_event(started_event=started_event) + api.serve(logging_conf=logging_conf) + + async def pool_join_thread(): + await api.join() + + asyncio.run(pool_join_thread()) + except SystemExit: + logger.info("SystemExit raised, exiting") + raise diff --git a/model_providers/model_providers/core/__init__.py b/model_providers/model_providers/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/model_providers/model_providers/core/bootstrap/__init__.py b/model_providers/model_providers/core/bootstrap/__init__.py new file mode 100644 index 00000000..1d45692b --- /dev/null +++ b/model_providers/model_providers/core/bootstrap/__init__.py @@ -0,0 +1,8 @@ + +from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb +from model_providers.core.bootstrap.bootstrap_register import bootstrap_register +__all__ = [ + "bootstrap_register", + "Bootstrap", + "OpenAIBootstrapBaseWeb", +] diff --git a/model_providers/model_providers/core/bootstrap/base.py b/model_providers/model_providers/core/bootstrap/base.py new file mode 100644 index 00000000..406a27ce --- /dev/null +++ b/model_providers/model_providers/core/bootstrap/base.py @@ -0,0 +1,54 @@ +from abc import abstractmethod +from collections import deque +from fastapi import Request + + +class Bootstrap: + + """最大的任务队列""" + _MAX_ONGOING_TASKS: int = 1 + + """任务队列""" + _QUEUE: deque = deque() + + def __init__(self): + self._version = "v0.0.1" + + @classmethod + @abstractmethod + def from_config(cls, cfg=None): + return cls() + + @property + def version(self): + return self._version + + @property + def queue(self) -> deque: + return self._QUEUE + + @classmethod + async def run(cls): + raise NotImplementedError + + @classmethod + async def destroy(cls): + raise NotImplementedError + + +class OpenAIBootstrapBaseWeb(Bootstrap): + + def __init__(self): + super().__init__() + + @abstractmethod + async def list_models(self, request: Request): + pass + + @abstractmethod + async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest): + pass + + @abstractmethod + async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest): + pass diff --git a/model_providers/model_providers/core/bootstrap/bootstrap_register.py b/model_providers/model_providers/core/bootstrap/bootstrap_register.py new file mode 100644 index 00000000..ef78184a --- /dev/null +++ b/model_providers/model_providers/core/bootstrap/bootstrap_register.py @@ -0,0 +1,51 @@ +from model_providers.core.bootstrap import Bootstrap + + +class BootstrapRegister: + """ + 注册管理器 + """ + mapping = { + "bootstrap": {}, + } + + @classmethod + def register_bootstrap(cls, name): + r"""Register system bootstrap to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + print(f"register_bootstrap {name}") + + def wrap(task_cls): + assert issubclass( + task_cls, Bootstrap + ), "All tasks must inherit bootstrap class" + if name in cls.mapping["bootstrap"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["bootstrap"][name] + ) + ) + cls.mapping["bootstrap"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def get_bootstrap_class(cls, name): + return cls.mapping["bootstrap"].get(name, None) + + @classmethod + def list_bootstrap(cls): + return sorted(cls.mapping["bootstrap"].keys()) + + +bootstrap_register = BootstrapRegister() + diff --git a/model_providers/model_providers/core/bootstrap/openai_protocol.py b/model_providers/model_providers/core/bootstrap/openai_protocol.py new file mode 100644 index 00000000..690475fd --- /dev/null +++ b/model_providers/model_providers/core/bootstrap/openai_protocol.py @@ -0,0 +1,143 @@ +import time +from enum import Enum +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel, Field, root_validator +from typing_extensions import Literal + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" + TOOL = "tool_calls" + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: Literal["owner"] = "owner" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: List[ModelCard] = [] + + +class Function(BaseModel): + name: str + arguments: str + + +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: Dict[str, Any] + + +class FunctionCallDefinition(BaseModel): + name: str + + +class FunctionCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: Function + + +class FunctionAvailable(BaseModel): + type: Literal["function", "code_interpreter"] = "function" + function: Optional[FunctionDefinition] = None + + +class ChatMessage(BaseModel): + role: Role + content: str + + +class ChatCompletionMessage(BaseModel): + role: Optional[Role] = None + content: Optional[str] = None + tool_calls: Optional[List[FunctionCall]] = None + function_call: Optional[Function] = None + + +class UsageInfo(BaseModel): + prompt_tokens: int + completion_tokens: Optional[int] = None + total_tokens: int + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + tools: Optional[List[FunctionAvailable]] = None + functions: Optional[List[FunctionDefinition]] = None + function_call: Optional[FunctionCallDefinition] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[float] = None + n: int = 1 + max_tokens: Optional[int] = None + stop: Optional[list[str]] = None, + stream: Optional[bool] = False + + def to_model_parameters_dict(self, *args, **kwargs): + # 调用父类的to_dict方法,并排除tools字段 + helper.dump_model + return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs) + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatCompletionMessage + finish_reason: Finish + + +class ChatCompletionStreamResponseChoice(BaseModel): + index: int + delta: ChatCompletionMessage + finish_reason: Optional[Finish] = None + + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionStreamResponseChoice] + + +class EmbeddingsRequest(BaseModel): + input: Union[str, List[List[int]], List[int], List[str]] + model: str + encoding_format: Literal["base64", "float"] = "float" + + +class Embeddings(BaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], bytes] + index: int + + +class EmbeddingsResponse(BaseModel): + object: Literal["list"] = "list" + data: List[Embeddings] + model: str + usage: UsageInfo diff --git a/model_providers/model_providers/core/entities/provider_configuration.py b/model_providers/model_providers/core/entities/provider_configuration.py index b1248d9f..823c7cf9 100644 --- a/model_providers/model_providers/core/entities/provider_configuration.py +++ b/model_providers/model_providers/core/entities/provider_configuration.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus -# from model_providers.core.helper import encrypter +from model_providers.core.helper import encrypter from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType from model_providers.core.model_runtime.entities.provider_entities import ( @@ -22,7 +22,7 @@ from model_providers.core.model_runtime.model_providers import model_provider_fa from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider from model_providers.extensions.ext_database import db -from model_providers.models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider +from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider logger = logging.getLogger(__name__) @@ -172,20 +172,20 @@ class ProviderConfiguration(BaseModel): original_credentials = {} # encrypt credentials - # for key, value in credentials.items(): - # if key in provider_credential_secret_variables: - # # if send [__HIDDEN__] in secret input, it will be same as original value - # if value == '[__HIDDEN__]' and key in original_credentials: - # credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == '[__HIDDEN__]' and key in original_credentials: + credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( self.provider.provider, credentials ) - # for key, value in credentials.items(): - # if key in provider_credential_secret_variables: - # credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return provider_record, credentials @@ -315,11 +315,11 @@ class ProviderConfiguration(BaseModel): original_credentials = {} # decrypt credentials - # for key, value in credentials.items(): - # if key in provider_credential_secret_variables: - # # if send [__HIDDEN__] in secret input, it will be same as original value - # if value == '[__HIDDEN__]' and key in original_credentials: - # credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == '[__HIDDEN__]' and key in original_credentials: + credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, @@ -328,9 +328,9 @@ class ProviderConfiguration(BaseModel): credentials=credentials ) - # for key, value in credentials.items(): - # if key in provider_credential_secret_variables: - # credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return provider_model_record, credentials @@ -481,10 +481,10 @@ class ProviderConfiguration(BaseModel): ) # Obfuscate provider credentials - # copy_credentials = credentials.copy() - # for key, value in copy_credentials.items(): - # if key in credential_secret_variables: - # copy_credentials[key] = encrypter.obfuscated_token(value) + copy_credentials = credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) return copy_credentials diff --git a/model_providers/model_providers/core/entities/provider_entities.py b/model_providers/model_providers/core/entities/provider_entities.py index 79c96436..08013af1 100644 --- a/model_providers/model_providers/core/entities/provider_entities.py +++ b/model_providers/model_providers/core/entities/provider_entities.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.models.provider import ProviderQuotaType +from models.provider import ProviderQuotaType class QuotaUnit(Enum): diff --git a/model_providers/model_providers/core/errors/__init__.py b/model_providers/model_providers/core/errors/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/model_providers/model_providers/core/errors/error.py b/model_providers/model_providers/core/errors/error.py deleted file mode 100644 index fddfb345..00000000 --- a/model_providers/model_providers/core/errors/error.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - - -class LLMError(Exception): - """Base class for all LLM exceptions.""" - description: Optional[str] = None - - def __init__(self, description: Optional[str] = None) -> None: - self.description = description - - -class LLMBadRequestError(LLMError): - """Raised when the LLM returns bad request.""" - description = "Bad Request" - - -class ProviderTokenNotInitError(Exception): - """ - Custom exception raised when the provider token is not initialized. - """ - description = "Provider Token Not Init" - - def __init__(self, *args, **kwargs): - self.description = args[0] if args else self.description - - -class QuotaExceededError(Exception): - """ - Custom exception raised when the quota for a provider has been exceeded. - """ - description = "Quota Exceeded" - - -class ModelCurrentlyNotSupportError(Exception): - """ - Custom exception raised when the model not support - """ - description = "Model Currently Not Support" diff --git a/model_providers/model_providers/core/file/__init__.py b/model_providers/model_providers/core/file/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/model_providers/model_providers/core/file/file_obj.py b/model_providers/model_providers/core/file/file_obj.py deleted file mode 100644 index ffcd4013..00000000 --- a/model_providers/model_providers/core/file/file_obj.py +++ /dev/null @@ -1,90 +0,0 @@ -import enum -from typing import Optional - -from pydantic import BaseModel - -from model_providers.core.file.upload_file_parser import UploadFileParser -from model_providers.core.model_runtime.entities.message_entities import ImagePromptMessageContent -from model_providers.extensions.ext_database import db -from model_providers.models.model import UploadFile - - -class FileType(enum.Enum): - IMAGE = 'image' - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - -class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - -class FileObj(BaseModel): - id: Optional[str] - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - url: Optional[str] - upload_file_id: Optional[str] - file_config: dict - - @property - def data(self) -> Optional[str]: - return self._get_data() - - @property - def preview_url(self) -> Optional[str]: - return self._get_data(force_url=True) - - @property - def prompt_message_content(self) -> ImagePromptMessageContent: - if self.type == FileType.IMAGE: - image_config = self.file_config.get('image') - - return ImagePromptMessageContent( - data=self.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW - ) - - def _get_data(self, force_url: bool = False) -> Optional[str]: - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.upload_file_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url - ) - - return None diff --git a/model_providers/model_providers/core/file/message_file_parser.py b/model_providers/model_providers/core/file/message_file_parser.py deleted file mode 100644 index 97bc1070..00000000 --- a/model_providers/model_providers/core/file/message_file_parser.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Optional, Union - -import requests - -from model_providers.core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType -from model_providers.extensions.ext_database import db -from model_providers.models.account import Account -from model_providers.models.model import AppModelConfig, EndUser, MessageFile, UploadFile -from services.file_service import IMAGE_EXTENSIONS - - -class MessageFileParser: - - def __init__(self, tenant_id: str, app_id: str) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - - def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, - user: Union[Account, EndUser]) -> list[FileObj]: - """ - validate and transform files arg - - :param files: - :param app_model_config: - :param user: - :return: - """ - file_upload_config = app_model_config.file_upload_dict - - for file in files: - if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) - - # validate files - new_files = [] - for file_type, file_objs in type_file_objs.items(): - if file_type == FileType.IMAGE: - # parse and validate files - image_config = file_upload_config.get('image') - - # check if image file feature is enabled - if not image_config['enabled']: - continue - - # Validate number of files - if len(files) > image_config['number_limits']: - raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") - - for file_obj in file_objs: - # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') - - # Validate file type - if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') - - if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: - # check remote url valid and is image - result, error = self._check_image_remote_url(file_obj.url) - if result is False: - raise ValueError(error) - elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: - # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.upload_file_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) - - # check upload file is belong to tenant and user - if not upload_file: - raise ValueError('Invalid upload file') - - new_files.append(file_obj) - - # return all file objs - return new_files - - def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: - """ - transform message files - - :param files: - :param app_model_config: - :return: - """ - # transform files to file objs - type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict) - - # return all file objs - return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_upload_config: dict) -> dict[FileType, list[FileObj]]: - """ - transform files to file objs - - :param files: - :param file_upload_config: - :return: - """ - type_file_objs: dict[FileType, list[FileObj]] = { - # Currently only support image - FileType.IMAGE: [] - } - - if not files: - return type_file_objs - - # group by file type and convert file args or message files to FileObj - for file in files: - if isinstance(file, MessageFile): - if file.belongs_to == FileBelongsTo.ASSISTANT.value: - continue - - file_obj = self._to_file_obj(file, file_upload_config) - if file_obj.type not in type_file_objs: - continue - - type_file_objs[file_obj.type].append(file_obj) - - return type_file_objs - - def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: - """ - transform file to file obj - - :param file: - :return: - """ - if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) - return FileObj( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), - transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - file_config=file_upload_config - ) - else: - return FileObj( - id=file.id, - tenant_id=self.tenant_id, - type=FileType.value_of(file.type), - transfer_method=FileTransferMethod.value_of(file.transfer_method), - url=file.url, - upload_file_id=file.upload_file_id or None, - file_config=file_upload_config - ) - - def _check_image_remote_url(self, url): - try: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - - response = requests.head(url, headers=headers, allow_redirects=True) - if response.status_code == 200: - return True, "" - else: - return False, "URL does not exist." - except requests.RequestException as e: - return False, f"Error checking URL: {e}" diff --git a/model_providers/model_providers/core/file/tool_file_parser.py b/model_providers/model_providers/core/file/tool_file_parser.py deleted file mode 100644 index ea8605ac..00000000 --- a/model_providers/model_providers/core/file/tool_file_parser.py +++ /dev/null @@ -1,8 +0,0 @@ -tool_file_manager = { - 'manager': None -} - -class ToolFileParser: - @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file diff --git a/model_providers/model_providers/core/file/upload_file_parser.py b/model_providers/model_providers/core/file/upload_file_parser.py deleted file mode 100644 index de261ded..00000000 --- a/model_providers/model_providers/core/file/upload_file_parser.py +++ /dev/null @@ -1,79 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from typing import Optional - -from flask import current_app - -from model_providers.extensions.ext_storage import storage - -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url: - return cls.get_signed_temp_image_url(upload_file) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') - return None - - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' - - @classmethod - def get_signed_temp_image_url(cls, upload_file) -> str: - """ - get signed url from upload file - - :param upload_file: UploadFile object - :return: - """ - base_url = current_app.config.get('FILES_URL') - image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview' - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= 300 # expired after 5 minutes diff --git a/model_providers/model_providers/core/helper/__init__.py b/model_providers/model_providers/core/helper/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/model_providers/model_providers/core/helper/model_provider_cache.py b/model_providers/model_providers/core/helper/model_provider_cache.py deleted file mode 100644 index 8cd4a055..00000000 --- a/model_providers/model_providers/core/helper/model_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from model_providers.extensions.ext_redis import redis_client - - -class ProviderCredentialsCacheType(Enum): - PROVIDER = "provider" - MODEL = "provider_model" - - -class ProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return cached_provider_credentials - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/model_providers/model_providers/core/hosting_configuration.py b/model_providers/model_providers/core/hosting_configuration.py deleted file mode 100644 index bdc1941e..00000000 --- a/model_providers/model_providers/core/hosting_configuration.py +++ /dev/null @@ -1,250 +0,0 @@ -from typing import Optional - -from flask import Config, Flask -from pydantic import BaseModel - -from model_providers.core.entities.provider_entities import QuotaUnit, RestrictModel -from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.models.provider import ProviderQuotaType - - -class HostingQuota(BaseModel): - quota_type: ProviderQuotaType - restrict_models: list[RestrictModel] = [] - - -class TrialHostingQuota(HostingQuota): - quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL - quota_limit: int = 0 - """Quota limit for the hosting provider models. -1 means unlimited.""" - - -class PaidHostingQuota(HostingQuota): - quota_type: ProviderQuotaType = ProviderQuotaType.PAID - - -class FreeHostingQuota(HostingQuota): - quota_type: ProviderQuotaType = ProviderQuotaType.FREE - - -class HostingProvider(BaseModel): - enabled: bool = False - credentials: Optional[dict] = None - quota_unit: Optional[QuotaUnit] = None - quotas: list[HostingQuota] = [] - - -class HostedModerationConfig(BaseModel): - enabled: bool = False - providers: list[str] = [] - - -class HostingConfiguration: - provider_map: dict[str, HostingProvider] = {} - moderation_config: HostedModerationConfig = None - - def init_app(self, app: Flask) -> None: - config = app.config - - if config.get('EDITION') != 'CLOUD': - return - - self.provider_map["azure_openai"] = self.init_azure_openai(config) - self.provider_map["openai"] = self.init_openai(config) - self.provider_map["anthropic"] = self.init_anthropic(config) - self.provider_map["minimax"] = self.init_minimax(config) - self.provider_map["spark"] = self.init_spark(config) - self.provider_map["zhipuai"] = self.init_zhipuai(config) - - self.moderation_config = self.init_moderation_config(config) - - def init_azure_openai(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.TIMES - if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): - credentials = { - "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), - "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" - } - - quotas = [] - hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=[ - RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] - ) - quotas.append(trial_quota) - - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_openai(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.CREDITS - quotas = [] - - if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) - trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) - quotas.append(trial_quota) - - if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): - paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) - quotas.append(paid_quota) - - if len(quotas) > 0: - credentials = { - "openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"), - } - - if app_config.get("HOSTED_OPENAI_API_BASE"): - credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE") - - if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): - credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") - - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_anthropic(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.TOKENS - quotas = [] - - if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) - quotas.append(trial_quota) - - if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): - paid_quota = PaidHostingQuota() - quotas.append(paid_quota) - - if len(quotas) > 0: - credentials = { - "anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"), - } - - if app_config.get("HOSTED_ANTHROPIC_API_BASE"): - credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") - - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_minimax(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_MINIMAX_ENABLED"): - quotas = [FreeHostingQuota()] - - return HostingProvider( - enabled=True, - credentials=None, # use credentials from the provider - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_spark(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_SPARK_ENABLED"): - quotas = [FreeHostingQuota()] - - return HostingProvider( - enabled=True, - credentials=None, # use credentials from the provider - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_zhipuai(self, app_config: Config) -> HostingProvider: - quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_ZHIPUAI_ENABLED"): - quotas = [FreeHostingQuota()] - - return HostingProvider( - enabled=True, - credentials=None, # use credentials from the provider - quota_unit=quota_unit, - quotas=quotas - ) - - return HostingProvider( - enabled=False, - quota_unit=quota_unit, - ) - - def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): - return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') - ) - - return HostedModerationConfig( - enabled=False - ) - - @staticmethod - def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: - models_str = app_config.get(env_var) - models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - diff --git a/model_providers/model_providers/core/model_manager.py b/model_providers/model_providers/core/model_manager.py deleted file mode 100644 index f1b78293..00000000 --- a/model_providers/model_providers/core/model_manager.py +++ /dev/null @@ -1,257 +0,0 @@ -from collections.abc import Generator -from typing import IO, Optional, Union, cast - -from model_providers.core.entities.provider_configuration import ProviderModelBundle -from model_providers.core.errors.error import ProviderTokenNotInitError -from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.rerank_entities import RerankResult -from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel -from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel -from model_providers.core.provider_manager import ProviderManager - - -class ModelInstance: - """ - Model instance class - """ - - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: - self._provider_model_bundle = provider_model_bundle - self.model = model - self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) - self.model_type_instance = self._provider_model_bundle.model_type_instance - - def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: - """ - Fetch credentials from provider model bundle - :param provider_model_bundle: provider model bundle - :param model: model name - :return: - """ - credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=provider_model_bundle.model_type_instance.model_type, - model=model - ) - - if credentials is None: - raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") - - return credentials - - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ - -> Union[LLMResult, Generator]: - """ - Invoke large language model - - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - if not isinstance(self.model_type_instance, LargeLanguageModel): - raise Exception("Model type instance is not LargeLanguageModel") - - self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks - ) - - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - """ - Invoke large language model - - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ - if not isinstance(self.model_type_instance, TextEmbeddingModel): - raise Exception("Model type instance is not TextEmbeddingModel") - - self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - texts=texts, - user=user - ) - - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: - """ - Invoke rerank model - - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - if not isinstance(self.model_type_instance, RerankModel): - raise Exception("Model type instance is not RerankModel") - - self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user - ) - - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: - """ - Invoke moderation model - - :param text: text to moderate - :param user: unique user id - :return: false if text is safe, true otherwise - """ - if not isinstance(self.model_type_instance, ModerationModel): - raise Exception("Model type instance is not ModerationModel") - - self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - text=text, - user=user - ) - - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: - """ - Invoke large language model - - :param file: audio file - :param user: unique user id - :return: text for given audio file - """ - if not isinstance(self.model_type_instance, Speech2TextModel): - raise Exception("Model type instance is not Speech2TextModel") - - self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - file=file, - user=user - ) - - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ - -> str: - """ - Invoke large language tts model - - :param content_text: text content to be translated - :param tenant_id: user tenant id - :param user: unique user id - :param voice: model timbre - :param streaming: output is streaming - :return: text for given audio file - """ - if not isinstance(self.model_type_instance, TTSModel): - raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, - streaming=streaming - ) - - def get_tts_voices(self, language: str) -> list: - """ - Invoke large language tts model voices - - :param language: tts language - :return: tts model voices - """ - if not isinstance(self.model_type_instance, TTSModel): - raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language - ) - - -class ModelManager: - def __init__(self) -> None: - self._provider_manager = ProviderManager() - - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: - """ - Get model instance - :param tenant_id: tenant id - :param provider: provider name - :param model_type: model type - :param model: model name - :return: - """ - if not provider: - return self.get_default_model_instance(tenant_id, model_type) - provider_model_bundle = self._provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=provider, - model_type=model_type - ) - - return ModelInstance(provider_model_bundle, model) - - def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: - """ - Get default model instance - :param tenant_id: tenant id - :param model_type: model type - :return: - """ - default_model_entity = self._provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type - ) - - if not default_model_entity: - raise ProviderTokenNotInitError(f"Default model not found for {model_type}") - - return self.get_model_instance( - tenant_id=tenant_id, - provider=default_model_entity.provider.provider, - model_type=model_type, - model=default_model_entity.model - ) diff --git a/model_providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py b/model_providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py index e408e07e..4475b16e 100644 --- a/model_providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/model_providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -3,8 +3,7 @@ import copy from functools import reduce from io import BytesIO from typing import Optional - -from flask import Response, stream_with_context +from fastapi.responses import StreamingResponse from openai import AzureOpenAI from pydub import AudioSegment @@ -37,15 +36,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [d['value'] for d in + self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice)), - status=200, mimetype=f'audio/{audio_type}') + return StreamingResponse(self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice), media_type='text/event-stream') + else: return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) @@ -68,7 +68,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse: """ _tts_invoke text2speech model @@ -103,7 +103,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + return StreamingResponse(buffer, media_type=f"audio/{audio_type}") except Exception as ex: raise InvokeBadRequestError(str(ex)) @@ -160,7 +160,6 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in TTS_BASE_MODELS: diff --git a/model_providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py b/model_providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py index b21f7e29..c44b6ca7 100644 --- a/model_providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py +++ b/model_providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py @@ -3,10 +3,9 @@ from functools import reduce from io import BytesIO from typing import Optional -from flask import Response, stream_with_context from openai import OpenAI from pydub import AudioSegment - +from fastapi.responses import StreamingResponse from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel @@ -37,12 +36,11 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice)), - status=200, mimetype=f'audio/{audio_type}') + return StreamingResponse(self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice), media_type='text/event-stream') else: return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) @@ -65,7 +63,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse: """ _tts_invoke text2speech model @@ -100,7 +98,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + return StreamingResponse(buffer, media_type=f"audio/{audio_type}") except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/model_providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py b/model_providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py index 76b2a0cb..7818ec3f 100644 --- a/model_providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/model_providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -4,9 +4,8 @@ from io import BytesIO from typing import Optional import dashscope -from flask import Response, stream_with_context from pydub import AudioSegment - +from fastapi.responses import StreamingResponse from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel @@ -37,12 +36,11 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - tenant_id=tenant_id)), - status=200, mimetype=f'audio/{audio_type}') + return StreamingResponse(self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice), media_type='text/event-stream') else: return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) @@ -101,7 +99,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + return StreamingResponse(buffer, media_type=f"audio/{audio_type}") except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/model_providers/model_providers/core/provider_manager.py b/model_providers/model_providers/core/provider_manager.py deleted file mode 100644 index 9bb6591b..00000000 --- a/model_providers/model_providers/core/provider_manager.py +++ /dev/null @@ -1,761 +0,0 @@ -import json -from collections import defaultdict -from json import JSONDecodeError -from typing import Optional - -from sqlalchemy.exc import IntegrityError - -from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity -from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle -from model_providers.core.entities.provider_entities import ( - CustomConfiguration, - CustomModelConfiguration, - CustomProviderConfiguration, - QuotaConfiguration, - SystemConfiguration, -) -# from model_providers.core.helper import encrypter -from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) -from model_providers.core.model_runtime.model_providers import model_provider_factory -from model_providers.extensions import ext_hosting_provider -from model_providers.extensions.ext_database import db -from model_providers.models.provider import ( - Provider, - ProviderModel, - ProviderQuotaType, - ProviderType, - TenantDefaultModel, - TenantPreferredModelProvider, -) - - -class ProviderManager: - """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. - """ - def __init__(self) -> None: - self.decoding_rsa_key = None - self.decoding_cipher_rsa = None - - def get_configurations(self, tenant_id: str) -> ProviderConfigurations: - """ - Get model provider configurations. - - Construct ProviderConfiguration objects for each provider - Including: - 1. Basic information of the provider - 2. Hosting configuration information, including: - (1. Whether to enable (support) hosting type, if enabled, the following information exists - (2. List of hosting type provider configurations - (including quota type, quota limit, current remaining quota, etc.) - (3. The current hosting type in use (whether there is a quota or not) - paid quotas > provider free quotas > hosting trial quotas - (4. Unified credentials for hosting providers - 3. Custom configuration information, including: - (1. Whether to enable (support) custom type, if enabled, the following information exists - (2. Custom provider configuration (including credentials) - (3. List of custom provider model configurations (including credentials) - 4. Hosting/custom preferred provider type. - Provide methods: - - Get the current configuration (including credentials) - - Get the availability and status of the hosting configuration: active available, - quota_exceeded insufficient quota, unsupported hosting - - Get the availability of custom configuration - Custom provider available conditions: - (1. custom provider credentials available - (2. at least one custom model credentials available - - Verify, update, and delete custom provider configuration - - Verify, update, and delete custom provider model configuration - - Get the list of available models (optional provider filtering, model type filtering) - Append custom provider models to the list - - Get provider instance - - Switch selection priority - - :param tenant_id: - :return: - """ - # Get all provider records of the workspace - provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) - - # Initialize trial provider records if not exist - provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict - ) - - # Get all provider model records of the workspace - provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id) - - # Get all provider entities - provider_entities = model_provider_factory.get_providers() - - # Get All preferred provider types of the workspace - provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id - ) - - # Construct ProviderConfiguration objects for each provider - for provider_entity in provider_entities: - provider_name = provider_entity.provider - - provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider) - if not provider_records: - provider_records = [] - - provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider) - if not provider_model_records: - provider_model_records = [] - - # Convert to custom configuration - custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records - ) - - # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) - - # Get preferred provider type - preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) - - if preferred_provider_type_record: - preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) - else: - if custom_configuration.provider or custom_configuration.models: - preferred_provider_type = ProviderType.CUSTOM - elif system_configuration.enabled: - preferred_provider_type = ProviderType.SYSTEM - else: - preferred_provider_type = ProviderType.CUSTOM - - using_provider_type = preferred_provider_type - if preferred_provider_type == ProviderType.SYSTEM: - if not system_configuration.enabled: - using_provider_type = ProviderType.CUSTOM - - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break - - if not has_valid_quota: - using_provider_type = ProviderType.CUSTOM - else: - if not custom_configuration.provider and not custom_configuration.models: - if system_configuration.enabled: - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break - - if has_valid_quota: - using_provider_type = ProviderType.SYSTEM - - provider_configuration = ProviderConfiguration( - tenant_id=tenant_id, - provider=provider_entity, - preferred_provider_type=preferred_provider_type, - using_provider_type=using_provider_type, - system_configuration=system_configuration, - custom_configuration=custom_configuration - ) - - provider_configurations[provider_name] = provider_configuration - - # Return the encapsulated object - return provider_configurations - - def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle: - """ - Get provider model bundle. - :param tenant_id: workspace id - :param provider: provider name - :param model_type: model type - :return: - """ - provider_configurations = self.get_configurations(tenant_id) - - # get provider instance - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - provider_instance = provider_configuration.get_provider_instance() - model_type_instance = provider_instance.get_model_instance(model_type) - - return ProviderModelBundle( - configuration=provider_configuration, - provider_instance=provider_instance, - model_type_instance=model_type_instance - ) - - def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: - """ - Get default model. - - :param tenant_id: workspace id - :param model_type: model type - :return: - """ - # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ - .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() - - # If it does not exist, get the first available provider model from get_configurations - # and update the TenantDefaultModel record - if not default_model: - # Get provider configurations - provider_configurations = self.get_configurations(tenant_id) - - # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) - - if available_models: - found = False - for available_model in available_models: - if available_model.model == "gpt-3.5-turbo-1106": - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() - found = True - break - - if not found: - available_model = available_models[0] - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() - - if not default_model: - return None - - provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name) - provider_schema = provider_instance.get_provider_schema() - - return DefaultModelEntity( - model=default_model.model_name, - model_type=model_type, - provider=DefaultModelProviderEntity( - provider=provider_schema.provider, - label=provider_schema.label, - icon_small=provider_schema.icon_small, - icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) - ) - - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: - """ - Update default model record. - - :param tenant_id: workspace id - :param model_type: model type - :param provider: provider name - :param model: model name - :return: - """ - provider_configurations = self.get_configurations(tenant_id) - if provider not in provider_configurations: - raise ValueError(f"Provider {provider} does not exist.") - - # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) - - # check if the model is exist in available models - model_names = [model.model for model in available_models] - if model not in model_names: - raise ValueError(f"Model {model} does not exist.") - - # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ - .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() - - # create or update TenantDefaultModel record - if default_model: - # update default model - default_model.provider_name = provider - default_model.model_name = model - db.session.commit() - else: - # create default model - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.value, - provider_name=provider, - model_name=model, - ) - db.session.add(default_model) - db.session.commit() - - return default_model - - def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() - - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - provider_name_to_provider_records_dict[provider.provider_name].append(provider) - - return provider_name_to_provider_records_dict - - def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: - """ - Get all provider model records of the workspace. - - :param tenant_id: workspace id - :return: - """ - # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() - - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - - return provider_name_to_provider_model_records_dict - - def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: - """ - Get All preferred provider types of the workspace. - - :param tenant_id: - :return: - """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - - return provider_name_to_preferred_provider_type_records_dict - - def _init_trial_provider_records(self, tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: - """ - Initialize trial provider records if not exists. - - :param tenant_id: workspace id - :param provider_name_to_provider_records_dict: provider name to provider records dict - :return: - """ - # Get hosting configuration - hosting_configuration = ext_hosting_provider.hosting_configuration - - for provider_name, configuration in hosting_configuration.provider_map.items(): - if not configuration.enabled: - continue - - provider_records = provider_name_to_provider_records_dict.get(provider_name) - if not provider_records: - provider_records = [] - - provider_quota_to_provider_record_dict = dict() - for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: - continue - - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record - - for quota in configuration.quotas: - if quota.quota_type == ProviderQuotaType.TRIAL: - # Init trial provider records if not exists - if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: - try: - provider_record = Provider( - tenant_id=tenant_id, - provider_name=provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, - quota_used=0, - is_valid=True - ) - db.session.add(provider_record) - db.session.commit() - except IntegrityError: - db.session.rollback() - provider_record = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() - - if provider_record and not provider_record.is_valid: - provider_record.is_valid = True - db.session.commit() - - provider_name_to_provider_records_dict[provider_name].append(provider_record) - - return provider_name_to_provider_records_dict - - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: - """ - Convert to custom configuration. - - :param tenant_id: workspace id - :param provider_entity: provider entity - :param provider_records: provider records - :param provider_model_records: provider model records - :return: - """ - # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( - provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] - ) - - # Get custom provider record - custom_provider_record = None - for provider_record in provider_records: - if provider_record.provider_type == ProviderType.SYSTEM.value: - continue - - if not provider_record.encrypted_config: - continue - - custom_provider_record = provider_record - - # Get custom provider credentials - custom_provider_configuration = None - if custom_provider_record: - provider_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - # Get cached provider credentials - cached_provider_credentials = provider_credentials_cache.get() - - if not cached_provider_credentials: - try: - # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } - else: - provider_credentials = json.loads(custom_provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} - - # # Get decoding rsa key and cipher for decrypting credentials - # if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - # self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - - # for variable in provider_credential_secret_variables: - # if variable in provider_credentials: - # try: - # provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - # provider_credentials.get(variable), - # self.decoding_rsa_key, - # self.decoding_cipher_rsa - # ) - # except ValueError: - # pass - - # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) - else: - provider_credentials = cached_provider_credentials - - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) - - # Get provider model credential secret variables - model_credential_secret_variables = self._extract_secret_variables( - provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] - ) - - # Get custom provider model credentials - custom_model_configurations = [] - for provider_model_record in provider_model_records: - if not provider_model_record.encrypted_config: - continue - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL - ) - - # Get cached provider model credentials - cached_provider_model_credentials = provider_model_credentials_cache.get() - - if not cached_provider_model_credentials: - try: - provider_model_credentials = json.loads(provider_model_record.encrypted_config) - except JSONDecodeError: - continue - - # # Get decoding rsa key and cipher for decrypting credentials - # if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - # self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - # - # for variable in model_credential_secret_variables: - # if variable in provider_model_credentials: - # try: - # provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( - # provider_model_credentials.get(variable), - # self.decoding_rsa_key, - # self.decoding_cipher_rsa - # ) - # except ValueError: - # pass - - # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) - else: - provider_model_credentials = cached_provider_model_credentials - - custom_model_configurations.append( - CustomModelConfiguration( - model=provider_model_record.model_name, - model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials - ) - ) - - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) - - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: - """ - Convert to system configuration. - - :param tenant_id: workspace id - :param provider_entity: provider entity - :param provider_records: provider records - :return: - """ - # Get hosting configuration - hosting_configuration = ext_hosting_provider.hosting_configuration - - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) - - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) - - # Convert provider_records to dict - quota_type_to_provider_records_dict = dict() - for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: - continue - - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record - - quota_configurations = [] - for provider_quota in provider_hosting_configuration.quotas: - if provider_quota.quota_type not in quota_type_to_provider_records_dict: - if provider_quota.quota_type == ProviderQuotaType.FREE: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, - quota_used=0, - quota_limit=0, - is_valid=False, - restrict_models=provider_quota.restrict_models - ) - else: - continue - else: - provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] - - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models - ) - - quota_configurations.append(quota_configuration) - - if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) - - current_quota_type = self._choice_current_using_quota_type(quota_configurations) - - current_using_credentials = provider_hosting_configuration.credentials - if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) - - if provider_record: - provider_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - # Get cached provider credentials - cached_provider_credentials = provider_credentials_cache.get() - - if not cached_provider_credentials: - try: - provider_credentials = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} - - # # Get provider credential secret variables - # provider_credential_secret_variables = self._extract_secret_variables( - # provider_entity.provider_credential_schema.credential_form_schemas - # if provider_entity.provider_credential_schema else [] - # ) - - # # Get decoding rsa key and cipher for decrypting credentials - # if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - # self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - # - # for variable in provider_credential_secret_variables: - # if variable in provider_credentials: - # try: - # provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - # provider_credentials.get(variable), - # self.decoding_rsa_key, - # self.decoding_cipher_rsa - # ) - # except ValueError: - # pass - - current_using_credentials = provider_credentials - - # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) - else: - current_using_credentials = cached_provider_credentials - else: - current_using_credentials = {} - quota_configurations = [] - - return SystemConfiguration( - enabled=True, - current_quota_type=current_quota_type, - quota_configurations=quota_configurations, - credentials=current_using_credentials - ) - - def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: - """ - Choice current using quota type. - paid quotas > provider free quotas > hosting trial quotas - If there is still quota for the corresponding quota type according to the sorting, - - :param quota_configurations: - :return: - """ - # convert to dict - quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations - } - - last_quota_configuration = None - for quota_type in [ProviderQuotaType.PAID, ProviderQuotaType.FREE, ProviderQuotaType.TRIAL]: - if quota_type in quota_type_to_quota_configuration_dict: - last_quota_configuration = quota_type_to_quota_configuration_dict[quota_type] - if last_quota_configuration.is_valid: - return quota_type - - if last_quota_configuration: - return last_quota_configuration.quota_type - - raise ValueError('No quota type available') - - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: - """ - Extract secret input form variables. - - :param credential_form_schemas: - :return: - """ - secret_input_form_variables = [] - for credential_form_schema in credential_form_schemas: - if credential_form_schema.type == FormType.SECRET_INPUT: - secret_input_form_variables.append(credential_form_schema.variable) - - return secret_input_form_variables diff --git a/model_providers/model_providers/core/tools/prompt/template.py b/model_providers/model_providers/core/tools/prompt/template.py deleted file mode 100644 index 3d355922..00000000 --- a/model_providers/model_providers/core/tools/prompt/template.py +++ /dev/null @@ -1,102 +0,0 @@ -ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. - -{{instruction}} - -You have access to the following tools: - -{{tools}} - -Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: "Final Answer" or {{tool_names}} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{ - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT -} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{ - "action": "Final Answer", - "action_input": "Final response to human" -} -``` - -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Question: {{query}} -Thought: {{agent_scratchpad}}""" - -ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} -Thought:""" - -ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. - -{{instruction}} - -You have access to the following tools: - -{{tools}} - -Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: "Final Answer" or {{tool_names}} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{ - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT -} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{ - "action": "Final Answer", - "action_input": "Final response to human" -} -``` - -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -""" - -ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" - -REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES - }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } - } -} \ No newline at end of file diff --git a/model_providers/model_providers/core/utils/generic.py b/model_providers/model_providers/core/utils/generic.py new file mode 100644 index 00000000..b93b0c57 --- /dev/null +++ b/model_providers/model_providers/core/utils/generic.py @@ -0,0 +1,21 @@ +import json +from typing import TYPE_CHECKING, Any, Dict + + +if TYPE_CHECKING: + from pydantic import BaseModel + + +def dictify(data: "BaseModel") -> Dict[str, Any]: + try: # pydantic v2 + return data.model_dump(exclude_unset=True) + except Exception: # pydantic v1 + return data.dict(exclude_unset=True) + + +def jsonify(data: "BaseModel") -> str: + try: # pydantic v2 + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except Exception: # pydantic v1 + return data.json(exclude_unset=True, ensure_ascii=False) + diff --git a/model_providers/model_providers/core/utils/json_dumps.py b/model_providers/model_providers/core/utils/json_dumps.py new file mode 100644 index 00000000..041615ce --- /dev/null +++ b/model_providers/model_providers/core/utils/json_dumps.py @@ -0,0 +1,12 @@ +import orjson +import os +from pydantic import BaseModel + + +def json_dumps(o): + def _default(obj): + if isinstance(obj, BaseModel): + return obj.dict() + raise TypeError + + return orjson.dumps(o, default=_default) diff --git a/model_providers/model_providers/ext_hosting_provider.py b/model_providers/model_providers/ext_hosting_provider.py deleted file mode 100644 index 0213c2b7..00000000 --- a/model_providers/model_providers/ext_hosting_provider.py +++ /dev/null @@ -1,9 +0,0 @@ -from flask import Flask - -from model_providers.core.hosting_configuration import HostingConfiguration - -hosting_configuration = HostingConfiguration() - - -def init_app(app: Flask): - hosting_configuration.init_app(app) diff --git a/model_providers/model_providers/extensions/ext_database.py b/model_providers/model_providers/extensions/ext_database.py deleted file mode 100644 index 9121c6ea..00000000 --- a/model_providers/model_providers/extensions/ext_database.py +++ /dev/null @@ -1,7 +0,0 @@ -from flask_sqlalchemy import SQLAlchemy - -db = SQLAlchemy() - - -def init_app(app): - db.init_app(app) diff --git a/model_providers/model_providers/extensions/ext_hosting_provider.py b/model_providers/model_providers/extensions/ext_hosting_provider.py deleted file mode 100644 index 0213c2b7..00000000 --- a/model_providers/model_providers/extensions/ext_hosting_provider.py +++ /dev/null @@ -1,9 +0,0 @@ -from flask import Flask - -from model_providers.core.hosting_configuration import HostingConfiguration - -hosting_configuration = HostingConfiguration() - - -def init_app(app: Flask): - hosting_configuration.init_app(app) diff --git a/model_providers/model_providers/extensions/ext_storage.py b/model_providers/model_providers/extensions/ext_storage.py index 3ce9935e..be85290f 100644 --- a/model_providers/model_providers/extensions/ext_storage.py +++ b/model_providers/model_providers/extensions/ext_storage.py @@ -6,7 +6,6 @@ from typing import Union import boto3 from botocore.exceptions import ClientError -from flask import Flask class Storage: @@ -16,21 +15,21 @@ class Storage: self.client = None self.folder = None - def init_app(self, app: Flask): - self.storage_type = app.config.get('STORAGE_TYPE') + def init_config(self, config: dict): + self.storage_type = config.get('STORAGE_TYPE') if self.storage_type == 's3': - self.bucket_name = app.config.get('S3_BUCKET_NAME') + self.bucket_name = config.get('S3_BUCKET_NAME') self.client = boto3.client( 's3', - aws_secret_access_key=app.config.get('S3_SECRET_KEY'), - aws_access_key_id=app.config.get('S3_ACCESS_KEY'), - endpoint_url=app.config.get('S3_ENDPOINT'), - region_name=app.config.get('S3_REGION') + aws_secret_access_key=config.get('S3_SECRET_KEY'), + aws_access_key_id=config.get('S3_ACCESS_KEY'), + endpoint_url=config.get('S3_ENDPOINT'), + region_name=config.get('S3_REGION') ) else: - self.folder = app.config.get('STORAGE_LOCAL_PATH') + self.folder = config.get('STORAGE_LOCAL_PATH') if not os.path.isabs(self.folder): - self.folder = os.path.join(app.root_path, self.folder) + self.folder = os.path.join(config.get('root_path'), self.folder) def save(self, filename, data): if self.storage_type == 's3': @@ -140,5 +139,3 @@ class Storage: storage = Storage() -def init_app(app: Flask): - storage.init_app(app) diff --git a/model_providers/model_providers/models/__init__.py b/model_providers/model_providers/models/__init__.py deleted file mode 100644 index 44d37d30..00000000 --- a/model_providers/model_providers/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# -*- coding:utf-8 -*- \ No newline at end of file diff --git a/model_providers/model_providers/models/model.py b/model_providers/model_providers/models/model.py deleted file mode 100644 index a050561b..00000000 --- a/model_providers/model_providers/models/model.py +++ /dev/null @@ -1,28 +0,0 @@ -from sqlalchemy import Float, text -from sqlalchemy.dialects.postgresql import UUID -from model_providers.extensions.ext_database import db - - -class UploadFile(db.Model): - __tablename__ = 'upload_files' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='upload_file_pkey'), - db.Index('upload_file_tenant_idx', 'tenant_id') - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - storage_type = db.Column(db.String(255), nullable=False) - key = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - size = db.Column(db.Integer, nullable=False) - extension = db.Column(db.String(255), nullable=False) - mime_type = db.Column(db.String(255), nullable=True) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) - created_by = db.Column(UUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - used_by = db.Column(UUID, nullable=True) - used_at = db.Column(db.DateTime, nullable=True) - hash = db.Column(db.String(255), nullable=True) - diff --git a/model_providers/model_providers/models/provider.py b/model_providers/model_providers/models/provider.py deleted file mode 100644 index 9adc394a..00000000 --- a/model_providers/model_providers/models/provider.py +++ /dev/null @@ -1,160 +0,0 @@ -from enum import Enum - -from sqlalchemy.dialects.postgresql import UUID - -from model_providers.extensions.ext_database import db - - -class ProviderType(Enum): - CUSTOM = 'custom' - SYSTEM = 'system' - - @staticmethod - def value_of(value): - for member in ProviderType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class ProviderQuotaType(Enum): - PAID = 'paid' - """hosted paid quota""" - - FREE = 'free' - """third-party free quota""" - - TRIAL = 'trial' - """hosted trial quota""" - - @staticmethod - def value_of(value): - for member in ProviderQuotaType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class Provider(db.Model): - """ - Provider model representing the API providers and their configurations. - """ - __tablename__ = 'providers' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_pkey'), - db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - last_used = db.Column(db.DateTime, nullable=True) - - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) - - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - - def __repr__(self): - return f"" - - @property - def token_is_set(self): - """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ - return self.encrypted_config is not None - - @property - def is_enabled(self): - """ - Returns True if the provider is enabled. - """ - if self.provider_type == ProviderType.SYSTEM.value: - return self.is_valid - else: - return self.is_valid and self.token_is_set - - -class ProviderModel(db.Model): - """ - Provider model representing the API provider_models and their configurations. - """ - __tablename__ = 'provider_models' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_pkey'), - db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - - -class TenantDefaultModel(db.Model): - __tablename__ = 'tenant_default_models' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), - db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - model_name = db.Column(db.String(40), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - - -class TenantPreferredModelProvider(db.Model): - __tablename__ = 'tenant_preferred_model_providers' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), - db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - - -class ProviderOrder(db.Model): - __tablename__ = 'provider_orders' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_order_pkey'), - db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), - ) - - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) - account_id = db.Column(UUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/model_providers/pyproject.toml b/model_providers/pyproject.toml index 2fff65fd..58546f71 100644 --- a/model_providers/pyproject.toml +++ b/model_providers/pyproject.toml @@ -8,8 +8,9 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" transformers = "4.31.0" -Flask-SQLAlchemy = "3.0.5" -SQLAlchemy = "1.4.28" +fastapi = "^0.108" +uvicorn = "0.25.0" +sse-starlette = "^1.8.2" pyyaml = "6.0.1" pydantic = "1.10.14" redis = "4.5.4" @@ -18,7 +19,6 @@ openai = "1.13.3" tiktoken = "0.5.2" pydub = "0.25.1" boto3 = "1.28.17" - [tool.poetry.group.test.dependencies] # The only dependencies that should be added are # dependencies used for running tests (e.g., pytest, freezegun, response). From 9d954b2b7674724020416d7c1d24ab731e075ad0 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 22 Mar 2024 00:57:14 +0800 Subject: [PATCH 2/2] model_providers bootstrap --- .../model_providers/bootstrap_web/openai_bootstrap_web.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 57fd0add..7e5ef088 100644 --- a/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model_providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -40,7 +40,7 @@ async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, response = model_type_instance.invoke( model=chat_request.model, credentials={ - 'openai_api_key': "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", + 'openai_api_key': "sk-", 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') }, @@ -189,7 +189,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): response = model_type_instance.invoke( model='gpt-4', credentials={ - 'openai_api_key': "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", + 'openai_api_key': "sk-", 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') },