mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
make format
This commit is contained in:
parent
b8d748b668
commit
a2df71d9ea
@ -1,9 +1,11 @@
|
||||
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from model_providers.bootstrap_web.openai_bootstrap_web import (
|
||||
RESTFulOpenAIBootstrapBaseWeb,
|
||||
)
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
|
||||
def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
"""
|
||||
@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
provider_name_to_provider_model_records_dict = {}
|
||||
|
||||
for key, item in cfg.items():
|
||||
model_credential = item.get('model_credential')
|
||||
provider_credential = item.get('provider_credential')
|
||||
model_credential = item.get("model_credential")
|
||||
provider_credential = item.get("provider_credential")
|
||||
# 转换omegaconf对象为基本属性
|
||||
if model_credential:
|
||||
model_credential = OmegaConf.to_container(model_credential)
|
||||
@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
provider_credential = OmegaConf.to_container(provider_credential)
|
||||
provider_name_to_provider_records_dict[key] = provider_credential
|
||||
|
||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
||||
return (
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
class BootstrapWebBuilder:
|
||||
"""
|
||||
创建一个模型实例创建工具
|
||||
"""
|
||||
|
||||
_model_providers_cfg_path: str
|
||||
_host: str
|
||||
_port: int
|
||||
@ -68,20 +74,26 @@ class BootstrapWebBuilder:
|
||||
return self
|
||||
|
||||
def build(self) -> OpenAIBootstrapBaseWeb:
|
||||
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
|
||||
assert (
|
||||
self._model_providers_cfg_path is not None
|
||||
and self._host is not None
|
||||
and self._port is not None
|
||||
)
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
||||
# 转换配置文件
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg)
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ModelManager(provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict)
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
# 创建web服务
|
||||
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
|
||||
"host": self._host,
|
||||
"port": self._port
|
||||
|
||||
})
|
||||
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||
cfg={"host": self._host, "port": self._port}
|
||||
)
|
||||
restful.provider_manager = provider_manager
|
||||
return restful
|
||||
|
||||
@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import (
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
FunctionAvailable,
|
||||
ModelList, ModelCard,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
)
|
||||
from model_providers.core.model_manager import ModelManager, ModelInstance
|
||||
from model_providers.core.model_manager import ModelInstance, ModelManager
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelType,
|
||||
)
|
||||
from model_providers.core.utils.generic import dictify, jsonify
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
llm_models: list[AIModelEntity] = []
|
||||
for model_type in ModelType.__members__.values():
|
||||
try:
|
||||
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||
provider=provider, model_type=model_type
|
||||
provider_model_bundle = (
|
||||
self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||
provider=provider, model_type=model_type
|
||||
)
|
||||
)
|
||||
llm_models.extend(
|
||||
provider_model_bundle.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
|
||||
except Exception as e:
|
||||
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
|
||||
logger.error(
|
||||
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||
)
|
||||
logger.error(e)
|
||||
|
||||
# models list[AIModelEntity]转换称List[ModelCard]
|
||||
|
||||
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
|
||||
models_list = [
|
||||
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
|
||||
for model in llm_models
|
||||
]
|
||||
|
||||
return ModelList(
|
||||
data=models_list
|
||||
)
|
||||
return ModelList(data=models_list)
|
||||
|
||||
async def create_embeddings(
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
return EmbeddingsResponse(**dictify(response))
|
||||
|
||||
async def create_chat_completion(
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||
@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
|
||||
|
||||
def run(
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
try:
|
||||
|
||||
@ -3,7 +3,10 @@ from collections import deque
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingsRequest,
|
||||
)
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
|
||||
@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||
|
||||
@abstractmethod
|
||||
async def create_embeddings(
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@ -22,7 +22,15 @@ class Finish(str, Enum):
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
|
||||
object: Literal[
|
||||
"text-generation",
|
||||
"embeddings",
|
||||
"reranking",
|
||||
"speech2text",
|
||||
"moderation",
|
||||
"tts",
|
||||
"text2img",
|
||||
] = "llm"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -74,7 +74,6 @@ class ModelType(Enum):
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
|
||||
@ -253,9 +253,11 @@ class ProviderManager:
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.get('model'),
|
||||
model_type=ModelType.value_of(provider_model_record.get('model_type')),
|
||||
credentials=provider_model_credentials
|
||||
model=provider_model_record.get("model"),
|
||||
model_type=ModelType.value_of(
|
||||
provider_model_record.get("model_type")
|
||||
),
|
||||
credentials=provider_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
import time
|
||||
import os
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str):
|
||||
|
||||
|
||||
def get_config_dict(
|
||||
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||
) -> dict:
|
||||
# for windows, the path should be a raw string.
|
||||
log_file_path = (
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence, List
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
get_timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
|
||||
@ -1,21 +1,26 @@
|
||||
from model_providers import BootstrapWebBuilder
|
||||
import logging
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from model_providers import BootstrapWebBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("fastapi")
|
||||
def test_init_server(logging_conf: dict) -> None:
|
||||
try:
|
||||
boot = BootstrapWebBuilder() \
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml") \
|
||||
.host(host="127.0.0.1") \
|
||||
.port(port=20000) \
|
||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml"
|
||||
)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence, List
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
get_timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
@ -14,14 +13,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_provider_manager_models(logging_conf: dict) -> None:
|
||||
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml")
|
||||
cfg = OmegaConf.load(
|
||||
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml"
|
||||
)
|
||||
# 转换配置文件
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg)
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None:
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||
predefined_models = (
|
||||
provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user