使用yaml加载用户配置适配器

This commit is contained in:
glide-the 2024-03-28 20:45:42 +08:00
parent 9818bd2a88
commit 451fef8a31
5 changed files with 80 additions and 20 deletions

View File

@ -0,0 +1,28 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''

View File

@ -1,17 +1,55 @@
from chatchat.configs import MODEL_PLATFORMS
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
def _to_custom_provide_configuration(): from omegaconf import OmegaConf, DictConfig
def _to_custom_provide_configuration(cfg: DictConfig):
"""
```
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
```
:param model_providers_cfg:
:return:
"""
provider_name_to_provider_records_dict = {} provider_name_to_provider_records_dict = {}
provider_name_to_provider_model_records_dict = {} provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items():
model_credential = item.get('model_credential')
provider_credential = item.get('provider_credential')
# 转换omegaconf对象为基本属性
if model_credential:
model_credential = OmegaConf.to_container(model_credential)
provider_name_to_provider_model_records_dict[key] = model_credential
if provider_credential:
provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg=model_providers_cfg)
# 基于配置管理器创建的模型实例 # 基于配置管理器创建的模型实例
provider_manager = ModelManager( provider_manager = ModelManager(
provider_name_to_provider_records_dict={ provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
'openai': { provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
'openai_api_key': "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={}
) )

View File

@ -1,6 +1,7 @@
import os import os
from typing import cast, Generator from typing import cast, Generator
from model_providers import provider_manager
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
@ -8,16 +9,7 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__': if __name__ == '__main__':
# 基于配置管理器创建的模型实例 # 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
'openai': {
'openai_api_key': "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={}
)
#
# Invoke model # Invoke model
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4') model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')

View File

@ -230,7 +230,7 @@ class ProviderManager:
custom_model_configurations.append( custom_model_configurations.append(
CustomModelConfiguration( CustomModelConfiguration(
model=provider_model_record.get('model_name'), model=provider_model_record.get('model'),
model_type=ModelType.value_of(provider_model_record.get('model_type')), model_type=ModelType.value_of(provider_model_record.get('model_type')),
credentials=provider_model_credentials credentials=provider_model_credentials
) )

View File

@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
pyyaml = "6.0.1" pyyaml = "6.0.1"
pydantic = "2.6.4" pydantic = "2.6.4"
redis = "4.5.4" redis = "4.5.4"
# config manage
omegaconf = "2.0.6"
# modle_runtime # modle_runtime
openai = "1.13.3" openai = "1.13.3"
tiktoken = "0.5.2" tiktoken = "0.5.2"