使用yaml加载用户配置适配器

This commit is contained in:
glide-the 2024-03-28 20:45:42 +08:00
parent 9818bd2a88
commit 451fef8a31
5 changed files with 80 additions and 20 deletions

View File

@ -0,0 +1,28 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''

View File

@ -1,17 +1,55 @@
from chatchat.configs import MODEL_PLATFORMS
from model_providers.core.model_manager import ModelManager
def _to_custom_provide_configuration():
from omegaconf import OmegaConf, DictConfig
def _to_custom_provide_configuration(cfg: DictConfig):
"""
```
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
```
:param model_providers_cfg:
:return:
"""
provider_name_to_provider_records_dict = {}
provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items():
model_credential = item.get('model_credential')
provider_credential = item.get('provider_credential')
# 转换omegaconf对象为基本属性
if model_credential:
model_credential = OmegaConf.to_container(model_credential)
provider_name_to_provider_model_records_dict[key] = model_credential
if provider_credential:
provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg=model_providers_cfg)
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
'openai': {
'openai_api_key': "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={}
)
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
)

View File

@ -1,6 +1,7 @@
import os
from typing import cast, Generator
from model_providers import provider_manager
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
@ -8,16 +9,7 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
'openai': {
'openai_api_key': "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={}
)
#
# Invoke model
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')

View File

@ -230,7 +230,7 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.get('model_name'),
model=provider_model_record.get('model'),
model_type=ModelType.value_of(provider_model_record.get('model_type')),
credentials=provider_model_credentials
)

View File

@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
pyyaml = "6.0.1"
pydantic = "2.6.4"
redis = "4.5.4"
# config manage
omegaconf = "2.0.6"
# modle_runtime
openai = "1.13.3"
tiktoken = "0.5.2"
@ -188,4 +190,4 @@ markers = [
"scheduled: mark tests to run in scheduled testing",
"compile: mark placeholder test used to compile integration tests without running them"
]
asyncio_mode = "auto"
asyncio_mode = "auto"