模型列表适配

This commit is contained in:
glide-the 2024-03-31 15:08:30 +08:00
parent 032dc8f58d
commit f005ea3298
12 changed files with 396 additions and 76 deletions

View File

@ -57,18 +57,14 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
model-providers = { path = "../model-providers", develop = true }

View File

@ -1,50 +0,0 @@
import os
from typing import Generator, cast
from model_providers import provider_manager
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# Invoke model
model_instance = provider_manager.get_model_instance(
provider="openai", model_type=ModelType.LLM, model="gpt-4"
)
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert (
len(chunk.delta.message.content) > 0
if not chunk.delta.finish_reason
else True
)
print(total_message)
assert "参考资料" in total_message

View File

@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import (
EmbeddingsRequest,
EmbeddingsResponse,
FunctionAvailable,
ModelList,
ModelList, ModelCard,
)
from model_providers.core.model_manager import ModelManager, ModelInstance
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
started_event.set()
async def list_models(self, provider: str, request: Request):
pass
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
except Exception as e:
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
return ModelList(
data=models_list
)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest

View File

@ -22,7 +22,7 @@ class Finish(str, Enum):
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"

View File

@ -245,6 +245,10 @@ class ModelManager:
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
@property
def provider_manager(self) -> ProviderManager:
return self._provider_manager
def get_model_instance(
self, provider: str, model_type: ModelType, model: str
) -> ModelInstance:

View File

@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, List
from pydantic import BaseModel
@ -74,6 +74,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
"""
Enum class for fetch from.

View File

@ -0,0 +1,89 @@
import logging
import time
import os
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "loom_core.log")
def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
# for windows, the path should be a raw string.
log_file_path = (
log_file_path.encode("unicode-escape").decode()
if os.name == "nt"
else log_file_path
)
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {
"format": (
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
)
},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -26,18 +26,14 @@ boto3 = "1.28.17"
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [

View File

@ -0,0 +1,99 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,28 @@
from model_providers import BootstrapWebBuilder
import logging
import asyncio
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = BootstrapWebBuilder() \
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml") \
.host(host="127.0.0.1") \
.port(port=20000) \
.build()
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,99 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence, List
import logging
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,39 @@
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
import logging
import asyncio
import pytest
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml")
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
logger.info(f"predefined_models: {predefined_models}")