diff --git a/model-providers/model_providers.yaml b/model-providers/model_providers.yaml new file mode 100644 index 00000000..eb96fba7 --- /dev/null +++ b/model-providers/model_providers.yaml @@ -0,0 +1,28 @@ +openai: + model_credential: + - model: 'gpt-3.5-turbo' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + - model: 'gpt-4' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + + provider_credential: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + +xinference: + model_credential: + - model: 'gpt-3.5-turbo' + model_type: 'llm' + credential: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index 486e453c..04e8fe2b 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,17 +1,55 @@ -from chatchat.configs import MODEL_PLATFORMS from model_providers.core.model_manager import ModelManager -def _to_custom_provide_configuration(): +from omegaconf import OmegaConf, DictConfig + + +def _to_custom_provide_configuration(cfg: DictConfig): + """ + ``` + openai: + model_credential: + - model: 'gpt-3.5-turbo' + model_credentials: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + - model: 'gpt-4' + model_credentials: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + + provider_credential: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + + ``` + :param model_providers_cfg: + :return: + """ provider_name_to_provider_records_dict = {} provider_name_to_provider_model_records_dict = {} + + for key, item in cfg.items(): + model_credential = item.get('model_credential') + provider_credential = item.get('provider_credential') + # 转换omegaconf对象为基本属性 + if model_credential: + model_credential = OmegaConf.to_container(model_credential) + provider_name_to_provider_model_records_dict[key] = model_credential + if provider_credential: + provider_credential = OmegaConf.to_container(provider_credential) + provider_name_to_provider_records_dict[key] = provider_credential + return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict + +model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml") +provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg=model_providers_cfg) # 基于配置管理器创建的模型实例 provider_manager = ModelManager( - provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", - } - }, - provider_name_to_provider_model_records_dict={} -) \ No newline at end of file + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict +) diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py index ef5fab6e..b58dc035 100644 --- a/model-providers/model_providers/__main__.py +++ b/model-providers/model_providers/__main__.py @@ -1,6 +1,7 @@ import os from typing import cast, Generator +from model_providers import provider_manager from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage @@ -8,16 +9,7 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType if __name__ == '__main__': # 基于配置管理器创建的模型实例 - provider_manager = ModelManager( - provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", - } - }, - provider_name_to_provider_model_records_dict={} - ) - # # Invoke model model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4') diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index a2016963..f38d976d 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -230,7 +230,7 @@ class ProviderManager: custom_model_configurations.append( CustomModelConfiguration( - model=provider_model_record.get('model_name'), + model=provider_model_record.get('model'), model_type=ModelType.value_of(provider_model_record.get('model_type')), credentials=provider_model_credentials ) diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml index 2ab88d8e..87359676 100644 --- a/model-providers/pyproject.toml +++ b/model-providers/pyproject.toml @@ -14,6 +14,8 @@ sse-starlette = "^1.8.2" pyyaml = "6.0.1" pydantic = "2.6.4" redis = "4.5.4" +# config manage +omegaconf = "2.0.6" # modle_runtime openai = "1.13.3" tiktoken = "0.5.2" @@ -188,4 +190,4 @@ markers = [ "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them" ] -asyncio_mode = "auto" \ No newline at end of file +asyncio_mode = "auto"