dify model_providers configuration

This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
This commit is contained in:
glide-the 2024-03-18 01:19:14 +08:00
parent 4bdb69baf3
commit e977e2ff73
507 changed files with 166249 additions and 0 deletions

View File

@ -0,0 +1,429 @@
import os
from typing import cast, Generator
from model_providers.core.entities.application_entities import ModelConfigEntity, AppOrchestrationConfigEntity, \
PromptTemplateEntity, AdvancedChatPromptTemplateEntity, ExternalDataVariableEntity, AgentEntity, AgentToolEntity, \
AgentPromptEntity, DatasetEntity, DatasetRetrieveConfigEntity, FileUploadEntity, TextToSpeechEntity, \
SensitiveWordAvoidanceEntity, AdvancedCompletionPromptTemplateEntity
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, \
QuotaExceededError
from model_providers.core.model_manager import ModelInstance
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole, UserPromptMessage, \
AssistantPromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.provider_manager import ProviderManager
from model_providers.core.tools.prompt.template import REACT_PROMPT_TEMPLATES
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
"""
Convert app model config dict to entity.
:param tenant_id: tenant ID
:param app_model_config_dict: app model config dict
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
properties = {}
copy_app_model_config_dict = app_model_config_dict.copy()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=copy_app_model_config_dict['model']['provider'],
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = copy_app_model_config_dict['model']['name']
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=copy_app_model_config_dict['model']['name']
)
if model_credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=copy_app_model_config_dict['model']['name'],
model_type=ModelType.LLM
)
if provider_model is None:
model_name = copy_app_model_config_dict['model']['name']
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = copy_app_model_config_dict['model'].get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = copy_app_model_config_dict['model'].get('mode')
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=copy_app_model_config_dict['model']['name'],
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
copy_app_model_config_dict['model']['name'],
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
properties['model_config'] = ModelConfigEntity(
provider=copy_app_model_config_dict['model']['provider'],
model=copy_app_model_config_dict['model']['name'],
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
# prompt template
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
# external data variables
properties['external_data_variables'] = []
# old external_data_tools
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# current external_data_tools
for variable in copy_app_model_config_dict.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
)
)
# show retrieve source
show_retrieve_source = False
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
properties['show_retrieve_source'] = show_retrieve_source
dataset_ids = []
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset['dataset']
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
else:
datasets = {'strategy': 'router', 'datasets': []}
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
and copy_app_model_config_dict['agent_mode']['enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if copy_app_model_config_dict['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
elif len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
properties['agent'] = AgentEntity(
provider=properties['model_config'].provider,
model=properties['model_config'].model,
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
)
if len(dataset_ids) > 0:
# dataset configs
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
single_strategy=datasets.get('strategy', 'router')
)
)
else:
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
# file upload
file_upload_dict = copy_app_model_config_dict.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
properties['file_upload'] = FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
# opening statement
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
# suggested questions after answer
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
properties['suggested_questions_after_answer'] = True
# more like this
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = True
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
properties['speech_to_text'] = True
# text to speech
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
properties['text_to_speech'] = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
if sensitive_word_avoidance_dict:
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
return AppOrchestrationConfigEntity(**properties)
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# provider_manager = ProviderManager()
# provider_model_bundle = provider_manager.get_provider_model_bundle(
# tenant_id="tenant_id",
# provider="copy_app_model_config_dict['model']['provider']",
# model_type=ModelType.LLM
# )
#
#
# model_instance = ModelInstance(
# provider_model_bundle=provider_model_bundle,
# model=model_config.model,
# )
# 直接通过模型加载器创建的模型实例
from model_providers.core.model_runtime.model_providers import model_provider_factory
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
print(model_type_instance)
response = model_type_instance.invoke(
model='gpt-4',
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
},
stop=['you'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
total_message = ''
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
print(total_message)
assert '参考资料' in total_message

View File

@ -0,0 +1,8 @@
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'

View File

@ -0,0 +1,309 @@
from enum import Enum
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.file.file_obj import FileObj
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Advanced Completion Prompt Template Entity.
"""
class RolePrefixEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
class PromptTemplateEntity(BaseModel):
"""
Prompt Template Entity.
"""
class PromptType(Enum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = 'simple'
ADVANCED = 'advanced'
@classmethod
def value_of(cls, value: str) -> 'PromptType':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt type value {value}')
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = 'single'
MULTIPLE = 'multiple'
@classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid retrieve strategy value {value}')
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
single_strategy: Optional[str] = None # for temp
top_k: Optional[int] = None
score_threshold: Optional[float] = None
reranking_model: Optional[dict] = None
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
class FileUploadEntity(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: list[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
dataset: Optional[DatasetEntity] = None
file_upload: Optional[FileUploadEntity] = None
opening_statement: Optional[str] = None
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: dict = {}
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
elif self == InvokeFrom.DEBUGGER:
return 'dev'
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
elif self == InvokeFrom.SERVICE_API:
return 'api'
return 'dev'
class ApplicationGenerateEntity(BaseModel):
"""
Application Generate Entity.
"""
task_id: str
tenant_id: str
app_id: str
app_model_config_id: str
# for save
app_model_config_dict: dict
app_model_config_override: bool
# Converted from app_model_config to Entity object, or directly covered by external input
app_orchestration_config_entity: AppOrchestrationConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str]
query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}

View File

@ -0,0 +1,135 @@
import enum
from typing import Any, cast
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, list[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
if isinstance(message, LCHumanMessageWithFiles):
file_prompt_message_contents = []
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
file_prompt_message_contents.append(ImagePromptMessageContent(
data=file.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
))
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
prompt_message_contents.extend(file_prompt_message_contents)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.content))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content
}
if 'function_call' in message.additional_kwargs:
message_kwargs['tool_calls'] = [
AssistantPromptMessage.ToolCall(
id=message.additional_kwargs['function_call']['id'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.additional_kwargs['function_call']['name'],
arguments=message.additional_kwargs['function_call']['arguments']
)
)
]
prompt_messages.append(AssistantPromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=message.content))
elif isinstance(message, FunctionMessage):
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
return prompt_messages
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, str):
messages.append(HumanMessage(content=prompt_message.content))
else:
message_contents = []
for content in prompt_message.content:
if isinstance(content, TextPromptMessageContent):
message_contents.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
message_contents.append({
'type': 'image',
'data': content.data,
'detail': content.detail.value
})
messages.append(HumanMessage(content=message_contents))
elif isinstance(prompt_message, AssistantPromptMessage):
message_kwargs = {
'content': prompt_message.content
}
if prompt_message.tool_calls:
message_kwargs['additional_kwargs'] = {
'function_call': {
'id': prompt_message.tool_calls[0].id,
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
}
messages.append(AIMessage(**message_kwargs))
elif isinstance(prompt_message, SystemPromptMessage):
messages.append(SystemMessage(content=prompt_message.content))
elif isinstance(prompt_message, ToolPromptMessage):
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
return messages

View File

@ -0,0 +1,71 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import ModelType, ProviderModel
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
class ModelStatus(Enum):
"""
Enum class for model status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
class SimpleModelProviderEntity(BaseModel):
"""
Simple provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
"""
Init simple provider.
:param provider_entity: provider entity
"""
super().__init__(
provider=provider_entity.provider,
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types
)
class ModelWithProviderEntity(ProviderModel):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
status: ModelStatus
class DefaultModelProviderEntity(BaseModel):
"""
Default model provider entity.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
class DefaultModelEntity(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: DefaultModelProviderEntity

View File

@ -0,0 +1,798 @@
import datetime
import json
import logging
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
# from model_providers.core.helper import encrypter
from model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.extensions.ext_database import db
from model_providers.models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
:param model_type: model type
:param model: model name
:return:
"""
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_models = quota_configuration.restrict_models
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
if (restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name):
copy_credentials['base_model_name'] = restrict_model.base_model_name
return copy_credentials
else:
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
return None
def get_system_configuration_status(self) -> SystemConfigurationStatus:
"""
Get system configuration status.
:return:
"""
if self.system_configuration.enabled is False:
return SystemConfigurationStatus.UNSUPPORTED
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
None
)
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
SystemConfigurationStatus.QUOTA_EXCEEDED
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Get custom credentials.
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if self.custom_configuration.provider is None:
return None
credentials = self.custom_configuration.provider.credentials
if not obfuscated:
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
if provider_record:
try:
# fix origin data
if provider_record.encrypted_config:
if not provider_record.encrypted_config.startswith("{"):
original_credentials = {
"openai_api_key": provider_record.encrypted_config
}
else:
original_credentials = json.loads(provider_record.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# # if send [__HIDDEN__] in secret input, it will be same as original value
# if value == '[__HIDDEN__]' and key in original_credentials:
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate(
self.provider.provider,
credentials
)
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_record, credentials
def add_or_update_custom_credentials(self, credentials: dict) -> None:
"""
Add or update custom provider credentials.
:param credentials:
:return:
"""
# validate custom provider config
provider_record, credentials = self.custom_credentials_validate(credentials)
# save provider
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_record:
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(ProviderType.CUSTOM)
def delete_custom_credentials(self) -> None:
"""
Delete custom provider credentials.
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# delete provider
if provider_record:
self.switch_preferred_provider_type(ProviderType.SYSTEM)
db.session.delete(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
"""
Get custom model credentials.
:param model_type: model type
:param model: model name
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if not self.custom_configuration.models:
return None
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
if not obfuscated:
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
if provider_model_record:
try:
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# # if send [__HIDDEN__] in secret input, it will be same as original value
# if value == '[__HIDDEN__]' and key in original_credentials:
# credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
)
# for key, value in credentials.items():
# if key in provider_credential_secret_variables:
# credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_model_record, credentials
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
"""
Add or update custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# validate custom model config
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
# save provider model
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
"""
Delete custom model credentials.
:param model_type: model type
:param model: model name
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# delete provider model
if provider_model_record:
db.session.delete(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
:param model_type: model type
:return:
"""
# Get provider instance
provider_instance = self.get_provider_instance()
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
"""
Switch preferred provider type.
:param provider_type:
:return:
"""
if provider_type == self.preferred_provider_type:
return
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider
).first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value
)
db.session.add(preferred_model_provider)
db.session.commit()
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
"""
Obfuscated credentials.
:param credentials: credentials
:param credential_form_schemas: credential form schemas
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self._extract_secret_variables(
credential_form_schemas
)
# Obfuscate provider credentials
# copy_credentials = credentials.copy()
# for key, value in copy_credentials.items():
# if key in credential_secret_variables:
# copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
:param model: model name
:param only_active: return active model only
:return:
"""
provider_models = self.get_provider_models(model_type, only_active)
for provider_model in provider_models:
if provider_model.model == model:
return provider_model
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
:param only_active: only active models
:return:
"""
provider_instance = self.get_provider_instance()
model_types = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
for model_type in model_types:
provider_models.extend(
[
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
for m in provider_instance.models(model_type)
]
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
credentials = None
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
for model_type in model_types:
if model_type not in self.provider.supported_model_types:
continue
models = provider_instance.models(model_type)
for m in models:
provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
)
)
# custom models
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
return provider_models
class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
"""
Get available models.
If preferred provider type is `system`:
Get the current **system mode** if provider supported,
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
If there is no model configured in custom mode, it is treated as no_configure.
system > custom > no_configure
If preferred provider type is `custom`:
If custom credentials are configured, it is treated as custom mode.
Otherwise, get the current **system mode** if supported,
If all system modes are not available (no quota), it is treated as no_configure.
custom > system > no_configure
If real mode is `system`, use system credentials to get models,
paid quotas > provider free quotas > system free quotas
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
If real mode is `custom`, use workspace custom credentials to get models,
include pre-defined models, custom models(manual append).
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
model status marked as `active` is available.
:param provider: provider name
:param model_type: model type
:param only_active: only active models
:return:
"""
all_models = []
for provider_configuration in self.values():
if provider and provider_configuration.provider.provider != provider:
continue
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
return all_models
def to_list(self) -> list[ProviderConfiguration]:
"""
Convert to list.
:return:
"""
return list(self.values())
def __getitem__(self, key):
return self.configurations[key]
def __setitem__(self, key, value):
self.configurations[key] = value
def __iter__(self):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
def get(self, key, default=None):
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True

