mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
使用yaml加载用户配置适配器
This commit is contained in:
parent
9818bd2a88
commit
451fef8a31
28
model-providers/model_providers.yaml
Normal file
28
model-providers/model_providers.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
openai:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
- model: 'gpt-4'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
provider_credential:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_type: 'llm'
|
||||
credential:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
@ -1,17 +1,55 @@
|
||||
from chatchat.configs import MODEL_PLATFORMS
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
def _to_custom_provide_configuration():
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
|
||||
def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
"""
|
||||
```
|
||||
openai:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_credentials:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
- model: 'gpt-4'
|
||||
model_credentials:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
provider_credential:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
```
|
||||
:param model_providers_cfg:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_records_dict = {}
|
||||
provider_name_to_provider_model_records_dict = {}
|
||||
|
||||
for key, item in cfg.items():
|
||||
model_credential = item.get('model_credential')
|
||||
provider_credential = item.get('provider_credential')
|
||||
# 转换omegaconf对象为基本属性
|
||||
if model_credential:
|
||||
model_credential = OmegaConf.to_container(model_credential)
|
||||
provider_name_to_provider_model_records_dict[key] = model_credential
|
||||
if provider_credential:
|
||||
provider_credential = OmegaConf.to_container(provider_credential)
|
||||
provider_name_to_provider_records_dict[key] = provider_credential
|
||||
|
||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
||||
|
||||
|
||||
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg=model_providers_cfg)
|
||||
# 基于配置管理器创建的模型实例
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict={
|
||||
'openai': {
|
||||
'openai_api_key': "sk-4M9LYF",
|
||||
}
|
||||
},
|
||||
provider_name_to_provider_model_records_dict={}
|
||||
)
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import cast, Generator
|
||||
|
||||
from model_providers import provider_manager
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
||||
@ -8,16 +9,7 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 基于配置管理器创建的模型实例
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict={
|
||||
'openai': {
|
||||
'openai_api_key': "sk-4M9LYF",
|
||||
}
|
||||
},
|
||||
provider_name_to_provider_model_records_dict={}
|
||||
)
|
||||
|
||||
#
|
||||
# Invoke model
|
||||
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
|
||||
|
||||
|
||||
@ -230,7 +230,7 @@ class ProviderManager:
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.get('model_name'),
|
||||
model=provider_model_record.get('model'),
|
||||
model_type=ModelType.value_of(provider_model_record.get('model_type')),
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
|
||||
@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
|
||||
pyyaml = "6.0.1"
|
||||
pydantic = "2.6.4"
|
||||
redis = "4.5.4"
|
||||
# config manage
|
||||
omegaconf = "2.0.6"
|
||||
# modle_runtime
|
||||
openai = "1.13.3"
|
||||
tiktoken = "0.5.2"
|
||||
@ -188,4 +190,4 @@ markers = [
|
||||
"scheduled: mark tests to run in scheduled testing",
|
||||
"compile: mark placeholder test used to compile integration tests without running them"
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_mode = "auto"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user