mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-31 19:33:26 +08:00
Merge pull request #3466 from glide-the/dev_model_providers
model_providers bootstrap
This commit is contained in:
commit
42dc6d18c9
@ -1,387 +1,18 @@
|
||||
import os
|
||||
from typing import cast, Generator
|
||||
|
||||
from model_providers.core.entities.application_entities import ModelConfigEntity, AppOrchestrationConfigEntity, \
|
||||
PromptTemplateEntity, AdvancedChatPromptTemplateEntity, ExternalDataVariableEntity, AgentEntity, AgentToolEntity, \
|
||||
AgentPromptEntity, DatasetEntity, DatasetRetrieveConfigEntity, FileUploadEntity, TextToSpeechEntity, \
|
||||
SensitiveWordAvoidanceEntity, AdvancedCompletionPromptTemplateEntity
|
||||
from model_providers.core.entities.model_entities import ModelStatus
|
||||
from model_providers.core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, \
|
||||
QuotaExceededError
|
||||
from model_providers.core.model_manager import ModelInstance
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole, UserPromptMessage, \
|
||||
AssistantPromptMessage
|
||||
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
from model_providers.core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param tenant_id: tenant ID
|
||||
:param app_model_config_dict: app model config dict
|
||||
:raises ProviderTokenNotInitError: provider token not init error
|
||||
:return: app orchestration config entity
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
copy_app_model_config_dict = app_model_config_dict.copy()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=copy_app_model_config_dict['model']['name']
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = copy_app_model_config_dict['model'].get('completion_params')
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = copy_app_model_config_dict['model'].get('mode')
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
credentials=model_credentials
|
||||
)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
copy_app_model_config_dict['model']['name'],
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
properties['model_config'] = ModelConfigEntity(
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
# prompt template
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
**completion_prompt_template_params
|
||||
)
|
||||
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
)
|
||||
|
||||
# external data variables
|
||||
properties['external_data_variables'] = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
continue
|
||||
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
# current external_data_tools
|
||||
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
|
||||
if retriever_resource_dict:
|
||||
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
|
||||
show_retrieve_source = True
|
||||
|
||||
properties['show_retrieve_source'] = show_retrieve_source
|
||||
|
||||
dataset_ids = []
|
||||
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
|
||||
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
continue
|
||||
dataset = dataset['dataset']
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
else:
|
||||
datasets = {'strategy': 'router', 'datasets': []}
|
||||
|
||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
|
||||
and copy_app_model_config_dict['agent_mode']['enabled']:
|
||||
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if copy_app_model_config_dict['model']['provider'] == 'openai':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
elif len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
|
||||
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
# check model mode
|
||||
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
)
|
||||
|
||||
properties['agent'] = AgentEntity(
|
||||
provider=properties['model_config'].provider,
|
||||
model=properties['model_config'].model,
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
)
|
||||
|
||||
if len(dataset_ids) > 0:
|
||||
# dataset configs
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=datasets.get('strategy', 'router')
|
||||
)
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
)
|
||||
|
||||
# file upload
|
||||
file_upload_dict = copy_app_model_config_dict.get('file_upload')
|
||||
if file_upload_dict:
|
||||
if 'image' in file_upload_dict and file_upload_dict['image']:
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
properties['file_upload'] = FileUploadEntity(
|
||||
image_config={
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'detail': file_upload_dict['image']['detail'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
}
|
||||
)
|
||||
|
||||
# opening statement
|
||||
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
|
||||
|
||||
# suggested questions after answer
|
||||
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
|
||||
if suggested_questions_after_answer_dict:
|
||||
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
|
||||
properties['suggested_questions_after_answer'] = True
|
||||
|
||||
# more like this
|
||||
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
||||
if more_like_this_dict:
|
||||
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
||||
properties['more_like_this'] = True
|
||||
|
||||
# speech to text
|
||||
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
||||
if speech_to_text_dict:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# text to speech
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
)
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
|
||||
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
)
|
||||
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 基于配置管理器创建的模型实例
|
||||
# provider_manager = ProviderManager()
|
||||
# provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
# tenant_id="tenant_id",
|
||||
# provider="copy_app_model_config_dict['model']['provider']",
|
||||
# model_type=ModelType.LLM
|
||||
# )
|
||||
#
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# model_instance = ModelInstance(
|
||||
|
||||
@ -0,0 +1,236 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
from fastapi import (APIRouter,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Response,
|
||||
Request,
|
||||
status
|
||||
)
|
||||
import logging
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
import json
|
||||
import pprint
|
||||
import tiktoken
|
||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \
|
||||
ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable
|
||||
from uvicorn import Config, Server
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import multiprocessing as mp
|
||||
import threading
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from model_providers.core.utils.generic import dictify, jsonify
|
||||
|
||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest):
|
||||
try:
|
||||
|
||||
|
||||
response = model_type_instance.invoke(
|
||||
model=chat_request.model,
|
||||
credentials={
|
||||
'openai_api_key': "sk-",
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
**chat_request.to_model_parameters_dict()
|
||||
},
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123"
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
"""
|
||||
Bootstrap Server Lifecycle
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
super().__init__()
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._router = APIRouter()
|
||||
self._app = FastAPI()
|
||||
self._server_thread = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
host = cfg.get("host", "127.0.0.1")
|
||||
port = cfg.get("port", 20000)
|
||||
|
||||
logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}")
|
||||
return cls(host=host, port=port)
|
||||
|
||||
def serve(self, logging_conf: Optional[dict] = None):
|
||||
self._app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/models",
|
||||
self.list_models,
|
||||
response_model=ModelList,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/embeddings",
|
||||
self.create_embeddings,
|
||||
response_model=EmbeddingsResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
methods=["POST"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/v1/chat/completions",
|
||||
self.create_chat_completion,
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
self._app.include_router(self._router)
|
||||
|
||||
config = Config(
|
||||
app=self._app, host=self._host, port=self._port, log_config=logging_conf
|
||||
)
|
||||
server = Server(config)
|
||||
|
||||
def run_server():
|
||||
server.run()
|
||||
|
||||
self._server_thread = threading.Thread(target=run_server)
|
||||
self._server_thread.start()
|
||||
|
||||
async def join(self):
|
||||
await self._server_thread.join()
|
||||
|
||||
def set_app_event(self, started_event: mp.Event = None):
|
||||
@self._app.on_event("startup")
|
||||
async def on_startup():
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
|
||||
async def list_models(self, request: Request):
|
||||
pass
|
||||
|
||||
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
|
||||
logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}")
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
client = ZhipuAI(api_key=authorization)
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = None
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
text = encoding.decode(token)
|
||||
input += text
|
||||
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=embeddings_request.model,
|
||||
input=input,
|
||||
)
|
||||
return EmbeddingsResponse(**dictify(response))
|
||||
|
||||
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
|
||||
logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}")
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
model_provider_factory.get_providers(provider_name='openai')
|
||||
provider_instance = model_provider_factory.get_provider_instance('openai')
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
if chat_request.stream:
|
||||
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
else:
|
||||
|
||||
response = model_type_instance.invoke(
|
||||
model='gpt-4',
|
||||
credentials={
|
||||
'openai_api_key': "sk-",
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'plugin_web_search': True,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
chat_response = ChatCompletionResponse(**dictify(response))
|
||||
|
||||
return chat_response
|
||||
|
||||
|
||||
def run(
|
||||
cfg: Dict, logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
try:
|
||||
import signal
|
||||
# 跳过键盘中断,使用xoscar的信号处理
|
||||
signal.signal(signal.SIGINT, lambda *_: None)
|
||||
api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {}))
|
||||
api.set_app_event(started_event=started_event)
|
||||
api.serve(logging_conf=logging_conf)
|
||||
|
||||
async def pool_join_thread():
|
||||
await api.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
@ -0,0 +1,8 @@
|
||||
|
||||
from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.bootstrap.bootstrap_register import bootstrap_register
|
||||
__all__ = [
|
||||
"bootstrap_register",
|
||||
"Bootstrap",
|
||||
"OpenAIBootstrapBaseWeb",
|
||||
]
|
||||
54
model_providers/model_providers/core/bootstrap/base.py
Normal file
54
model_providers/model_providers/core/bootstrap/base.py
Normal file
@ -0,0 +1,54 @@
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
|
||||
"""最大的任务队列"""
|
||||
_MAX_ONGOING_TASKS: int = 1
|
||||
|
||||
"""任务队列"""
|
||||
_QUEUE: deque = deque()
|
||||
|
||||
def __init__(self):
|
||||
self._version = "v0.0.1"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def queue(self) -> deque:
|
||||
return self._QUEUE
|
||||
|
||||
@classmethod
|
||||
async def run(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def destroy(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self, request: Request):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
|
||||
pass
|
||||
@ -0,0 +1,51 @@
|
||||
from model_providers.core.bootstrap import Bootstrap
|
||||
|
||||
|
||||
class BootstrapRegister:
|
||||
"""
|
||||
注册管理器
|
||||
"""
|
||||
mapping = {
|
||||
"bootstrap": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_bootstrap(cls, name):
|
||||
r"""Register system bootstrap to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from lavis.common.registry import registry
|
||||
"""
|
||||
|
||||
print(f"register_bootstrap {name}")
|
||||
|
||||
def wrap(task_cls):
|
||||
assert issubclass(
|
||||
task_cls, Bootstrap
|
||||
), "All tasks must inherit bootstrap class"
|
||||
if name in cls.mapping["bootstrap"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["bootstrap"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["bootstrap"][name] = task_cls
|
||||
return task_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def get_bootstrap_class(cls, name):
|
||||
return cls.mapping["bootstrap"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def list_bootstrap(cls):
|
||||
return sorted(cls.mapping["bootstrap"].keys())
|
||||
|
||||
|
||||
bootstrap_register = BootstrapRegister()
|
||||
|
||||
@ -0,0 +1,143 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
TOOL = "tool_calls"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
class FunctionCallDefinition(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
type: Literal["function", "code_interpreter"] = "function"
|
||||
function: Optional[FunctionDefinition] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
function_call: Optional[Function] = None
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: Optional[int] = None
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
functions: Optional[List[FunctionDefinition]] = None
|
||||
function_call: Optional[FunctionCallDefinition] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def to_model_parameters_dict(self, *args, **kwargs):
|
||||
# 调用父类的to_dict方法,并排除tools字段
|
||||
helper.dump_model
|
||||
return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs)
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: Union[str, List[List[int]], List[int], List[str]]
|
||||
model: str
|
||||
encoding_format: Literal["base64", "float"] = "float"
|
||||
|
||||
|
||||
class Embeddings(BaseModel):
|
||||
object: Literal["embedding"] = "embedding"
|
||||
embedding: Union[List[float], bytes]
|
||||
index: int
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[Embeddings]
|
||||
model: str
|
||||
usage: UsageInfo
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
|
||||
# from model_providers.core.helper import encrypter
|
||||
from model_providers.core.helper import encrypter
|
||||
from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
@ -22,7 +22,7 @@ from model_providers.core.model_runtime.model_providers import model_provider_fa
|
||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from model_providers.extensions.ext_database import db
|
||||
from model_providers.models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
||||
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -172,20 +172,20 @@ class ProviderConfiguration(BaseModel):
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
# for key, value in credentials.items():
|
||||
# if key in provider_credential_secret_variables:
|
||||
# # if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
# if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
self.provider.provider,
|
||||
credentials
|
||||
)
|
||||
|
||||
# for key, value in credentials.items():
|
||||
# if key in provider_credential_secret_variables:
|
||||
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_record, credentials
|
||||
|
||||
@ -315,11 +315,11 @@ class ProviderConfiguration(BaseModel):
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
# for key, value in credentials.items():
|
||||
# if key in provider_credential_secret_variables:
|
||||
# # if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
# if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
@ -328,9 +328,9 @@ class ProviderConfiguration(BaseModel):
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
# for key, value in credentials.items():
|
||||
# if key in provider_credential_secret_variables:
|
||||
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_model_record, credentials
|
||||
|
||||
@ -481,10 +481,10 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
# copy_credentials = credentials.copy()
|
||||
# for key, value in copy_credentials.items():
|
||||
# if key in credential_secret_variables:
|
||||
# copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
copy_credentials = credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
|
||||
return copy_credentials
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.models.provider import ProviderQuotaType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
description = "Model Currently Not Support"
|
||||
@ -1,90 +0,0 @@
|
||||
import enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.file.upload_file_parser import UploadFileParser
|
||||
from model_providers.core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from model_providers.extensions.ext_database import db
|
||||
from model_providers.models.model import UploadFile
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
TOOL_FILE = 'tool_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str]
|
||||
upload_file_id: Optional[str]
|
||||
file_config: dict
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_config.get('image')
|
||||
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
return UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
)
|
||||
|
||||
return None
|
||||
@ -1,184 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from model_providers.core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
|
||||
from model_providers.extensions.ext_database import db
|
||||
from model_providers.models.account import Account
|
||||
from model_providers.models.model import AppModelConfig, EndUser, MessageFile, UploadFile
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
file_upload_config = app_model_config.file_upload_dict
|
||||
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError('Invalid file format, must be dict')
|
||||
if not file.get('type'):
|
||||
raise ValueError('Missing file type')
|
||||
FileType.value_of(file.get('type'))
|
||||
if not file.get('transfer_method'):
|
||||
raise ValueError('Missing file transfer method')
|
||||
FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get('url'):
|
||||
raise ValueError('Missing file url')
|
||||
if not file.get('url').startswith('http'):
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_upload_config.get('image')
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config['enabled']:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config['number_limits']:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config['transfer_methods']:
|
||||
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f'Invalid file type: {file_obj.type}')
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError('Invalid upload file')
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
return FileObj(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
else:
|
||||
return FileObj(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id or None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
@ -1,8 +0,0 @@
|
||||
tool_file_manager = {
|
||||
'manager': None
|
||||
}
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> 'ToolFileManager':
|
||||
return tool_file_manager['manager']
|
||||
@ -1,79 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from model_providers.extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f'File not found: {upload_file.key}')
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= 300 # expired after 5 minutes
|
||||
@ -1,51 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from model_providers.extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "provider"
|
||||
MODEL = "provider_model"
|
||||
|
||||
|
||||
class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_provider_credentials
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@ -1,250 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask import Config, Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class HostingQuota(BaseModel):
|
||||
quota_type: ProviderQuotaType
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class TrialHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the hosting provider models. -1 means unlimited."""
|
||||
|
||||
|
||||
class PaidHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
|
||||
|
||||
|
||||
class FreeHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
|
||||
|
||||
|
||||
class HostingProvider(BaseModel):
|
||||
enabled: bool = False
|
||||
credentials: Optional[dict] = None
|
||||
quota_unit: Optional[QuotaUnit] = None
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
providers: list[str] = []
|
||||
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
moderation_config: HostedModerationConfig = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
|
||||
if config.get('EDITION') != 'CLOUD':
|
||||
return
|
||||
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai(config)
|
||||
self.provider_map["openai"] = self.init_openai(config)
|
||||
self.provider_map["anthropic"] = self.init_anthropic(config)
|
||||
self.provider_map["minimax"] = self.init_minimax(config)
|
||||
self.provider_map["spark"] = self.init_spark(config)
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai(config)
|
||||
|
||||
self.moderation_config = self.init_moderation_config(config)
|
||||
|
||||
def init_azure_openai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
"base_model_name": "gpt-35-turbo"
|
||||
}
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=credentials,
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_openai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=trial_models
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(
|
||||
restrict_models=paid_models
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=credentials,
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
|
||||
paid_quota = PaidHostingQuota()
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=credentials,
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_minimax(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_MINIMAX_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_spark(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_SPARK_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_zhipuai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_moderation_config(self, app_config: Config) -> HostedModerationConfig:
|
||||
if app_config.get("HOSTED_MODERATION_ENABLED") \
|
||||
and app_config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
return HostedModerationConfig(
|
||||
enabled=True,
|
||||
providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
)
|
||||
|
||||
return HostedModerationConfig(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
|
||||
model_name.strip()]
|
||||
|
||||
@ -1,257 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
||||
from model_providers.core.errors.error import ProviderTokenNotInitError
|
||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
||||
self._provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||
|
||||
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||
"""
|
||||
Fetch credentials from provider model bundle
|
||||
:param provider_model_bundle: provider model bundle
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model
|
||||
)
|
||||
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
def get_tts_voices(self, language: str) -> list:
|
||||
"""
|
||||
Invoke large language tts model voices
|
||||
|
||||
:param language: tts language
|
||||
:return: tts model voices
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
language=language
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self) -> None:
|
||||
self._provider_manager = ProviderManager()
|
||||
|
||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param tenant_id: tenant id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if not provider:
|
||||
return self.get_default_model_instance(tenant_id, model_type)
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
:param tenant_id: tenant id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
default_model_entity = self._provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
if not default_model_entity:
|
||||
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||
|
||||
return self.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=default_model_entity.provider.provider,
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
)
|
||||
@ -3,8 +3,7 @@ import copy
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@ -37,15 +36,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
if not voice or voice not in [d['value'] for d in
|
||||
self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
return StreamingResponse(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice), media_type='text/event-stream')
|
||||
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
@ -68,7 +68,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
@ -103,7 +103,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@ -160,7 +160,6 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
|
||||
@ -3,10 +3,9 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
@ -37,12 +36,11 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
return StreamingResponse(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice), media_type='text/event-stream')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
@ -65,7 +63,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
@ -100,7 +98,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
@ -4,9 +4,8 @@ from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from flask import Response, stream_with_context
|
||||
from pydub import AudioSegment
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
@ -37,12 +36,11 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
tenant_id=tenant_id)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
return StreamingResponse(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice), media_type='text/event-stream')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
@ -101,7 +99,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
return StreamingResponse(buffer, media_type=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
@ -1,761 +0,0 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from model_providers.core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
QuotaConfiguration,
|
||||
SystemConfiguration,
|
||||
)
|
||||
# from model_providers.core.helper import encrypter
|
||||
from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
from model_providers.extensions import ext_hosting_provider
|
||||
from model_providers.extensions.ext_database import db
|
||||
from model_providers.models.provider import (
|
||||
Provider,
|
||||
ProviderModel,
|
||||
ProviderQuotaType,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
Get model provider configurations.
|
||||
|
||||
Construct ProviderConfiguration objects for each provider
|
||||
Including:
|
||||
1. Basic information of the provider
|
||||
2. Hosting configuration information, including:
|
||||
(1. Whether to enable (support) hosting type, if enabled, the following information exists
|
||||
(2. List of hosting type provider configurations
|
||||
(including quota type, quota limit, current remaining quota, etc.)
|
||||
(3. The current hosting type in use (whether there is a quota or not)
|
||||
paid quotas > provider free quotas > hosting trial quotas
|
||||
(4. Unified credentials for hosting providers
|
||||
3. Custom configuration information, including:
|
||||
(1. Whether to enable (support) custom type, if enabled, the following information exists
|
||||
(2. Custom provider configuration (including credentials)
|
||||
(3. List of custom provider model configurations (including credentials)
|
||||
4. Hosting/custom preferred provider type.
|
||||
Provide methods:
|
||||
- Get the current configuration (including credentials)
|
||||
- Get the availability and status of the hosting configuration: active available,
|
||||
quota_exceeded insufficient quota, unsupported hosting
|
||||
- Get the availability of custom configuration
|
||||
Custom provider available conditions:
|
||||
(1. custom provider credentials available
|
||||
(2. at least one custom model credentials available
|
||||
- Verify, update, and delete custom provider configuration
|
||||
- Verify, update, and delete custom provider model configuration
|
||||
- Get the list of available models (optional provider filtering, model type filtering)
|
||||
Append custom provider models to the list
|
||||
- Get provider instance
|
||||
- Switch selection priority
|
||||
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
# Initialize trial provider records if not exist
|
||||
provider_name_to_provider_records_dict = self._init_trial_provider_records(
|
||||
tenant_id,
|
||||
provider_name_to_provider_records_dict
|
||||
)
|
||||
|
||||
# Get all provider model records of the workspace
|
||||
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
|
||||
|
||||
# Get all provider entities
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
provider_name = provider_entity.provider
|
||||
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider)
|
||||
if not provider_records:
|
||||
provider_records = []
|
||||
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
|
||||
if not provider_model_records:
|
||||
provider_model_records = []
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records,
|
||||
provider_model_records
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
system_configuration = self._to_system_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records
|
||||
)
|
||||
|
||||
# Get preferred provider type
|
||||
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
else:
|
||||
if custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
else:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
|
||||
using_provider_type = preferred_provider_type
|
||||
if preferred_provider_type == ProviderType.SYSTEM:
|
||||
if not system_configuration.enabled:
|
||||
using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
has_valid_quota = False
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.is_valid:
|
||||
has_valid_quota = True
|
||||
break
|
||||
|
||||
if not has_valid_quota:
|
||||
using_provider_type = ProviderType.CUSTOM
|
||||
else:
|
||||
if not custom_configuration.provider and not custom_configuration.models:
|
||||
if system_configuration.enabled:
|
||||
has_valid_quota = False
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.is_valid:
|
||||
has_valid_quota = True
|
||||
break
|
||||
|
||||
if has_valid_quota:
|
||||
using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
provider_configuration = ProviderConfiguration(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_entity,
|
||||
preferred_provider_type=preferred_provider_type,
|
||||
using_provider_type=using_provider_type,
|
||||
system_configuration=system_configuration,
|
||||
custom_configuration=custom_configuration
|
||||
)
|
||||
|
||||
provider_configurations[provider_name] = provider_configuration
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle:
|
||||
"""
|
||||
Get provider model bundle.
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get provider instance
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
provider_instance = provider_configuration.get_provider_instance()
|
||||
model_type_instance = provider_instance.get_model_instance(model_type)
|
||||
|
||||
return ProviderModelBundle(
|
||||
configuration=provider_configuration,
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
)
|
||||
|
||||
def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]:
|
||||
"""
|
||||
Get default model.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get the corresponding TenantDefaultModel record
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# If it does not exist, get the first available provider model from get_configurations
|
||||
# and update the TenantDefaultModel record
|
||||
if not default_model:
|
||||
# Get provider configurations
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
if available_models:
|
||||
found = False
|
||||
for available_model in available_models:
|
||||
if available_model.model == "gpt-3.5-turbo-1106":
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
available_model = available_models[0]
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
return DefaultModelEntity(
|
||||
model=default_model.model_name,
|
||||
model_type=model_type,
|
||||
provider=DefaultModelProviderEntity(
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types
|
||||
)
|
||||
)
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
-> TenantDefaultModel:
|
||||
"""
|
||||
Update default model record.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:param provider: provider name
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
if provider not in provider_configurations:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
# check if the model is exist in available models
|
||||
model_names = [model.model for model in available_models]
|
||||
if model not in model_names:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
|
||||
# Get the list of available models from get_configurations and check if it is LLM
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# create or update TenantDefaultModel record
|
||||
if default_model:
|
||||
# update default model
|
||||
default_model.provider_name = provider
|
||||
default_model.model_name = model
|
||||
db.session.commit()
|
||||
else:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=provider,
|
||||
model_name=model,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
return default_model
|
||||
|
||||
def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Get all provider records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.is_valid == True
|
||||
).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
provider_name_to_provider_records_dict[provider.provider_name].append(provider)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]:
|
||||
"""
|
||||
Get all provider model records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.is_valid == True
|
||||
).all()
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
return provider_name_to_provider_model_records_dict
|
||||
|
||||
def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
|
||||
"""
|
||||
Get All preferred provider types of the workspace.
|
||||
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id
|
||||
).all()
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
def _init_trial_provider_records(self, tenant_id: str,
|
||||
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
|
||||
"""
|
||||
Initialize trial provider records if not exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name_to_provider_records_dict: provider name to provider records dict
|
||||
:return:
|
||||
"""
|
||||
# Get hosting configuration
|
||||
hosting_configuration = ext_hosting_provider.hosting_configuration
|
||||
|
||||
for provider_name, configuration in hosting_configuration.provider_map.items():
|
||||
if not configuration.enabled:
|
||||
continue
|
||||
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_name)
|
||||
if not provider_records:
|
||||
provider_records = []
|
||||
|
||||
provider_quota_to_provider_record_dict = dict()
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
|
||||
= provider_record
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_used=0,
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||
).first()
|
||||
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
def _to_custom_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_entity: provider entity
|
||||
:param provider_records: provider records
|
||||
:param provider_model_records: provider model records
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
custom_provider_record = None
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
if not provider_record.encrypted_config:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||
provider_credentials = {
|
||||
"openai_api_key": custom_provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# # Get decoding rsa key and cipher for decrypting credentials
|
||||
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
# for variable in provider_credential_secret_variables:
|
||||
# if variable in provider_credentials:
|
||||
# try:
|
||||
# provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
# provider_credentials.get(variable),
|
||||
# self.decoding_rsa_key,
|
||||
# self.decoding_cipher_rsa
|
||||
# )
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# # Get decoding rsa key and cipher for decrypting credentials
|
||||
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
#
|
||||
# for variable in model_credential_secret_variables:
|
||||
# if variable in provider_model_credentials:
|
||||
# try:
|
||||
# provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
# provider_model_credentials.get(variable),
|
||||
# self.decoding_rsa_key,
|
||||
# self.decoding_cipher_rsa
|
||||
# )
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations
|
||||
)
|
||||
|
||||
def _to_system_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider]) -> SystemConfiguration:
|
||||
"""
|
||||
Convert to system configuration.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_entity: provider entity
|
||||
:param provider_records: provider records
|
||||
:return:
|
||||
"""
|
||||
# Get hosting configuration
|
||||
hosting_configuration = ext_hosting_provider.hosting_configuration
|
||||
|
||||
if provider_entity.provider not in hosting_configuration.provider_map \
|
||||
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled:
|
||||
return SystemConfiguration(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = dict()
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
|
||||
= provider_record
|
||||
|
||||
quota_configurations = []
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
if provider_quota.quota_type == ProviderQuotaType.FREE:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit,
|
||||
quota_used=0,
|
||||
quota_limit=0,
|
||||
is_valid=False,
|
||||
restrict_models=provider_quota.restrict_models
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
if len(quota_configurations) == 0:
|
||||
return SystemConfiguration(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
current_quota_type = self._choice_current_using_quota_type(quota_configurations)
|
||||
|
||||
current_using_credentials = provider_hosting_configuration.credentials
|
||||
if current_quota_type == ProviderQuotaType.FREE:
|
||||
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
||||
|
||||
if provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# # Get provider credential secret variables
|
||||
# provider_credential_secret_variables = self._extract_secret_variables(
|
||||
# provider_entity.provider_credential_schema.credential_form_schemas
|
||||
# if provider_entity.provider_credential_schema else []
|
||||
# )
|
||||
|
||||
# # Get decoding rsa key and cipher for decrypting credentials
|
||||
# if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
# self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
#
|
||||
# for variable in provider_credential_secret_variables:
|
||||
# if variable in provider_credentials:
|
||||
# try:
|
||||
# provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
# provider_credentials.get(variable),
|
||||
# self.decoding_rsa_key,
|
||||
# self.decoding_cipher_rsa
|
||||
# )
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
current_using_credentials = provider_credentials
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=current_using_credentials
|
||||
)
|
||||
else:
|
||||
current_using_credentials = cached_provider_credentials
|
||||
else:
|
||||
current_using_credentials = {}
|
||||
quota_configurations = []
|
||||
|
||||
return SystemConfiguration(
|
||||
enabled=True,
|
||||
current_quota_type=current_quota_type,
|
||||
quota_configurations=quota_configurations,
|
||||
credentials=current_using_credentials
|
||||
)
|
||||
|
||||
def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType:
|
||||
"""
|
||||
Choice current using quota type.
|
||||
paid quotas > provider free quotas > hosting trial quotas
|
||||
If there is still quota for the corresponding quota type according to the sorting,
|
||||
|
||||
:param quota_configurations:
|
||||
:return:
|
||||
"""
|
||||
# convert to dict
|
||||
quota_type_to_quota_configuration_dict = {
|
||||
quota_configuration.quota_type: quota_configuration
|
||||
for quota_configuration in quota_configurations
|
||||
}
|
||||
|
||||
last_quota_configuration = None
|
||||
for quota_type in [ProviderQuotaType.PAID, ProviderQuotaType.FREE, ProviderQuotaType.TRIAL]:
|
||||
if quota_type in quota_type_to_quota_configuration_dict:
|
||||
last_quota_configuration = quota_type_to_quota_configuration_dict[quota_type]
|
||||
if last_quota_configuration.is_valid:
|
||||
return quota_type
|
||||
|
||||
if last_quota_configuration:
|
||||
return last_quota_configuration.quota_type
|
||||
|
||||
raise ValueError('No quota type available')
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
@ -1,102 +0,0 @@
|
||||
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {{query}}
|
||||
Thought: {{agent_scratchpad}}"""
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
|
||||
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
21
model_providers/model_providers/core/utils/generic.py
Normal file
21
model_providers/model_providers/core/utils/generic.py
Normal file
@ -0,0 +1,21 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except Exception: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except Exception: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
12
model_providers/model_providers/core/utils/json_dumps.py
Normal file
12
model_providers/model_providers/core/utils/json_dumps.py
Normal file
@ -0,0 +1,12 @@
|
||||
import orjson
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def json_dumps(o):
|
||||
def _default(obj):
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.dict()
|
||||
raise TypeError
|
||||
|
||||
return orjson.dumps(o, default=_default)
|
||||
@ -1,9 +0,0 @@
|
||||
from flask import Flask
|
||||
|
||||
from model_providers.core.hosting_configuration import HostingConfiguration
|
||||
|
||||
hosting_configuration = HostingConfiguration()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
hosting_configuration.init_app(app)
|
||||
@ -1,7 +0,0 @@
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
db = SQLAlchemy()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
db.init_app(app)
|
||||
@ -1,9 +0,0 @@
|
||||
from flask import Flask
|
||||
|
||||
from model_providers.core.hosting_configuration import HostingConfiguration
|
||||
|
||||
hosting_configuration = HostingConfiguration()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
hosting_configuration.init_app(app)
|
||||
@ -6,7 +6,6 @@ from typing import Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class Storage:
|
||||
@ -16,21 +15,21 @@ class Storage:
|
||||
self.client = None
|
||||
self.folder = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
self.storage_type = app.config.get('STORAGE_TYPE')
|
||||
def init_config(self, config: dict):
|
||||
self.storage_type = config.get('STORAGE_TYPE')
|
||||
if self.storage_type == 's3':
|
||||
self.bucket_name = app.config.get('S3_BUCKET_NAME')
|
||||
self.bucket_name = config.get('S3_BUCKET_NAME')
|
||||
self.client = boto3.client(
|
||||
's3',
|
||||
aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
|
||||
aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
|
||||
endpoint_url=app.config.get('S3_ENDPOINT'),
|
||||
region_name=app.config.get('S3_REGION')
|
||||
aws_secret_access_key=config.get('S3_SECRET_KEY'),
|
||||
aws_access_key_id=config.get('S3_ACCESS_KEY'),
|
||||
endpoint_url=config.get('S3_ENDPOINT'),
|
||||
region_name=config.get('S3_REGION')
|
||||
)
|
||||
else:
|
||||
self.folder = app.config.get('STORAGE_LOCAL_PATH')
|
||||
self.folder = config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(self.folder):
|
||||
self.folder = os.path.join(app.root_path, self.folder)
|
||||
self.folder = os.path.join(config.get('root_path'), self.folder)
|
||||
|
||||
def save(self, filename, data):
|
||||
if self.storage_type == 's3':
|
||||
@ -140,5 +139,3 @@ class Storage:
|
||||
storage = Storage()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
storage.init_app(app)
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
@ -1,28 +0,0 @@
|
||||
from sqlalchemy import Float, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from model_providers.extensions.ext_database import db
|
||||
|
||||
|
||||
class UploadFile(db.Model):
|
||||
__tablename__ = 'upload_files'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='upload_file_pkey'),
|
||||
db.Index('upload_file_tenant_idx', 'tenant_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
storage_type = db.Column(db.String(255), nullable=False)
|
||||
key = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
size = db.Column(db.Integer, nullable=False)
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
used_by = db.Column(UUID, nullable=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
hash = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from model_providers.extensions.ext_database import db
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = 'custom'
|
||||
SYSTEM = 'system'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = 'paid'
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = 'free'
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = 'trial'
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class Provider(db.Model):
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
"""
|
||||
__tablename__ = 'providers'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='provider_pkey'),
|
||||
db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'),
|
||||
db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
last_used = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
|
||||
quota_limit = db.Column(db.BigInteger, nullable=True)
|
||||
quota_used = db.Column(db.BigInteger, default=0)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}', provider_type='{self.provider_type}')>"
|
||||
|
||||
@property
|
||||
def token_is_set(self):
|
||||
"""
|
||||
Returns True if the encrypted_config is not None, indicating that the token is set.
|
||||
"""
|
||||
return self.encrypted_config is not None
|
||||
|
||||
@property
|
||||
def is_enabled(self):
|
||||
"""
|
||||
Returns True if the provider is enabled.
|
||||
"""
|
||||
if self.provider_type == ProviderType.SYSTEM.value:
|
||||
return self.is_valid
|
||||
else:
|
||||
return self.is_valid and self.token_is_set
|
||||
|
||||
|
||||
class ProviderModel(db.Model):
|
||||
"""
|
||||
Provider model representing the API provider_models and their configurations.
|
||||
"""
|
||||
__tablename__ = 'provider_models'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='provider_model_pkey'),
|
||||
db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'),
|
||||
db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class TenantDefaultModel(db.Model):
|
||||
__tablename__ = 'tenant_default_models'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'),
|
||||
db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(db.Model):
|
||||
__tablename__ = 'tenant_preferred_model_providers'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'),
|
||||
db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
preferred_provider_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class ProviderOrder(db.Model):
|
||||
__tablename__ = 'provider_orders'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='provider_order_pkey'),
|
||||
db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
payment_product_id = db.Column(db.String(191), nullable=False)
|
||||
payment_id = db.Column(db.String(191))
|
||||
transaction_id = db.Column(db.String(191))
|
||||
quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1'))
|
||||
currency = db.Column(db.String(40))
|
||||
total_amount = db.Column(db.Integer)
|
||||
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
|
||||
paid_at = db.Column(db.DateTime)
|
||||
pay_failed_at = db.Column(db.DateTime)
|
||||
refunded_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -8,8 +8,9 @@ readme = "README.md"
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
transformers = "4.31.0"
|
||||
Flask-SQLAlchemy = "3.0.5"
|
||||
SQLAlchemy = "1.4.28"
|
||||
fastapi = "^0.108"
|
||||
uvicorn = "0.25.0"
|
||||
sse-starlette = "^1.8.2"
|
||||
pyyaml = "6.0.1"
|
||||
pydantic = "1.10.14"
|
||||
redis = "4.5.4"
|
||||
@ -18,7 +19,6 @@ openai = "1.13.3"
|
||||
tiktoken = "0.5.2"
|
||||
pydub = "0.25.1"
|
||||
boto3 = "1.28.17"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
# The only dependencies that should be added are
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user