View File

@ -0,0 +1,74 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.models.provider import ProviderQuotaType
class QuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
CREDITS = 'credits'
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = 'active'
QUOTA_EXCEEDED = 'quota-exceeded'
UNSUPPORTED = 'unsupported'
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
credentials: dict
class CustomModelConfiguration(BaseModel):
"""
Model class for provider custom model configuration.
"""
model: str
model_type: ModelType
credentials: dict
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []

View File

@ -0,0 +1,133 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
class QueueEvent(Enum):
"""
QueueEvent enum
"""
MESSAGE = "message"
AGENT_MESSAGE = "agent_message"
MESSAGE_REPLACE = "message-replace"
MESSAGE_END = "message-end"
RETRIEVER_RESOURCES = "retriever-resources"
ANNOTATION_REPLY = "annotation-reply"
AGENT_THOUGHT = "agent-thought"
MESSAGE_FILE = "message-file"
ERROR = "error"
PING = "ping"
STOP = "stop"
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
"""
event: QueueEvent
class QueueMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.MESSAGE
chunk: LLMResultChunk
class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event = QueueEvent.MESSAGE_REPLACE
text: str
class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
class AnnotationReplyEvent(AppQueueEvent):
"""
AnnotationReplyEvent entity
"""
event = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
llm_result: LLMResult
class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.MESSAGE_FILE
message_file_id: str
class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event = QueueEvent.ERROR
error: Any
class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event = QueueEvent.PING
class QueueStopEvent(AppQueueEvent):
"""
QueueStopEvent entity
"""
class StopBy(Enum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
event = QueueEvent.STOP
stopped_by: StopBy
class QueueMessage(BaseModel):
"""
QueueMessage entity
"""
task_id: str
message_id: str
conversation_id: str
app_mode: str
event: AppQueueEvent

View File

@ -0,0 +1,38 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"

View File

@ -0,0 +1,90 @@
import enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.file.upload_file_parser import UploadFileParser
from model_providers.core.model_runtime.entities.message_entities import ImagePromptMessageContent
from model_providers.extensions.ext_database import db
from model_providers.models.model import UploadFile
class FileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
TOOL_FILE = 'tool_file'
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = 'user'
ASSISTANT = 'assistant'
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileObj(BaseModel):
id: Optional[str]
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str]
upload_file_id: Optional[str]
file_config: dict
@property
def data(self) -> Optional[str]:
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
return self._get_data(force_url=True)
@property
def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE:
image_config = self.file_config.get('image')
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.upload_file_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
)
return None

View File

@ -0,0 +1,184 @@
from typing import Optional, Union
import requests
from model_providers.core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
from model_providers.extensions.ext_database import db
from model_providers.models.account import Account
from model_providers.models.model import AppModelConfig, EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> list[FileObj]:
"""
validate and transform files arg
:param files:
:param app_model_config:
:param user:
:return:
"""
file_upload_config = app_model_config.file_upload_dict
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_upload_config.get('image')
# check if image file feature is enabled
if not image_config['enabled']:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}')
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.upload_file_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(IMAGE_EXTENSIONS)
).first())
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
"""
transform message files
:param files:
:param app_model_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
"""
transform files to file objs
:param files:
:param file_upload_config:
:return:
"""
type_file_objs: dict[FileType, list[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
if isinstance(file, MessageFile):
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue
file_obj = self._to_file_obj(file, file_upload_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
return FileObj(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
file_config=file_upload_config
)
else:
return FileObj(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
upload_file_id=file.upload_file_id or None,
file_config=file_upload_config
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

View File

@ -0,0 +1,8 @@
tool_file_manager = {
'manager': None
}
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> 'ToolFileManager':
return tool_file_manager['manager']

View File

@ -0,0 +1,79 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from flask import current_app
from model_providers.extensions.ext_storage import storage
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
return cls.get_signed_temp_image_url(upload_file)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
@classmethod
def get_signed_temp_image_url(cls, upload_file) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = current_app.config.get('FILES_URL')
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes

View File

@ -0,0 +1,51 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from model_providers.extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -0,0 +1,250 @@
from typing import Optional
from flask import Config, Flask
from pydantic import BaseModel
from model_providers.core.entities.provider_entities import QuotaUnit, RestrictModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.models.provider import ProviderQuotaType
class HostingQuota(BaseModel):
quota_type: ProviderQuotaType
restrict_models: list[RestrictModel] = []
class TrialHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
quota_limit: int = 0
"""Quota limit for the hosting provider models. -1 means unlimited."""
class PaidHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
class FreeHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
class HostingProvider(BaseModel):
enabled: bool = False
credentials: Optional[dict] = None
quota_unit: Optional[QuotaUnit] = None
quotas: list[HostingQuota] = []
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask) -> None:
config = app.config
if config.get('EDITION') != 'CLOUD':
return
self.provider_map["azure_openai"] = self.init_azure_openai(config)
self.provider_map["openai"] = self.init_openai(config)
self.provider_map["anthropic"] = self.init_anthropic(config)
self.provider_map["minimax"] = self.init_minimax(config)
self.provider_map["spark"] = self.init_spark(config)
self.provider_map["zhipuai"] = self.init_zhipuai(config)
self.moderation_config = self.init_moderation_config(config)
def init_azure_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
credentials = {
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}
quotas = []
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas = []
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
restrict_models=paid_models
)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
}
if app_config.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas = []
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit
)
quotas.append(trial_quota)
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
}
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_minimax(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_MINIMAX_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_spark(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_SPARK_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_zhipuai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
quotas = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
credentials=None, # use credentials from the provider
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_moderation_config(self, app_config: Config) -> HostedModerationConfig:
if app_config.get("HOSTED_MODERATION_ENABLED") \
and app_config.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True,
providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)
return HostedModerationConfig(
enabled=False
)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]

View File

@ -0,0 +1,257 @@
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.errors.error import ProviderTokenNotInitError
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from model_providers.core.provider_manager import ProviderManager
class ModelInstance:
"""
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
:param language: tts language
:return: tts model voices
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
)
class ModelManager:
def __init__(self) -> None:
self._provider_manager = ProviderManager()
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
if not provider:
return self.get_default_model_instance(tenant_id, model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
:param tenant_id: tenant id
:param model_type: model type
:return:
"""
default_model_entity = self._provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type
)
if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
return self.get_model_instance(
tenant_id=tenant_id,
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
)

View File

