make format

This commit is contained in:
glide-the 2024-03-31 15:12:20 +08:00
parent b8d748b668
commit a2df71d9ea
11 changed files with 121 additions and 68 deletions

View File

@ -1,9 +1,11 @@
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb from omegaconf import DictConfig, OmegaConf
from model_providers.bootstrap_web.openai_bootstrap_web import (
RESTFulOpenAIBootstrapBaseWeb,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from omegaconf import OmegaConf, DictConfig
def _to_custom_provide_configuration(cfg: DictConfig): def _to_custom_provide_configuration(cfg: DictConfig):
""" """
@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig):
provider_name_to_provider_model_records_dict = {} provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items(): for key, item in cfg.items():
model_credential = item.get('model_credential') model_credential = item.get("model_credential")
provider_credential = item.get('provider_credential') provider_credential = item.get("provider_credential")
# 转换omegaconf对象为基本属性 # 转换omegaconf对象为基本属性
if model_credential: if model_credential:
model_credential = OmegaConf.to_container(model_credential) model_credential = OmegaConf.to_container(model_credential)
@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig):
provider_credential = OmegaConf.to_container(provider_credential) provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential provider_name_to_provider_records_dict[key] = provider_credential
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict return (
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
class BootstrapWebBuilder: class BootstrapWebBuilder:
""" """
创建一个模型实例创建工具 创建一个模型实例创建工具
""" """
_model_providers_cfg_path: str _model_providers_cfg_path: str
_host: str _host: str
_port: int _port: int
@ -68,20 +74,26 @@ class BootstrapWebBuilder:
return self return self
def build(self) -> OpenAIBootstrapBaseWeb: def build(self) -> OpenAIBootstrapBaseWeb:
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None assert (
self._model_providers_cfg_path is not None
and self._host is not None
and self._port is not None
)
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path) cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件 # 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( (
cfg) provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器 # 创建模型管理器
provider_manager = ModelManager(provider_name_to_provider_records_dict, provider_manager = ModelManager(
provider_name_to_provider_model_records_dict) provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
# 创建web服务 # 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={ restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
"host": self._host, cfg={"host": self._host, "port": self._port}
"port": self._port )
})
restful.provider_manager = provider_manager restful.provider_manager = provider_manager
return restful return restful

View File

@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import (
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
FunctionAvailable, FunctionAvailable,
ModelList, ModelCard, ModelCard,
ModelList,
) )
from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_manager import ModelInstance, ModelManager
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
llm_models: list[AIModelEntity] = [] llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values(): for model_type in ModelType.__members__.values():
try: try:
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle( provider_model_bundle = (
provider=provider, model_type=model_type self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
)
llm_models.extend(
provider_model_bundle.model_type_instance.predefined_models()
) )
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
except Exception as e: except Exception as e:
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}") logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
)
logger.error(e) logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard] # models list[AIModelEntity]转换称List[ModelCard]
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models] models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
for model in llm_models
]
return ModelList( return ModelList(data=models_list)
data=models_list
)
async def create_embeddings( async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response)) return EmbeddingsResponse(**dictify(response))
async def create_chat_completion( async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
logger.info( logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}" f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run( def run(
cfg: Dict, cfg: Dict,
logging_conf: Optional[dict] = None, logging_conf: Optional[dict] = None,
started_event: mp.Event = None, started_event: mp.Event = None,
): ):
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
try: try:

View File

@ -3,7 +3,10 @@ from collections import deque
from fastapi import Request from fastapi import Request
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
@abstractmethod @abstractmethod
async def create_embeddings( async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
pass pass
@abstractmethod @abstractmethod
async def create_chat_completion( async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
pass pass

View File

@ -22,7 +22,15 @@ class Finish(str, Enum):
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm" object: Literal[
"text-generation",
"embeddings",
"reranking",
"speech2text",
"moderation",
"tts",
"text2img",
] = "llm"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner" owned_by: Literal["owner"] = "owner"

View File

@ -1,6 +1,6 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, Optional, List from typing import Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -74,7 +74,6 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}") raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum): class FetchFrom(Enum):
""" """
Enum class for fetch from. Enum class for fetch from.

View File

@ -253,9 +253,11 @@ class ProviderManager:
custom_model_configurations.append( custom_model_configurations.append(
CustomModelConfiguration( CustomModelConfiguration(
model=provider_model_record.get('model'), model=provider_model_record.get("model"),
model_type=ModelType.value_of(provider_model_record.get('model_type')), model_type=ModelType.value_of(
credentials=provider_model_credentials provider_model_record.get("model_type")
),
credentials=provider_model_credentials,
) )
) )

View File

@ -1,7 +1,6 @@
import logging import logging
import time
import os import os
import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str):
def get_config_dict( def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict: ) -> dict:
# for windows, the path should be a raw string. # for windows, the path should be a raw string.
log_file_path = ( log_file_path = (

View File

@ -1,11 +1,16 @@
"""Configuration for unit tests.""" """Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest import pytest
from pytest import Config, Function, Parser from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:

View File

@ -1,21 +1,26 @@
from model_providers import BootstrapWebBuilder
import logging
import asyncio import asyncio
import logging
import pytest import pytest
from model_providers import BootstrapWebBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi") @pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None: def test_init_server(logging_conf: dict) -> None:
try: try:
boot = BootstrapWebBuilder() \ boot = (
BootstrapWebBuilder()
.model_providers_cfg_path( .model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml") \ "/model_providers.yaml"
.host(host="127.0.0.1") \ )
.port(port=20000) \ .host(host="127.0.0.1")
.port(port=20000)
.build() .build()
)
boot.set_app_event(started_event=None) boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf) boot.serve(logging_conf=logging_conf)

View File

@ -1,11 +1,16 @@
"""Configuration for unit tests.""" """Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest import pytest
from pytest import Config, Function, Parser from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:

View File

@ -1,11 +1,10 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
import logging
import asyncio
import pytest
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager from model_providers.core.provider_manager import ProviderManager
@ -14,14 +13,17 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None: def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" cfg = OmegaConf.load(
"/model_providers.yaml") "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
# 转换配置文件 # 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( (
cfg) provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器 # 创建模型管理器
provider_manager = ProviderManager( provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None:
provider_model_bundle_emb = provider_manager.get_provider_model_bundle( provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING provider="openai", model_type=ModelType.TEXT_EMBEDDING
) )
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models() predefined_models = (
provider_model_bundle_emb.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}") logger.info(f"predefined_models: {predefined_models}")