模型列表适配

This commit is contained in:
glide-the 2024-03-31 15:08:30 +08:00
parent 032dc8f58d
commit f005ea3298
12 changed files with 396 additions and 76 deletions

View File

@ -57,18 +57,14 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response). # dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed. # Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0" pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" pytest-mock = "^3.10.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2" syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
model-providers = { path = "../model-providers", develop = true } model-providers = { path = "../model-providers", develop = true }

View File

@ -1,50 +0,0 @@
import os
from typing import Generator, cast
from model_providers import provider_manager
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# Invoke model
model_instance = provider_manager.get_model_instance(
provider="openai", model_type=ModelType.LLM, model="gpt-4"
)
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert (
len(chunk.delta.message.content) > 0
if not chunk.delta.finish_reason
else True
)
print(total_message)
assert "参考资料" in total_message

View File

@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import (
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
FunctionAvailable, FunctionAvailable,
ModelList, ModelList, ModelCard,
) )
from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_manager import ModelManager, ModelInstance
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
started_event.set() started_event.set()
async def list_models(self, provider: str, request: Request): async def list_models(self, provider: str, request: Request):
pass logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
except Exception as e:
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
return ModelList(
data=models_list
)
async def create_embeddings( async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest

View File

@ -22,7 +22,7 @@ class Finish(str, Enum):
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Literal["model"] = "model" object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner" owned_by: Literal["owner"] = "owner"

View File

@ -245,6 +245,10 @@ class ModelManager:
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
) )
@property
def provider_manager(self) -> ProviderManager:
return self._provider_manager
def get_model_instance( def get_model_instance(
self, provider: str, model_type: ModelType, model: str self, provider: str, model_type: ModelType, model: str
) -> ModelInstance: ) -> ModelInstance:

View File

@ -1,6 +1,6 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional, List
from pydantic import BaseModel from pydantic import BaseModel
@ -74,6 +74,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}") raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum): class FetchFrom(Enum):
""" """
Enum class for fetch from. Enum class for fetch from.

View File

@ -0,0 +1,89 @@
import logging
import time
import os
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "loom_core.log")
def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
# for windows, the path should be a raw string.
log_file_path = (
log_file_path.encode("unicode-escape").decode()
if os.name == "nt"
else log_file_path
)
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {
"format": (
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
)
},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -26,18 +26,14 @@ boto3 = "1.28.17"
# dependencies used for running tests (e.g., pytest, freezegun, response). # dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed. # Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0" pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" pytest-mock = "^3.10.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2" syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
# #
# https://github.com/tophat/syrupy # https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
# Registering custom markers. # Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [ markers = [

View File

@ -0,0 +1,99 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,28 @@
from model_providers import BootstrapWebBuilder
import logging
import asyncio
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = BootstrapWebBuilder() \
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml") \
.host(host="127.0.0.1") \
.port(port=20000) \
.build()
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,99 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,39 @@
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
import logging
import asyncio
import pytest
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml")
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
logger.info(f"predefined_models: {predefined_models}")