@ -0,0 +1,70 @@
# Model Runtime
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
## Features
- Supports capability invocation for 5 types of models
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
- `Rerank Model` - Segment Rerank capability
- `Speech-to-text Model` - Speech to text capability
- `Text-to-speech Model` - Text to speech capability
- `Moderation` - Moderation capability
- Model provider display
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md).
- Selectable model list display
![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
- Provider/model credential authentication
![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
## Structure
![](./docs/en_US/images/index/image-20231210165243632.png)
Model Runtime is divided into three layers:
- The outermost layer is the factory method
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
- The second layer is the provider layer
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
- The bottom layer is the model layer
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
- Implement interface methods: [Link](./docs/en_US/interfaces.md)

View File

@ -0,0 +1,89 @@
# Model Runtime
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
## 功能介绍
- 支持 5 种模型类型的能力调用
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
- `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
- `Rerank Model` - 分段 Rerank 能力
- `Speech-to-text Model` - 语音转文本能力
- `Text-to-speech Model` - 文本转语音能力
- `Moderation` - Moderation 能力
- 模型供应商展示
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
- 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
- 供应商/模型凭据鉴权
![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
## 结构
![](./docs/zh_Hans/images/index/image-20231210165243632.png)
Model Runtime 分三层:
- 最外层为工厂方法
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
- 第二层为供应商层
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况
- 如OpenAI这类中心化供应商需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
当配置好凭据后就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
- 最底层为模型层
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
在这里我们需要先区分模型参数与模型凭据。
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等这些参数是由用户在前端页面上进行调整的因此需要在后端定义参数的规则以便前端页面进行展示和调整。在DifyRuntime中他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中他们的参数名一般为**credentials: dict[str, any]**Provider层的credentials会直接被传递到这一层不需要再单独定义。
## 下一步
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
当添加后对应供应商的模型列表中将会出现一个新的预定义模型供用户选择如GPT-3.5 GPT-4 ChatGLM3-6b等而对于支持自定义模型的供应商则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。

View File

@ -0,0 +1,113 @@
from abc import ABC
from typing import Optional
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class Callback(ABC):
"""
Base class for callbacks.
Only for LLM.
"""
raise_error: bool = False
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
def print_text(
self, text: str, color: Optional[str] = None, end: str = ""
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -0,0 +1,133 @@
import json
import logging
import sys
from typing import Optional
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
self.print_text(f"Model: {model}\n", color='blue')
self.print_text("Parameters:\n", color='blue')
for key, value in model_parameters.items():
self.print_text(f"\t{key}: {value}\n", color='blue')
if stop:
self.print_text(f"\tstop: {stop}\n", color='blue')
if tools:
self.print_text("\tTools:\n", color='blue')
for tool in tools:
self.print_text(f"\t\t{tool.name}\n", color='blue')
self.print_text(f"Stream: {stream}\n", color='blue')
if user:
self.print_text(f"User: {user}\n", color='blue')
self.print_text("Prompt messages:\n", color='blue')
for prompt_message in prompt_messages:
if prompt_message.name:
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
if stream:
self.print_text("\n[on_llm_new_chunk]")
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
sys.stdout.write(chunk.delta.message.content)
sys.stdout.flush()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
self.print_text(f"Content: {result.message.content}\n", color='yellow')
if result.message.tool_calls:
self.print_text("Tool calls:\n", color='yellow')
for tool_call in result.message.tool_calls:
self.print_text(f"\t{tool_call.id}\n", color='yellow')
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
self.print_text(f"Model: {result.model}\n", color='yellow')
self.print_text(f"Usage: {result.usage}\n", color='yellow')
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_invoke_error]\n", color='red')
logger.exception(ex)

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 541 KiB

View File

@ -0,0 +1,706 @@
# Interface Methods
This section describes the interface methods and parameter explanations that need to be implemented by providers and various model types.
## Provider
Inherit the `__base.model_provider.ModelProvider` base class and implement the following interfaces:
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
- `credentials` (object) Credential information
The parameters of credential information are defined by the `provider_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
## Model
Models are divided into 5 different types, each inheriting from different base classes and requiring the implementation of different methods.
All models need to uniformly implement the following 2 methods:
- Model Credential Verification
Similar to provider credential verification, this step involves verification for an individual model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
- Invocation Error Mapping Table
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
Runtime Errors:
- `InvokeConnectionError` Invocation connection error
- `InvokeServerUnavailableError` Invocation service provider unavailable
- `InvokeRateLimitError` Invocation reached rate limit
- `InvokeAuthorizationError` Invocation authorization failure
- `InvokeBadRequestError` Invocation parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
You can refer to OpenAI's `_invoke_error_mapping` for an example.
### LLM
Inherit the `__base.large_language_model.LargeLanguageModel` base class and implement the following interfaces:
- LLM Invocation
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) List of prompts
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
If the model is of the `Chat` type, it requires a list of elements such as [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) depending on the message.
- `model_parameters` (object) Model parameters
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] List of tools, equivalent to the `function` in `function calling`.
That is, the tool list for tool calling.
- `stop` (array[string]) [optional] Stop sequences
The model output will stop before the string defined by the stop sequence.
- `stream` (bool) Whether to output in a streaming manner, default is True
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
- Pre-calculating Input Tokens
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
For parameter explanations, refer to the above section on `LLM Invocation`.
- Fetch Custom Model Schema [Optional]
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
```
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
### TextEmbedding
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
- Embedding Invocation
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `texts` (array[string]) List of texts, capable of batch processing
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
[TextEmbeddingResult](#TextEmbeddingResult) entity.
- Pre-calculating Tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
```
For parameter explanations, refer to the above section on `Embedding Invocation`.
### Rerank
Inherit the `__base.rerank_model.RerankModel` base class and implement the following interfaces:
- Rerank Invocation
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `query` (string) Query request content
- `docs` (array[string]) List of segments to be reranked
- `score_threshold` (float) [optional] Score threshold
- `top_n` (int) [optional] Select the top n segments
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
[RerankResult](#RerankResult) entity.
### Speech2text
Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `file` (File) File stream
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
The string after speech-to-text conversion.
### Text2speech
Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
```
- Parameters
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `content_text` (string) The text content that needs to be converted
- `streaming` (bool) Whether to stream output
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns
Text converted speech stream。
### Moderation
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `text` (string) Text content
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
False indicates that the input text is safe, True indicates otherwise.
## Entities
### PromptMessageRole
Message role
```python
class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
```
### PromptMessageContentType
Message content types, divided into text and image.
```python
class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
```
### PromptMessageContent
Message content base class, used only for parameter declaration and cannot be initialized.
```python
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str
```
Currently, two types are supported: text and image. It's possible to simultaneously input text and multiple images.
You need to initialize `TextPromptMessageContent` and `ImagePromptMessageContent` separately for input.
### TextPromptMessageContent
```python
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
```
If inputting a combination of text and images, the text needs to be constructed into this entity as part of the `content` list.
### ImagePromptMessageContent
```python
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW # Resolution
```
If inputting a combination of text and images, the images need to be constructed into this entity as part of the `content` list.
`data` can be either a `url` or a `base64` encoded string of the image.
### PromptMessage
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
```python
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
name: Optional[str] = None
```
### UserPromptMessage
UserMessage message body, representing a user's message.
```python
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
```
### AssistantPromptMessage
Represents a message returned by the model, typically used for `few-shots` or inputting chat history.
```python
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str # tool name
arguments: str # tool arguments
id: str # Tool ID, effective only in OpenAI tool calls. It's the unique ID for tool invocation and the same tool can be called multiple times.
type: str # default: function
function: ToolCallFunction # tool call information
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
```
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
### SystemPromptMessage
Represents system messages, usually used for setting system commands given to the model.
```python
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
```
### ToolPromptMessage
Represents tool messages, used for conveying the results of a tool execution to the model for the next step of processing.
```python
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str # Tool invocation ID. If OpenAI tool call is not supported, the name of the tool can also be inputted.
```
The base class's `content` takes in the results of tool execution.
### PromptMessageTool
```python
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
```
---
### LLMResult
```python
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
message: AssistantPromptMessage # response message
usage: LLMUsage # usage info
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
```
### LLMResultChunkDelta
In streaming returns, each iteration contains the `delta` entity.
```python
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage # response message
usage: Optional[LLMUsage] = None # usage info
finish_reason: Optional[str] = None # finish reason, only the last one returns
```
### LLMResultChunk
Each iteration entity in streaming returns.
```python
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
delta: LLMResultChunkDelta
```
### LLMUsage
```python
class LLMUsage(ModelUsage):
"""
Model class for LLM usage.
"""
prompt_tokens: int # Tokens used for prompt
prompt_unit_price: Decimal # Unit price for prompt
prompt_price_unit: Decimal # Price unit for prompt, i.e., the unit price based on how many tokens
prompt_price: Decimal # Cost for prompt
completion_tokens: int # Tokens used for response
completion_unit_price: Decimal # Unit price for response
completion_price_unit: Decimal # Price unit for response, i.e., the unit price based on how many tokens
completion_price: Decimal # Cost for response
total_tokens: int # Total number of tokens used
total_price: Decimal # Total cost
currency: str # Currency unit
latency: float # Request latency (s)
```
---
### TextEmbeddingResult
```python
class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str # Actual model used
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
usage: EmbeddingUsage # Usage information
```
### EmbeddingUsage
```python
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int # Number of tokens used
total_tokens: int # Total number of tokens used
unit_price: Decimal # Unit price
price_unit: Decimal # Price unit, i.e., the unit price based on how many tokens
total_price: Decimal # Total cost
currency: str # Currency unit
latency: float # Request latency (s)
```
---
### RerankResult
```python
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str # Actual model used
docs: list[RerankDocument] # Reranked document list
```
### RerankDocument
```python
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int # original index
text: str
score: float
```

View File

@ -0,0 +1,265 @@
## Adding a New Provider
Providers support three types of model configuration methods:
- `predefined-model` Predefined model
This indicates that users only need to configure the unified provider credentials to use the predefined models under the provider.
- `customizable-model` Customizable model
Users need to add credential configurations for each model.
- `fetch-from-remote` Fetch from remote
This is consistent with the `predefined-model` configuration method. Only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
These three configuration methods **can coexist**, meaning a provider can support `predefined-model` + `customizable-model` or `predefined-model` + `fetch-from-remote`, etc. In other words, configuring the unified provider credentials allows the use of predefined and remotely fetched models, and if new models are added, they can be used in addition to the custom models.
## Getting Started
Adding a new provider starts with determining the English identifier of the provider, such as `anthropic`, and using this identifier to create a `module` in `model_providers`.
Under this `module`, we first need to prepare the provider's YAML configuration.
### Preparing Provider YAML
Here, using `Anthropic` as an example, we preset the provider's basic information, supported model types, configuration methods, and credential rules.
```YAML
provider: anthropic # Provider identifier
label: # Provider display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
en_US: Anthropic
icon_small: # Small provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
en_US: icon_s_en.png
icon_large: # Large provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
en_US: icon_l_en.png
supported_model_types: # Supported model types, Anthropic only supports LLM
- llm
configurate_methods: # Supported configuration methods, Anthropic only supports predefined models
- predefined-model
provider_credential_schema: # Provider credential rules, as Anthropic only supports predefined models, unified provider credential rules need to be defined
credential_form_schemas: # List of credential form items
- variable: anthropic_api_key # Credential parameter variable name
label: # Display name
en_US: API Key
type: secret-input # Form type, here secret-input represents an encrypted information input box, showing masked information when editing.
required: true # Whether required
placeholder: # Placeholder information
zh_Hans: Enter your API Key here
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input # Form type, here text-input represents a text input box
required: false
placeholder:
zh_Hans: Enter your API URL here
en_US: Enter your API URL
```
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
### Implementing Provider Code
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
---
### Adding Models
After the provider integration is complete, the next step is to integrate models under the provider.
First, we need to determine the type of the model to be integrated and create a `module` for the corresponding model type in the provider's directory.
The currently supported model types are as follows:
- `llm` Text generation model
- `text_embedding` Text Embedding model
- `rerank` Rerank model
- `speech2text` Speech to text
- `tts` Text to speech
- `moderation` Moderation
Continuing with `Anthropic` as an example, since `Anthropic` only supports LLM, we create a `module` named `llm` in `model_providers.anthropic`.
For predefined models, we first need to create a YAML file named after the model, such as `claude-2.1.yaml`, under the `llm` `module`.
#### Preparing Model YAML
```yaml
model: claude-2.1 # Model identifier
# Model display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
# Alternatively, if the label is not set, use the model identifier content.
label:
en_US: claude-2.1
model_type: llm # Model type, claude-2.1 is an LLM
features: # Supported features, agent-thought for Agent reasoning, vision for image understanding
- agent-thought
model_properties: # Model properties
mode: chat # LLM mode, complete for text completion model, chat for dialogue model
context_size: 200000 # Maximum supported context size
parameter_rules: # Model invocation parameter rules, only required for LLM
- name: temperature # Invocation parameter variable name
# Default preset with 5 variable content configuration templates: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# Directly set the template variable name in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
# If additional configuration parameters are set, they will override the default configuration
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # Invocation parameter display name
zh_Hans: Sampling quantity
en_US: Top k
type: int # Parameter type, supports float/int/string/boolean
help: # Help information, describing the role of the parameter
zh_Hans: Only sample from the top K options for each subsequent token.
en_US: Only sample from the top K options for each subsequent token.
required: false # Whether required, can be left unset
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # Default parameter value
min: 1 # Minimum parameter value, only applicable for float/int
max: 4096 # Maximum parameter value, only applicable for float/int
pricing: # Pricing information
input: '8.00' # Input price, i.e., Prompt price
output: '24.00' # Output price, i.e., returned content price
unit: '0.000001' # Pricing unit, i.e., the above prices are per 100K
currency: USD # Currency
```
It is recommended to prepare all model configurations before starting the implementation of the model code.
Similarly, you can also refer to the YAML configuration information for corresponding model types of other providers in the `model_providers` directory. The complete YAML rules can be found at: [Schema](schema.md#AIModel).
#### Implementing Model Invocation Code
Next, you need to create a python file named `llm.py` under the `llm` `module` to write the implementation code.
In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguageModel` (arbitrarily), inheriting the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Invocation
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- Pre-calculating Input Tokens
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- Model Credential Verification
Similar to provider credential verification, this step involves verification for an individual model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- Invocation Error Mapping Table
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
Runtime Errors:
- `InvokeConnectionError` Invocation connection error
- `InvokeServerUnavailableError` Invocation service provider unavailable
- `InvokeRateLimitError` Invocation reached rate limit
- `InvokeAuthorizationError` Invocation authorization failure
- `InvokeBadRequestError` Invocation parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For details on the interface methods, see: [Interfaces](interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
### Testing
To ensure the availability of integrated providers/models, each method written needs corresponding integration test code in the `tests` directory.
Continuing with `Anthropic` as an example:
Before writing test code, you need to first add the necessary credential environment variables for the test provider in `.env.example`, such as: `ANTHROPIC_API_KEY`.
Before execution, copy `.env.example` to `.env` and then execute.
#### Writing Test Code
Create a `module` with the same name as the provider in the `tests` directory: `anthropic`, and continue to create `test_provider.py` and test py files for the corresponding model types within this module, as shown below:
```shell
.
├── __init__.py
├── anthropic
│   ├── __init__.py
│   ├── test_llm.py # LLM Testing
│   └── test_provider.py # Provider Testing
```
Write test code for all the various cases implemented above and submit the code after passing the tests.

View File

@ -0,0 +1,203 @@
# Configuration Rules
- Provider rules are based on the [Provider](#Provider) entity.
- Model rules are based on the [AIModelEntity](#AIModelEntity) entity.
> All entities mentioned below are based on `Pydantic BaseModel` and can be found in the `entities` module.
### Provider
- `provider` (string) Provider identifier, e.g., `openai`
- `label` (object) Provider display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
- `zh_Hans` (string) [optional] Chinese label name, if `zh_Hans` is not set, `en_US` will be used by default.
- `en_US` (string) English label name
- `description` (object) Provider description, i18n
- `zh_Hans` (string) [optional] Chinese description
- `en_US` (string) English description
- `icon_small` (string) [optional] Small provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
- `zh_Hans` (string) Chinese ICON
- `en_US` (string) English ICON
- `icon_large` (string) [optional] Large provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
- `zh_Hans` (string) Chinese ICON
- `en_US` (string) English ICON
- `background` (string) [optional] Background color value, e.g., #FFFFFF, if empty, the default frontend color value will be displayed.
- `help` (object) [optional] help information
- `title` (object) help title, i18n
- `zh_Hans` (string) [optional] Chinese title
- `en_US` (string) English title
- `url` (object) help link, i18n
- `zh_Hans` (string) [optional] Chinese link
- `en_US` (string) English link
- `supported_model_types` (array[[ModelType](#ModelType)]) Supported model types
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) Configuration methods
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
### AIModelEntity
- `model` (string) Model identifier, e.g., `gpt-3.5-turbo`
- `label` (object) [optional] Model display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
- `zh_Hans` (string) [optional] Chinese label name
- `en_US` (string) English label name
- `model_type` ([ModelType](#ModelType)) Model type
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] Supported feature list
- `model_properties` (object) Model properties
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
- `default_voice` (string) default voice, e.g.alloy,echo,fable,onyx,nova,shimmeravailable for model type `tts`
- `voices` (list) List of available voice.available for model type `tts`
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `lanuage` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraphwise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
### ModelType
- `llm` Text generation model
- `text-embedding` Text Embedding model
- `rerank` Rerank model
- `speech2text` Speech to text
- `tts` Text to speech
- `moderation` Moderation
### ConfigurateMethod
- `predefined-model` Predefined model
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
- `customizable-model` Customizable model
Users need to add credential configuration for each model.
- `fetch-from-remote` Fetch from remote
Consistent with the `predefined-model` configuration method, only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
### ModelFeature
- `agent-thought` Agent reasoning, generally over 70B with thought chain capability.
- `vision` Vision, i.e., image understanding.
### FetchFrom
- `predefined-model` Predefined model
- `fetch-from-remote` Remote model
### LLMMode
- `complete` Text completion
- `chat` Dialogue
### ParameterRule
- `name` (string) Actual model invocation parameter name
- `use_template` (string) [optional] Using template
By default, 5 variable content configuration templates are preset:
- `temperature`
- `top_p`
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
- `label` (object) [optional] Label, i18n
- `zh_Hans`(string) [optional] Chinese label name
- `en_US` (string) English label name
- `type`(string) [optional] Parameter type
- `int` Integer
- `float` Float
- `string` String
- `boolean` Boolean
- `help` (string) [optional] Help information
- `zh_Hans` (string) [optional] Chinese help information
- `en_US` (string) English help information
- `required` (bool) Required, default False.
- `default`(int/float/string/bool) [optional] Default value
- `min`(int/float) [optional] Minimum value, applicable only to numeric types
- `max`(int/float) [optional] Maximum value, applicable only to numeric types
- `precision`(int) [optional] Precision, number of decimal places to keep, applicable only to numeric types
- `options` (array[string]) [optional] Dropdown option values, applicable only when `type` is `string`, if not set or null, option values are not restricted
### PriceConfig
- `input` (float) Input price, i.e., Prompt price
- `output` (float) Output price, i.e., returned content price
- `unit` (float) Pricing unit, e.g., per 100K price is `0.000001`
- `currency` (string) Currency unit
### ProviderCredentialSchema
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
### ModelCredentialSchema
- `model` (object) Model identifier, variable name defaults to `model`
- `label` (object) Model form item display name
- `en_US` (string) English
- `zh_Hans`(string) [optional] Chinese
- `placeholder` (object) Model prompt content
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
### CredentialFormSchema
- `variable` (string) Form item variable name
- `label` (object) Form item label name
- `en_US`(string) English
- `zh_Hans` (string) [optional] Chinese
- `type` ([FormType](#FormType)) Form item type
- `required` (bool) Whether required
- `default`(string) Default value
- `options` (array[[FormOption](#FormOption)]) Specific property of form items of type `select` or `radio`, defining dropdown content
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
- `en_US`(string) English
- `zh_Hans` (string) [optional] Chinese
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormType
- `text-input` Text input component
- `secret-input` Password input component
- `select` Single-choice dropdown
- `radio` Radio component
- `switch` Switch component, only supports `true` and `false` values
### FormOption
- `label` (object) Label
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `value` (string) Dropdown option value
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormShowOnObject
- `variable` (string) Variable name of other form items
- `value` (string) Variable value of other form items

View File

@ -0,0 +1,297 @@
## 自定义预定义模型接入
### 介绍
供应商集成完成后,接下来为供应商下模型的接入,为了帮助理解整个接入过程,我们以`Xinference`为例,逐步完成一个完整的供应商接入。
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
而不同于预定义模型自定义供应商接入时永远会拥有如下两个参数不需要在供应商yaml中定义。
![Alt text](images/index/image-3.png)
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
### 编写供应商yaml
我们首先要确定,接入的这个供应商支持哪些类型的模型。
当前支持模型类型如下:
- `llm` 文本生成模型
- `text_embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
`Xinference`支持`LLM``Text Embedding`和Rerank那么我们开始编写`xinference.yaml`
```yaml
provider: xinference #确定供应商标识
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
en_US: Xorbits Inference
icon_small: # 小图标,可以参考其他供应商的图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_s_en.svg
icon_large: # 大图标
en_US: icon_l_en.svg
help: # 帮助
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # 支持的模型类型Xinference同时支持LLM/Text Embedding/Rerank
- llm
- text-embedding
- rerank
configurate_methods: # 因为Xinference为本地部署的供应商并且没有预定义模型需要用什么模型需要根据Xinference的文档自己部署所以这里只支持自定义模型
- customizable-model
provider_credential_schema:
credential_form_schemas:
```
随后我们需要思考在Xinference中定义一个模型需要哪些凭据
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
```yaml
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 语言模型
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
```
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
```yaml
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: 模型名称
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
```
- 填写Xinference本地部署的地址
```yaml
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- 每个模型都有唯一的model_uid因此需要在这里定义
```yaml
- variable: model_uid
label:
zh_Hans: 模型UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid
```
现在,我们就完成了供应商的基础定义。
### 编写模型代码
然后我们以`llm`类型为例,编写`xinference.llm.llm.py`
`llm.py` 中创建一个 Xinference LLM 类,我们取名为 `XinferenceAILargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
有时候也许你不需要直接返回0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens这个方法位于`AIModel`基类中它会使用GPT2的Tokenizer进行计算但是只能作为替代方法并不完全准确。
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- 模型参数Schema
与自定义类型不同由于没有在yaml文件中定义一个模型支持哪些参数因此我们需要动态时间模型参数的Schema。
如Xinference支持`max_tokens` `temperature` `top_p` 这三个模型参数。
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`我们这里举例A模型支持`top_k`B模型不支持`top_k`那么我们需要在这里动态生成模型参数的Schema如下所示
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 541 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

View File

@ -0,0 +1,746 @@
# 接口方法
这里介绍供应商和各模型类型需要实现的接口方法和参数说明。
## 供应商
继承 `__base.model_provider.ModelProvider` 基类,实现以下接口:
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 定义,传入如:`api_key` 等。
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
**注:预定义模型需完整实现该接口,自定义模型供应商只需要如下简单实现即可**
```python
class XinferenceProvider(Provider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
```
## 模型
模型分为 5 种不同的模型类型,不同模型类型继承的基类不同,需要实现的方法也不同。
### 通用接口
所有模型均需要统一实现下面 2 个方法:
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
也可以直接抛出对应Erros并做如下定义这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
}
```
可参考 OpenAI `_invoke_error_mapping`
### LLM
继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下接口:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) Prompt 列表
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
- `model_parameters` (object) 模型参数
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] 工具列表,等同于 `function calling` 中的 `function`
即传入 tool calling 的工具列表。
- `stop` (array[string]) [optional] 停止序列
模型返回将在停止序列定义的字符串之前停止输出。
- `stream` (bool) 是否流式输出,默认 True
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
参数说明见上述 `LLM 调用`
该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
- 获取自定义模型规则 [可选]
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
```
​当供应商支持增加自定义 LLM 时,可实现此方法让自定义模型可获取模型规则,默认返回 None。
对于`OpenAI`供应商下的大部分微调模型,可以通过其微调模型名称获取到其基类模型,如`gpt-3.5-turbo-1106`,然后返回基类模型的预定义参数规则,参考[openai](https://github.com/langgenius/dify/blob/feat/model-runtime/api/core/model_runtime/model_providers/openai/llm/llm.py#L801)
的具体实现
### TextEmbedding
继承 `__base.text_embedding_model.TextEmbeddingModel` 基类,实现以下接口:
- Embedding 调用
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `texts` (array[string]) 文本列表,可批量处理
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
[TextEmbeddingResult](#TextEmbeddingResult) 实体。
- 预计算 tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
```
参数说明见上述 `Embedding 调用`
同上述`LargeLanguageModel`,该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
### Rerank
继承 `__base.rerank_model.RerankModel` 基类,实现以下接口:
- rerank 调用
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `query` (string) 查询请求内容
- `docs` (array[string]) 需要重排的分段列表
- `score_threshold` (float) [optional] Score 阈值
- `top_n` (int) [optional] 取前 n 个分段
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
[RerankResult](#RerankResult) 实体。
### Speech2text
继承 `__base.speech2text_model.Speech2TextModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `file` (File) 文件流
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
语音转换后的字符串。
### Text2speech
继承 `__base.text2speech_model.Text2SpeechModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `content_text` (string) 需要转换的文本内容
- `streaming` (bool) 是否进行流式输出
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
文本转换后的语音流。
### Moderation
继承 `__base.moderation_model.ModerationModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema``model_credential_schema` 定义,传入如:`api_key` 等。
- `text` (string) 文本内容
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
False 代表传入的文本安全True 则反之。
## 实体
### PromptMessageRole
消息角色
```python
class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
```
### PromptMessageContentType
消息内容类型,分为纯文本和图片。
```python
class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
```
### PromptMessageContent
消息内容基类,仅作为参数声明用,不可初始化。
```python
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str # 内容数据
```
当前支持文本和图片两种类型,可支持同时传入文本和多图。
需要分别初始化 `TextPromptMessageContent``ImagePromptMessageContent` 传入。
### TextPromptMessageContent
```python
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
```
若传入图文,其中文字需要构造此实体作为 `content` 列表中的一部分。
### ImagePromptMessageContent
```python
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW # 分辨率
```
若传入图文,其中图片需要构造此实体作为 `content` 列表中的一部分
`data` 可以为 `url` 或者图片 `base64` 加密后的字符串。
### PromptMessage
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
```python
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole # 消息角色
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
name: Optional[str] = None # 名称,可选。
```
### UserPromptMessage
UserMessage 消息体,代表用户消息。
```python
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
```
### AssistantPromptMessage
代表模型返回消息,通常用于 `few-shots` 或聊天历史传入。
```python
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str # 工具名称
arguments: str # 工具参数
id: str # 工具 ID仅在 OpenAI tool call 生效,为工具调用的唯一 ID同一个工具可以调用多次
type: str # 默认 function
function: ToolCallFunction # 工具调用信息
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools并且模型认为需要调用工具时返回
```
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
### SystemPromptMessage
代表系统消息,通常用于设定给模型的系统指令。
```python
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
```
### ToolPromptMessage
代表工具消息,用于工具执行后将结果交给模型进行下一步计划。
```python
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str # 工具调用 ID若不支持 OpenAI tool call也可传入工具名称
```
基类的 `content` 传入工具执行结果。
### PromptMessageTool
```python
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str # 工具名称
description: str # 工具描述
parameters: dict # 工具参数 dict
```
---
### LLMResult
```python
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
message: AssistantPromptMessage # 回复消息
usage: LLMUsage # 使用的 tokens 及费用信息
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
```
### LLMResultChunkDelta
流式返回中每个迭代内部 `delta` 实体
```python
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int # 序号
message: AssistantPromptMessage # 回复消息
usage: Optional[LLMUsage] = None # 使用的 tokens 及费用信息,仅最后一条返回
finish_reason: Optional[str] = None # 结束原因,仅最后一条返回
```
### LLMResultChunk
流式返回中每个迭代实体
```python
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
```
### LLMUsage
```python
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int # prompt 使用 tokens
prompt_unit_price: Decimal # prompt 单价
prompt_price_unit: Decimal # prompt 价格单位,即单价基于多少 tokens
prompt_price: Decimal # prompt 费用
completion_tokens: int # 回复使用 tokens
completion_unit_price: Decimal # 回复单价
completion_price_unit: Decimal # 回复价格单位,即单价基于多少 tokens
completion_price: Decimal # 回复费用
total_tokens: int # 总使用 token 数
total_price: Decimal # 总费用
currency: str # 货币单位
latency: float # 请求耗时(s)
```
---
### TextEmbeddingResult
```python
class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str # 实际使用模型
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
usage: EmbeddingUsage # 使用信息
```
### EmbeddingUsage
```python
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int # 使用 token 数
total_tokens: int # 总使用 token 数
unit_price: Decimal # 单价
price_unit: Decimal # 价格单位,即单价基于多少 tokens
total_price: Decimal # 总费用
currency: str # 货币单位
latency: float # 请求耗时(s)
```
---
### RerankResult
```python
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str # 实际使用模型
docs: list[RerankDocument] # 重排后的分段列表
```
### RerankDocument
```python
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int # 原序号
text: str # 分段文本内容
score: float # 分数
```

View File

@ -0,0 +1,172 @@
## 预定义模型接入
供应商集成完成后,接下来为供应商下模型的接入。
我们首先需要确定接入模型的类型,并在对应供应商的目录下创建对应模型类型的 `module`
当前支持模型类型如下:
- `llm` 文本生成模型
- `text_embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
依旧以 `Anthropic` 为例,`Anthropic` 仅支持 LLM因此在 `model_providers.anthropic` 创建一个 `llm` 为名称的 `module`
对于预定义的模型,我们首先需要在 `llm` `module` 下创建以模型名为文件名称的 YAML 文件,如:`claude-2.1.yaml`
### 准备模型 YAML
```yaml
model: claude-2.1 # 模型标识
# 模型展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
# 也可不设置 label则使用 model 标识内容。
label:
en_US: claude-2.1
model_type: llm # 模型类型claude-2.1 为 LLM
features: # 支持功能agent-thought 为支持 Agent 推理vision 为支持图片理解
- agent-thought
model_properties: # 模型属性
mode: chat # LLM 模式complete 文本补全模型chat 对话模型
context_size: 200000 # 支持最大上下文大小
parameter_rules: # 模型调用参数规则,仅 LLM 需要提供
- name: temperature # 调用参数变量名
# 默认预置了 5 种变量内容配置模板temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# 可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
# 若设置了额外的配置参数,将覆盖默认配置
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # 调用参数展示名称
zh_Hans: 取样数量
en_US: Top k
type: int # 参数类型,支持 float/int/string/boolean
help: # 帮助信息,描述参数作用
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false # 是否必填,可不设置
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # 参数默认值
min: 1 # 参数最小值,仅 float/int 可用
max: 4096 # 参数最大值,仅 float/int 可用
pricing: # 价格信息
input: '8.00' # 输入单价,即 Prompt 单价
output: '24.00' # 输出单价,即返回内容单价
unit: '0.000001' # 价格单位,即上述价格为每 100K 的单价
currency: USD # 价格货币
```
建议将所有模型配置都准备完毕后再开始模型代码的实现。
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#AIModel)。
### 实现模型调用代码
接下来需要在 `llm` `module` 下创建一个同名的 python 文件 `llm.py` 来编写代码实现。
`llm.py` 中创建一个 Anthropic LLM 类,我们取名为 `AnthropicLargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。

View File

@ -0,0 +1,188 @@
## 增加新供应商
供应商支持三种模型配置方式:
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置如Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
- `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
如OpenAI我们可以基于gpt-turbo-3.5来Fine Tune多个模型而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让DifyRuntime获取到开发者所有的微调模型并接入Dify。
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
## 开始
### 介绍
#### 名词解释
- `module`: 一个`module`即为一个Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
#### 步骤
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
- 创建供应商yaml文件根据[ProviderSchema](./schema.md#provider)编写
- 创建供应商代码,实现一个`class`
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm``text_embedding`
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`
- 如果有预定义模型根据模型名称创建同名的yaml文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
- 编写测试代码,确保功能可用。
### 开始吧
增加一个新的供应商需要先确定供应商的英文标识,如 `anthropic`,使用该标识在 `model_providers` 创建以此为名称的 `module`
在此 `module` 下,我们需要先准备供应商的 YAML 配置。
#### 准备供应商 YAML
此处以 `Anthropic` 为例,预设了供应商基础信息、支持的模型类型、配置方式、凭据规则。
```YAML
provider: anthropic # 供应商标识
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
en_US: Anthropic
icon_small: # 供应商小图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_s_en.png
icon_large: # 供应商大图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_l_en.png
supported_model_types: # 支持的模型类型Anthropic 仅支持 LLM
- llm
configurate_methods: # 支持的配置方式Anthropic 仅支持预定义模型
- predefined-model
provider_credential_schema: # 供应商凭据规则,由于 Anthropic 仅支持预定义模型,则需要定义统一供应商凭据规则
credential_form_schemas: # 凭据表单项列表
- variable: anthropic_api_key # 凭据参数变量名
label: # 展示名称
en_US: API Key
type: secret-input # 表单类型,此处 secret-input 代表加密信息输入框,编辑时只展示屏蔽后的信息。
required: true # 是否必填
placeholder: # PlaceHolder 信息
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input # 表单类型,此处 text-input 代表文本输入框
required: false
placeholder:
zh_Hans: 在此输入您的 API URL
en_US: Enter your API URL
```
如果接入的供应商提供自定义模型,比如`OpenAI`提供微调模型,那么我们就需要添加[`model_credential_schema`](./schema.md#modelcredentialschema),以`OpenAI`为例:
```yaml
model_credential_schema:
model: # 微调模型名称
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: openai_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: openai_organization
label:
zh_Hans: 组织 ID
en_US: Organization
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的组织 ID
en_US: Enter your Organization ID
- variable: openai_api_base
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base
```
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
#### 实现供应商代码
我们需要在`model_providers`下创建一个同名的python文件`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`
##### 自定义模型供应商
当供应商为Xinference等自定义模型供应商时可跳过该步骤仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
```python
class XinferenceProvider(Provider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
```
##### 预定义模型供应商
供应商需要继承 `__base.model_provider.ModelProvider` 基类,实现 `validate_provider_credentials` 供应商统一凭据校验方法即可,可参考 [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py)。
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
当然也可以先预留 `validate_provider_credentials` 实现,在模型凭据校验方法实现后直接复用。
#### 增加模型
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
对于预定义模型我们可以通过简单定义一个yaml并通过实现调用代码来接入。
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
---
### 测试
为了保证接入供应商/模型的可用性,编写后的每个方法均需要在 `tests` 目录中编写对应的集成测试代码。
依旧以 `Anthropic` 为例。
在编写测试代码前,需要先在 `.env.example` 新增测试供应商所需要的凭据环境变量,如:`ANTHROPIC_API_KEY`
在执行前需要将 `.env.example` 复制为 `.env` 再执行。
#### 编写测试代码
`tests` 目录下创建供应商同名的 `module`: `anthropic`,继续在此模块中创建 `test_provider.py` 以及对应模型类型的 test py 文件,如下所示:
```shell
.
├── __init__.py
├── anthropic
│   ├── __init__.py
│   ├── test_llm.py # LLM 测试
│   └── test_provider.py # 供应商测试
```
针对上面实现的代码的各种情况进行测试代码编写,并测试通过后提交代码。

View File

@ -0,0 +1,205 @@
# 配置规则
- 供应商规则基于 [Provider](#Provider) 实体。
- 模型规则基于 [AIModelEntity](#AIModelEntity) 实体。
> 以下所有实体均基于 `Pydantic BaseModel`,可在 `entities` 模块中找到对应实体。
### Provider
- `provider` (string) 供应商标识,如:`openai`
- `label` (object) 供应商展示名称i18n可设置 `en_US` 英文、`zh_Hans` 中文两种语言
- `zh_Hans ` (string) [optional] 中文标签名,`zh_Hans` 不设置将默认使用 `en_US`
- `en_US` (string) 英文标签名
- `description` (object) [optional] 供应商描述i18n
- `zh_Hans` (string) [optional] 中文描述
- `en_US` (string) 英文描述
- `icon_small` (string) [optional] 供应商小 ICON存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
- `zh_Hans` (string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `icon_large` (string) [optional] 供应商大 ICON存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
- `zh_Hans `(string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
- `help` (object) [optional] 帮助信息
- `title` (object) 帮助标题i18n
- `zh_Hans` (string) [optional] 中文标题
- `en_US` (string) 英文标题
- `url` (object) 帮助链接i18n
- `zh_Hans` (string) [optional] 中文链接
- `en_US` (string) 英文链接
- `supported_model_types` (array[[ModelType](#ModelType)]) 支持的模型类型
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) 配置方式
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
### AIModelEntity
- `model` (string) 模型标识,如:`gpt-3.5-turbo`
- `label` (object) [optional] 模型展示名称i18n可设置 `en_US` 英文、`zh_Hans` 中文两种语言
- `zh_Hans `(string) [optional] 中文标签名
- `en_US` (string) 英文标签名
- `model_type` ([ModelType](#ModelType)) 模型类型
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] 支持功能列表
- `model_properties` (object) 模型属性
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
- `file_upload_limit` (int) 文件最大上传限制单位MB。模型类型 `speech2text` 可用)
- `supported_file_extensions` (string) 支持文件扩展格式mp3,mp4模型类型 `speech2text` 可用)
- `default_voice` (string) 缺省音色必选alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式mp3,wav模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
### ModelType
- `llm` 文本生成模型
- `text-embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
### ConfigurateMethod
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置。
- `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
### ModelFeature
- `agent-thought` Agent 推理,一般超过 70B 有思维链能力。
- `vision` 视觉,即:图像理解。
### FetchFrom
- `predefined-model` 预定义模型
- `fetch-from-remote` 远程模型
### LLMMode
- `completion` 文本补全
- `chat` 对话
### ParameterRule
- `name` (string) 调用模型实际参数名
- `use_template` (string) [optional] 使用模板
默认预置了 5 种变量内容配置模板:
- `temperature`
- `top_p`
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
不用设置除 `name``use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
可参考 `openai/llm/gpt-3.5-turbo.yaml`
- `label` (object) [optional] 标签i18n
- `zh_Hans`(string) [optional] 中文标签名
- `en_US` (string) 英文标签名
- `type`(string) [optional] 参数类型
- `int` 整数
- `float` 浮点数
- `string` 字符串
- `boolean` 布尔型
- `help` (string) [optional] 帮助信息
- `zh_Hans` (string) [optional] 中文帮助信息
- `en_US` (string) 英文帮助信息
- `required` (bool) 是否必填,默认 False。
- `default`(int/float/string/bool) [optional] 默认值
- `min`(int/float) [optional] 最小值,仅数字类型适用
- `max`(int/float) [optional] 最大值,仅数字类型适用
- `precision`(int) [optional] 精度,保留小数位数,仅数字类型适用
- `options` (array[string]) [optional] 下拉选项值,仅当 `type``string` 时适用,若不设置或为 null 则不限制选项值
### PriceConfig
- `input` (float) 输入单价,即 Prompt 单价
- `output` (float) 输出单价,即返回内容单价
- `unit` (float) 价格单位,如:每 100K 的单价为 `0.000001`
- `currency` (string) 货币单位
### ProviderCredentialSchema
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
### ModelCredentialSchema
- `model` (object) 模型标识,变量名默认 `model`
- `label` (object) 模型表单项展示名称
- `en_US` (string) 英文
- `zh_Hans`(string) [optional] 中文
- `placeholder` (object) 模型提示内容
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
### CredentialFormSchema
- `variable` (string) 表单项变量名
- `label` (object) 表单项标签名
- `en_US`(string) 英文
- `zh_Hans` (string) [optional] 中文
- `type` ([FormType](#FormType)) 表单项类型
- `required` (bool) 是否必填
- `default`(string) 默认值
- `options` (array[[FormOption](#FormOption)]) 表单项为 `select``radio` 专有属性,定义下拉内容
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
- `en_US`(string) 英文
- `zh_Hans` (string) [optional] 中文
- `max_length` (int) 表单项为`text-input`专有属性定义输入最大长度0 为不限制。
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormType
- `text-input` 文本输入组件
- `secret-input` 密码输入组件
- `select` 单选下拉
- `radio` Radio 组件
- `switch` 开关组件,仅支持 `true``false`
### FormOption
- `label` (object) 标签
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `value` (string) 下拉选项值
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormShowOnObject
- `variable` (string) 其他表单项变量名
- `value` (string) 其他表单项变量值

View File

@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
if not self.zh_Hans:
self.zh_Hans = self.en_US

View File

@ -0,0 +1,98 @@
from model_providers.core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
'label': {
'en_US': 'Temperature',
'zh_Hans': '温度',
},
'type': 'float',
'help': {
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
},
DefaultParameterName.TOP_P: {
'label': {
'en_US': 'Top P',
'zh_Hans': 'Top P',
},
'type': 'float',
'help': {
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
'zh_Hans': '通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。',
},
'required': False,
'default': 1.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
'label': {
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
},
DefaultParameterName.MAX_TOKENS: {
'label': {
'en_US': 'Max Tokens',
'zh_Hans': '最大标记',
},
'type': 'int',
'help': {
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记这些标记在提示和完成之间共享。',
},
'required': False,
'default': 64,
'min': 1,
'max': 2048,
'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
}
}

View File

@ -0,0 +1,102 @@
from decimal import Decimal
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
class LLMMode(Enum):
"""
Enum class for large language model mode.
"""
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> 'LLMMode':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int
prompt_unit_price: Decimal
prompt_price_unit: Decimal
prompt_price: Decimal
completion_tokens: int
completion_unit_price: Decimal
completion_price_unit: Decimal
completion_price: Decimal
total_tokens: int
total_price: Decimal
currency: str
latency: float
@classmethod
def empty_usage(cls):
return cls(
prompt_tokens=0,
prompt_unit_price=Decimal('0.0'),
prompt_price_unit=Decimal('0.0'),
prompt_price=Decimal('0.0'),
completion_tokens=0,
completion_unit_price=Decimal('0.0'),
completion_price_unit=Decimal('0.0'),
completion_price=Decimal('0.0'),
total_tokens=0,
total_price=Decimal('0.0'),
currency='USD',
latency=0.0
)
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str
prompt_messages: list[PromptMessage]
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: Optional[str] = None
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage
usage: Optional[LLMUsage] = None
finish_reason: Optional[str] = None
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str
prompt_messages: list[PromptMessage]
system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta
class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
"""
tokens: int

View File

@ -0,0 +1,134 @@
from abc import ABC
from enum import Enum
from typing import Optional
from pydantic import BaseModel
class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@classmethod
def value_of(cls, value: str) -> 'PromptMessageRole':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt message type value {value}')
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
class PromptMessageFunction(BaseModel):
"""
Model class for prompt message function.
"""
type: str = 'function'
function: PromptMessageTool
class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
name: Optional[str] = None
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str
arguments: str
id: str
type: str
function: ToolCallFunction
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = []
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str

View File

@ -0,0 +1,210 @@
from decimal import Decimal
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum):
"""
Enum class for model type.
"""
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
TEXT2IMG = "text2img"
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
"""
Get model type from origin model type.
:return: model type
"""
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
return cls.LLM
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
return cls.RERANK
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
raise ValueError(f'invalid origin model type {origin_model_type}')
def to_origin_model_type(self) -> str:
"""
Get origin model type from model type.
:return: origin model type
"""
if self == self.LLM:
return 'text-generation'
elif self == self.TEXT_EMBEDDING:
return 'embeddings'
elif self == self.RERANK:
return 'reranking'
elif self == self.SPEECH2TEXT:
return 'speech2text'
elif self == self.TTS:
return 'tts'
elif self == self.MODERATION:
return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else:
raise ValueError(f'invalid model type {self}')
class FetchFrom(Enum):
"""
Enum class for fetch from.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(Enum):
"""
Enum class for llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
"""
Get parameter name from value.
:param value: parameter value
:return: parameter name
"""
for name in cls:
if name.value == value:
return name
raise ValueError(f'invalid parameter name {value}')
class ParameterType(Enum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
class ModelPropertyKey(Enum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
class ProviderModel(BaseModel):
"""
Model class for provider model.
"""
model: str
label: I18nObject
model_type: ModelType
features: Optional[list[ModelFeature]] = None
fetch_from: FetchFrom
model_properties: dict[ModelPropertyKey, Any]
deprecated: bool = False
class Config:
protected_namespaces = ()
class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""
name: str
use_template: Optional[str] = None
label: I18nObject
type: ParameterType
help: Optional[I18nObject] = None
required: bool = False
default: Optional[Any] = None
min: Optional[float] = None
max: Optional[float] = None
precision: Optional[int] = None
options: list[str] = []
class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""
input: Decimal
output: Optional[Decimal] = None
unit: Decimal
currency: str
class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None
class ModelUsage(BaseModel):
pass
class PriceType(Enum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
class PriceInfo(BaseModel):
"""
Model class for price info.
"""
unit_price: Decimal
unit: Decimal
total_amount: Decimal
currency: str

View File

@ -0,0 +1,149 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
class ConfigurateMethod(Enum):
"""
Enum class for configurate method of provider model.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(Enum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
RADIO = "radio"
SWITCH = "switch"
class FormShowOnObject(BaseModel):
"""
Model class for form show on.
"""
variable: str
value: str
class FormOption(BaseModel):
"""
Model class for form option.
"""
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
if not self.label:
self.label = I18nObject(
en_US=self.value
)
class CredentialFormSchema(BaseModel):
"""
Model class for credential form schema.
"""
variable: str
label: I18nObject
type: FormType
required: bool = True
default: Optional[str] = None
options: Optional[list[FormOption]] = None
placeholder: Optional[I18nObject] = None
max_length: int = 0
show_on: list[FormShowOnObject] = []
class ProviderCredentialSchema(BaseModel):
"""
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
class FieldModelSchema(BaseModel):
label: I18nObject
placeholder: Optional[I18nObject] = None
class ModelCredentialSchema(BaseModel):
"""
Model class for model credential schema.
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
models: list[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
"""
Model class for provider help.
"""
title: I18nObject
url: I18nObject
class ProviderEntity(BaseModel):
"""
Model class for provider.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
class Config:
protected_namespaces = ()
def to_simple_provider(self) -> SimpleProviderEntity:
"""
Convert to simple provider.
:return: simple provider
"""
return SimpleProviderEntity(
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models
)
class ProviderConfig(BaseModel):
"""
Model class for provider config.
"""
provider: str
credentials: dict

View File

@ -0,0 +1,18 @@
from pydantic import BaseModel
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int
text: str
score: float
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str
docs: list[RerankDocument]

View File

@ -0,0 +1,28 @@
from decimal import Decimal
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int
total_tokens: int
unit_price: Decimal
price_unit: Decimal
total_price: Decimal
currency: str
latency: float
class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -0,0 +1,37 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""
description = "Connection Error"
class InvokeServerUnavailableError(InvokeError):
"""Raised when the Invoke returns server unavailable error."""
description = "Server Unavailable Error"
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"
class InvokeAuthorizationError(InvokeError):
"""Raised when the Invoke returns authorization error."""
description = "Incorrect model credentials provided, please check and try again. "
class InvokeBadRequestError(InvokeError):
"""Raised when the Invoke returns bad request."""
description = "Bad Request Error"

View File

@ -0,0 +1,5 @@
class CredentialsValidateFailedError(Exception):
"""
Credentials validate failed error
"""
pass

View File

@ -0,0 +1,318 @@
import decimal
import os
from abc import ABC, abstractmethod
from typing import Optional
import yaml
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelType,
PriceConfig,
PriceInfo,
PriceType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from model_providers.core.utils.position_helper import get_position_map, sort_by_position_map
class AIModel(ABC):
"""
Base class for all models.
"""
model_type: ModelType
model_schemas: list[AIModelEntity] = None
started_at: float = 0
@abstractmethod
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
raise NotImplementedError
@property
@abstractmethod
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
raise NotImplementedError
def _transform_invoke_error(self, error: Exception) -> InvokeError:
"""
Transform invoke error to unified error
:param error: model invoke error
:return: unified error
"""
provider_name = self.__class__.__module__.split('.')[-3]
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
"""
Get price for given model and tokens
:param model: model name
:param credentials: model credentials
:param price_type: price type
:param tokens: number of tokens
:return: price info
"""
# get model schema
model_schema = self.get_model_schema(model, credentials)
# get price info from predefined model schema
price_config: Optional[PriceConfig] = None
if model_schema:
price_config: PriceConfig = model_schema.pricing
# get unit price
unit_price = None
if price_config:
if price_type == PriceType.INPUT:
unit_price = price_config.input
elif price_type == PriceType.OUTPUT and price_config.output is not None:
unit_price = price_config.output
if unit_price is None:
return PriceInfo(
unit_price=decimal.Decimal('0.0'),
unit=decimal.Decimal('0.0'),
total_amount=decimal.Decimal('0.0'),
currency="USD",
)
# calculate total amount
total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
return PriceInfo(
unit_price=unit_price,
unit=price_config.unit,
total_amount=total_amount,
currency=price_config.currency,
)
def predefined_models(self) -> list[AIModelEntity]:
"""
Get all predefined models for given provider.
:return:
"""
if self.model_schemas:
return self.model_schemas
model_schemas = []
# get module name
model_type = self.__class__.__module__.split('.')[-1]
# get provider name
provider_name = self.__class__.__module__.split('.')[-3]
# get the path of current classes
current_path = os.path.abspath(__file__)
# get parent path of the current path
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
# get all yaml files path under provider_model_type_path that do not start with __
model_schema_yaml_paths = [
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
]
# get _position.yaml file path
position_map = get_position_map(provider_model_type_path)
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
if 'use_template' in parameter_rule:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
copy_default_parameter_rule = default_parameter_rule.copy()
copy_default_parameter_rule.update(parameter_rule)
parameter_rule = copy_default_parameter_rule
except ValueError:
pass
if 'label' not in parameter_rule:
parameter_rule['label'] = {
'zh_Hans': parameter_rule['name'],
'en_US': parameter_rule['name']
}
new_parameter_rules.append(parameter_rule)
yaml_data['parameter_rules'] = new_parameter_rules
if 'label' not in yaml_data:
yaml_data['label'] = {
'zh_Hans': yaml_data['model'],
'en_US': yaml_data['model']
}
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
try:
# yaml_data to entity
model_schema = AIModelEntity(**yaml_data)
except Exception as e:
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
f' {str(e)}')
# cache model schema
model_schemas.append(model_schema)
# resort model schemas by position
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
# cache model schemas
self.model_schemas = model_schemas
return model_schemas
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
:param model: model name
:param credentials: model credentials
:return: model schema
"""
# get predefined models (predefined_models)
models = self.predefined_models()
model_map = {model.model: model for model in models}
if model in model_map:
return model_map[model]
if credentials:
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
if model_schema:
return model_schema
return None
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
schema = self.get_customizable_model_schema(model, credentials)
if not schema:
return None
# fill in the template
new_parameter_rules = []
for parameter_rule in schema.parameter_rules:
if parameter_rule.use_template:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError:
pass
new_parameter_rules.append(parameter_rule)
schema.parameter_rules = new_parameter_rules
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return None
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
"""
Get default parameter rule for given name
:param name: parameter name
:return: parameter rule
"""
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule:
raise Exception(f'Invalid model parameter rule name {name}')
return default_parameter_rule
def _get_num_tokens_by_gpt2(self, text: str) -> int:
"""
Get number of tokens for given prompt messages by gpt2
Some provider models do not provide an interface for obtaining the number of tokens.
Here, the gpt2 tokenizer is used to calculate the number of tokens.
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
:param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens
"""
return GPT2Tokenizer.get_num_tokens(text)

View File

@ -0,0 +1,819 @@
import logging
import os
import re
import time
from abc import abstractmethod
from collections.abc import Generator
from typing import Optional, Union
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.callbacks.logging_callback import LoggingCallback
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceType,
)
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LargeLanguageModel(AIModel):
"""
Model class for large language model.
"""
model_type: ModelType = ModelType.LLM
def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
# validate and filter model parameters
if model_parameters is None:
model_parameters = {}
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
self.started_at = time.perf_counter()
callbacks = callbacks or []
if bool(os.environ.get("DEBUG")):
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
self._trigger_before_invoke_callbacks(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
try:
if "response_format" in model_parameters:
result = self._code_block_mode_wrapper(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
ex=e,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
raise self._transform_invoke_error(e)
if stream and isinstance(result, Generator):
return self._invoke_result_generator(
model=model,
result=result,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
else:
self._trigger_after_invoke_callbacks(
model=model,
result=result,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
return result
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
else:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "normal"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
else:
yield piece
continue
new_piece = ""
for char in piece:
if state == "normal":
if char == "`":
state = "in_backticks"
backtick_count = 1
else:
new_piece += char
elif state == "in_backticks":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_content"
backtick_count = 0
else:
new_piece += "`" * backtick_count + char
state = "normal"
backtick_count = 0
elif state == "skip_content":
if char.isspace():
state = "normal"
if new_piece:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "search_start"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
# Reset content to ensure we're only processing and yielding the relevant parts
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
else:
# Yield pieces without content directly
yield piece
continue
if state == "done":
continue
new_piece = ""
for char in piece:
if state == "search_start":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_language"
backtick_count = 0
else:
backtick_count = 0
elif state == "skip_language":
# Skip everything until the first newline, marking the end of the language identifier
if char == "\n":
state = "in_code_block"
elif state == "in_code_block":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "done"
break
else:
if backtick_count > 0:
# If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count
backtick_count = 0
new_piece += char
elif state == "done":
break
if new_piece:
# Only yield content collected within the code block
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
prompt_message = AssistantPromptMessage(
content=""
)
usage = None
system_fingerprint = None
real_model = model
try:
for chunk in result:
yield chunk
self._trigger_new_chunk_callbacks(
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
prompt_message.content += chunk.delta.message.content
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint
),
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
raise NotImplementedError
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
"""
Transform llm result to stream
:param result: llm result
:return: stream
"""
index = 0
tool_calls = result.message.tool_calls
for word in result.message.content:
assistant_prompt_message = AssistantPromptMessage(
content=word,
tool_calls=tool_calls if index == (len(result.message.content) - 1) else []
)
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
)
index += 1
time.sleep(0.01)
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
"""
Get parameter rules
:param model: model name
:param credentials: model credentials
:return: parameter rules
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema:
return model_schema.parameter_rules
return []
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
"""
Get model mode
:param model: model name
:param credentials: model credentials
:return: model mode
"""
model_schema = self.get_model_schema(model, credentials)
mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
return mode
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param prompt_tokens: prompt tokens
:param completion_tokens: completion tokens
:return: usage
"""
# get prompt price info
prompt_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=prompt_tokens,
)
# get completion price info
completion_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
)
# transform usage
usage = LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=prompt_price_info.unit_price,
prompt_price_unit=prompt_price_info.unit,
prompt_price=prompt_price_info.total_amount,
completion_tokens=completion_tokens,
completion_unit_price=completion_price_info.unit_price,
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger before invoke callbacks
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
if callbacks:
for callback in callbacks:
try:
callback.on_before_invoke(
llm_instance=self,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger new chunk callbacks
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
if callbacks:
for callback in callbacks:
try:
callback.on_new_chunk(
llm_instance=self,
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger after invoke callbacks
:param model: model name
:param result: result
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
if callbacks:
for callback in callbacks:
try:
callback.on_after_invoke(
llm_instance=self,
result=result,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger invoke error callbacks
:param model: model name
:param ex: exception
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
if callbacks:
for callback in callbacks:
try:
callback.on_invoke_error(
llm_instance=self,
ex=ex,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
"""
Validate model parameters
:param model: model name
:param model_parameters: model parameters
:param credentials: model credentials
:return:
"""
parameter_rules = self.get_parameter_rules(model, credentials)
# validate model parameters
filtered_model_parameters = {}
for parameter_rule in parameter_rules:
parameter_name = parameter_rule.name
parameter_value = model_parameters.get(parameter_name)
if parameter_value is None:
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
# if parameter value is None, use template value variable name instead
parameter_value = model_parameters[parameter_rule.use_template]
else:
if parameter_rule.required:
if parameter_rule.default is not None:
filtered_model_parameters[parameter_name] = parameter_rule.default
continue
else:
raise ValueError(f"Model Parameter {parameter_name} is required.")
else:
continue
# validate parameter value type
if parameter_rule.type == ParameterType.INT:
if not isinstance(parameter_value, int):
raise ValueError(f"Model Parameter {parameter_name} should be int.")
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, float | int):
raise ValueError(f"Model Parameter {parameter_name} should be float.")
# validate parameter value precision
if parameter_rule.precision is not None:
if parameter_rule.precision == 0:
if parameter_value != int(parameter_value):
raise ValueError(f"Model Parameter {parameter_name} should be int.")
else:
if parameter_value != round(parameter_value, parameter_rule.precision):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
elif parameter_rule.type == ParameterType.BOOLEAN:
if not isinstance(parameter_value, bool):
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
elif parameter_rule.type == ParameterType.STRING:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
else:
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
filtered_model_parameters[parameter_name] = parameter_value
return filtered_model_parameters

View File

@ -0,0 +1,124 @@
import importlib
import os
from abc import ABC, abstractmethod
import yaml
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class ModelProvider(ABC):
provider_schema: ProviderEntity = None
model_instance_map: dict[str, AIModel] = {}
@abstractmethod
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
raise NotImplementedError
def get_provider_schema(self) -> ProviderEntity:
"""
Get provider schema
:return: provider schema
"""
if self.provider_schema:
return self.provider_schema
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
# get the path of the model_provider classes
base_path = os.path.abspath(__file__)
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
try:
# yaml_data to entity
provider_schema = ProviderEntity(**yaml_data)
except Exception as e:
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
# cache schema
self.provider_schema = provider_schema
return provider_schema
def models(self, model_type: ModelType) -> list[AIModelEntity]:
"""
Get all models for given model type
:param model_type: model type defined in `ModelType`
:return: list of models
"""
provider_schema = self.get_provider_schema()
if model_type not in provider_schema.supported_model_types:
return []
# get model instance of the model type
model_instance = self.get_model_instance(model_type)
# get predefined models (predefined_models)
models = model_instance.predefined_models()
# return models
return models
def get_model_instance(self, model_type: ModelType) -> AIModel:
"""
Get model instance
:param model_type: model type defined in `ModelType`
:return:
"""
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
# get the path of the model type classes
base_path = os.path.abspath(__file__)
model_type_name = model_type.value.replace('-', '_')
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
model_instance_map = model_class()
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
return model_instance_map

View File

@ -0,0 +1,48 @@
import time
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class ModerationModel(AIModel):
"""
Model class for moderation model.
"""
model_type: ModelType = ModelType.MODERATION
def invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
self.started_at = time.perf_counter()
try:
return self._invoke(model, credentials, text, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
raise NotImplementedError

View File

@ -0,0 +1,56 @@
import time
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class RerankModel(AIModel):
"""
Base Model class for rerank model.
"""
model_type: ModelType = ModelType.RERANK
def invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
self.started_at = time.perf_counter()
try:
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
raise NotImplementedError

View File

@ -0,0 +1,57 @@
import os
from abc import abstractmethod
from typing import IO, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class Speech2TextModel(AIModel):
"""
Model class for speech2text model.
"""
model_type: ModelType = ModelType.SPEECH2TEXT
def invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
try:
return self._invoke(model, credentials, file, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
raise NotImplementedError
def _get_demo_file_path(self) -> str:
"""
Get demo file for given model
:return: demo file
"""
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the audio file
return os.path.join(current_dir, 'audio.mp3')

View File

@ -0,0 +1,48 @@
from abc import abstractmethod
from typing import IO, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError

View File

@ -0,0 +1,90 @@
import time
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class TextEmbeddingModel(AIModel):
"""
Model class for text embedding model.
"""
model_type: ModelType = ModelType.TEXT_EMBEDDING
def invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
self.started_at = time.perf_counter()
try:
return self._invoke(model, credentials, texts, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
raise NotImplementedError
def _get_context_size(self, model: str, credentials: dict) -> int:
"""
Get context size for given embedding model
:param model: model name
:param credentials: model credentials
:return: context size
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
return 1000
def _get_max_chunks(self, model: str, credentials: dict) -> int:
"""
Get max chunks for given embedding model
:param model: model name
:param credentials: model credentials
:return: max chunks
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return 1

View File

@ -0,0 +1,23 @@
{
"bos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@ -0,0 +1,33 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"bos_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": true,
"eos_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"errors": "replace",
"model_max_length": 1024,
"pad_token": null,
"tokenizer_class": "GPT2Tokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@ -0,0 +1,33 @@
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
_tokenizer = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -0,0 +1,165 @@
import hashlib
import subprocess
import uuid
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
class TTSModel(AIModel):
"""
Model class for ttstext model.
"""
model_type: ModelType = ModelType.TTS
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
try:
self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
content_text=content_text, voice=voice, tenant_id=tenant_id)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
raise NotImplementedError
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Get voice for given tts model voices
:param language: tts language
:param model: model name
:param credentials: model credentials
:return: voices lists
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')]
else:
return [{'name': d['name'], 'value': d['mode']} for d in voices]
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
"""
Get voice for given tts model
:param model: model name
:param credentials: model credentials
:return: voice
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
"""
Get audio type for given tts model
:param model: model name
:param credentials: model credentials
:return: voice
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
Get audio type for given tts model
:return: audio type
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
Get audio max workers for given tts model
:return: audio type
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
if delimiters is None:
delimiters = set('。!?;\n')
buf = []
word_count = 0
for char in text:
buf.append(char)
if char in delimiters:
if word_count >= limit:
yield ''.join(buf)
buf = []
word_count = 0
else:
word_count += 1
else:
word_count += 1
if buf:
yield ''.join(buf)
@staticmethod
def _is_ffmpeg_installed():
try:
output = subprocess.check_output("ffmpeg -version", shell=True)
if "ffmpeg version" in output.decode("utf-8"):
return True
else:
raise InvokeBadRequestError("ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
except Exception:
raise InvokeBadRequestError("ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
# Todo: To improve the streaming function
@staticmethod
def _get_file_name(file_content: str) -> str:
hash_object = hashlib.sha256(file_content.encode())
hex_digest = hash_object.hexdigest()
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
return str(unique_uuid)

View File

@ -0,0 +1,3 @@
from model_providers.core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
model_provider_factory = ModelProviderFactory()

View File

@ -0,0 +1,25 @@
- openai
- anthropic
- azure_openai
- google
- cohere
- bedrock
- togetherai
- ollama
- mistralai
- groq
- replicate
- huggingface_hub
- zhipuai
- baichuan
- spark
- minimax
- tongyi
- wenxin
- moonshot
- jina
- chatglm
- xinference
- openllm
- localai
- openai_api_compatible

View File

@ -0,0 +1,78 @@
<svg width="90" height="20" viewBox="0 0 90 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_8587_60274)">
<mask id="mask0_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M89.375 4.99805H0V14.998H89.375V4.99805Z" fill="white"/>
</mask>
<g mask="url(#mask0_8587_60274)">
<mask id="mask1_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99609H89.375V14.9961H0V4.99609Z" fill="white"/>
</mask>
<g mask="url(#mask1_8587_60274)">
<mask id="mask2_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99414H89.375V14.9941H0V4.99414Z" fill="white"/>
</mask>
<g mask="url(#mask2_8587_60274)">
<mask id="mask3_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask3_8587_60274)">
<path d="M18.1273 11.9244L13.7773 5.15625H11.4297V14.825H13.4321V8.05688L17.7821 14.825H20.1297V5.15625H18.1273V11.9244Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask4_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask4_8587_60274)">
<path d="M21.7969 7.02094H25.0423V14.825H27.1139V7.02094H30.3594V5.15625H21.7969V7.02094Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask5_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask5_8587_60274)">
<path d="M38.6442 9.00994H34.0871V5.15625H32.0156V14.825H34.0871V10.8746H38.6442V14.825H40.7156V5.15625H38.6442V9.00994Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask6_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask6_8587_60274)">
<path d="M45.3376 7.02094H47.893C48.9152 7.02094 49.4539 7.39387 49.4539 8.09831C49.4539 8.80275 48.9152 9.17569 47.893 9.17569H45.3376V7.02094ZM51.5259 8.09831C51.5259 6.27506 50.186 5.15625 47.9897 5.15625H43.2656V14.825H45.3376V11.0404H47.6443L49.7164 14.825H52.0094L49.715 10.7521C50.8666 10.3094 51.5259 9.37721 51.5259 8.09831Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask7_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask7_8587_60274)">
<path d="M57.8732 13.0565C56.2438 13.0565 55.2496 11.8963 55.2496 10.004C55.2496 8.08416 56.2438 6.92394 57.8732 6.92394C59.4887 6.92394 60.4691 8.08416 60.4691 10.004C60.4691 11.8963 59.4887 13.0565 57.8732 13.0565ZM57.8732 4.99023C55.0839 4.99023 53.1094 7.06206 53.1094 10.004C53.1094 12.9184 55.0839 14.9902 57.8732 14.9902C60.6486 14.9902 62.6094 12.9184 62.6094 10.004C62.6094 7.06206 60.6486 4.99023 57.8732 4.99023Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask8_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask8_8587_60274)">
<path d="M69.1794 9.45194H66.6233V7.02094H69.1794C70.2019 7.02094 70.7407 7.43532 70.7407 8.23644C70.7407 9.03756 70.2019 9.45194 69.1794 9.45194ZM69.2762 5.15625H64.5508V14.825H66.6233V11.3166H69.2762C71.473 11.3166 72.8133 10.1564 72.8133 8.23644C72.8133 6.3165 71.473 5.15625 69.2762 5.15625Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask9_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask9_8587_60274)">
<path d="M86.8413 11.5786C86.4823 12.5179 85.7642 13.0565 84.7837 13.0565C83.1542 13.0565 82.16 11.8963 82.16 10.004C82.16 8.08416 83.1542 6.92394 84.7837 6.92394C85.7642 6.92394 86.4823 7.46261 86.8413 8.40183H89.0369C88.4984 6.33002 86.8827 4.99023 84.7837 4.99023C81.9942 4.99023 80.0195 7.06206 80.0195 10.004C80.0195 12.9184 81.9942 14.9902 84.7837 14.9902C86.8965 14.9902 88.5122 13.6366 89.0508 11.5786H86.8413Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask10_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask10_8587_60274)">
<path d="M73.6484 5.15625L77.5033 14.825H79.6172L75.7624 5.15625H73.6484Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask11_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask11_8587_60274)">
<path d="M3.64038 10.9989L4.95938 7.60106L6.27838 10.9989H3.64038ZM3.85422 5.15625L0 14.825H2.15505L2.9433 12.7946H6.97558L7.76371 14.825H9.91875L6.06453 5.15625H3.85422Z" fill="black" fill-opacity="0.92"/>
</g>
</g>
</g>
</g>
</g>
<defs>
<clipPath id="clip0_8587_60274">
<rect width="89.375" height="10" fill="white" transform="translate(0 5)"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

@ -0,0 +1,4 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="6" fill="#CA9F7B"/>
<path d="M15.3843 6.43481H12.9687L17.3739 17.5652H19.7896L15.3843 6.43481ZM8.40522 6.43481L4 17.5652H6.4633L7.36417 15.2279H11.9729L12.8737 17.5652H15.337L10.9318 6.43481H8.40522ZM8.16104 13.1607L9.66852 9.24907L11.176 13.1607H8.16104Z" fill="#191918"/>
</svg>

After

Width:  |  Height:  |  Size: 410 B

View File

@ -0,0 +1,31 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class AnthropicProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='claude-instant-1.2',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,39 @@
provider: anthropic
label:
en_US: Anthropic
description:
en_US: Anthropics powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 3。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#F0F0EB"
help:
title:
en_US: Get your API Key from Anthropic
zh_Hans: 从 Anthropic 获取 API Key
url:
en_US: https://console.anthropic.com/account/keys
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: anthropic_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API URL
en_US: Enter your API URL

View File

@ -0,0 +1,6 @@
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@ -0,0 +1,36 @@
model: claude-2.1
label:
en_US: claude-2.1
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,37 @@
model: claude-2
label:
en_US: claude-2
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -0,0 +1,37 @@
model: claude-3-haiku-20240307
label:
en_US: claude-3-haiku-20240307
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '0.25'
output: '1.25'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,37 @@
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,37 @@
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,35 @@
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,36 @@
model: claude-instant-1
label:
en_US: claude-instant-1
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -0,0 +1,506 @@
import base64
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
import anthropic
import requests
from anthropic import Anthropic, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from httpx import Timeout
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
# chat model
self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
client = Anthropic(api_key="")
return client.count_tokens(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping"),
],
model_parameters={
"temperature": 0,
"max_tokens": 20,
},
stream=False
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.content[0].text
)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response:
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
content=''
),
finish_reason=finish_reason,
usage=usage
)
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
index = chunk.index
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk.index,
message=assistant_prompt_message,
)
)
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
:param credentials:
:return:
"""
credentials_kwargs = {
"api_key": credentials['anthropic_api_key'],
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{human_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{human_prompt} [IMAGE]"
elif isinstance(message, AssistantPromptMessage):
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{ai_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
anthropic.APIConnectionError,
anthropic.APITimeoutError
],
InvokeServerUnavailableError: [
anthropic.InternalServerError
],
InvokeRateLimitError: [
anthropic.RateLimitError
],
InvokeAuthorizationError: [
anthropic.AuthenticationError,
anthropic.PermissionDeniedError
],
InvokeBadRequestError: [
anthropic.BadRequestError,
anthropic.NotFoundError,
anthropic.UnprocessableEntityError,
anthropic.APIError
]
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Some files were not shown because too many files have changed in this diff Show More