make format

This commit is contained in:
glide-the 2024-03-31 15:12:20 +08:00
parent b8d748b668
commit a2df71d9ea
11 changed files with 121 additions and 68 deletions

View File

@ -1,9 +1,11 @@
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
from omegaconf import DictConfig, OmegaConf
from model_providers.bootstrap_web.openai_bootstrap_web import (
RESTFulOpenAIBootstrapBaseWeb,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager
from omegaconf import OmegaConf, DictConfig
def _to_custom_provide_configuration(cfg: DictConfig):
"""
@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig):
provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items():
model_credential = item.get('model_credential')
provider_credential = item.get('provider_credential')
model_credential = item.get("model_credential")
provider_credential = item.get("provider_credential")
# 转换omegaconf对象为基本属性
if model_credential:
model_credential = OmegaConf.to_container(model_credential)
@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig):
provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
return (
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
class BootstrapWebBuilder:
"""
创建一个模型实例创建工具
"""
_model_providers_cfg_path: str
_host: str
_port: int
@ -68,20 +74,26 @@ class BootstrapWebBuilder:
return self
def build(self) -> OpenAIBootstrapBaseWeb:
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
assert (
self._model_providers_cfg_path is not None
and self._host is not None
and self._port is not None
)
# 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ModelManager(provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict)
provider_manager = ModelManager(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
# 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
"host": self._host,
"port": self._port
})
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg={"host": self._host, "port": self._port}
)
restful.provider_manager = provider_manager
return restful

View File

@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import (
EmbeddingsRequest,
EmbeddingsResponse,
FunctionAvailable,
ModelList, ModelCard,
ModelCard,
ModelList,
)
from model_providers.core.model_manager import ModelManager, ModelInstance
from model_providers.core.model_manager import ModelInstance, ModelManager
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
provider_model_bundle = (
self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
)
llm_models.extend(
provider_model_bundle.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
except Exception as e:
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
)
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
for model in llm_models
]
return ModelList(
data=models_list
)
return ModelList(data=models_list)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:

View File

@ -3,7 +3,10 @@ from collections import deque
from fastapi import Request
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.model_manager import ModelManager
@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
@abstractmethod
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
pass
@abstractmethod
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
pass

View File

@ -22,7 +22,15 @@ class Finish(str, Enum):
class ModelCard(BaseModel):
id: str
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
object: Literal[
"text-generation",
"embeddings",
"reranking",
"speech2text",
"moderation",
"tts",
"text2img",
] = "llm"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"

View File

@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
from typing import Any, Optional, List
from typing import Any, List, Optional
from pydantic import BaseModel
@ -74,7 +74,6 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
"""
Enum class for fetch from.

View File

@ -253,9 +253,11 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.get('model'),
model_type=ModelType.value_of(provider_model_record.get('model_type')),
credentials=provider_model_credentials
model=provider_model_record.get("model"),
model_type=ModelType.value_of(
provider_model_record.get("model_type")
),
credentials=provider_model_credentials,
)
)

View File

@ -1,7 +1,6 @@
import logging
import time
import os
import time
logger = logging.getLogger(__name__)
@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str):
def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
# for windows, the path should be a raw string.
log_file_path = (

View File

@ -1,11 +1,16 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:

View File

@ -1,21 +1,26 @@
from model_providers import BootstrapWebBuilder
import logging
import asyncio
import logging
import pytest
from model_providers import BootstrapWebBuilder
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = BootstrapWebBuilder() \
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml") \
.host(host="127.0.0.1") \
.port(port=20000) \
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)

View File

@ -1,11 +1,16 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:

View File

@ -1,11 +1,10 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
import logging
import asyncio
import pytest
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
@ -14,14 +13,17 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml")
cfg = OmegaConf.load(
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None:
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
predefined_models = (
provider_model_bundle_emb.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")