diff --git a/chatchat-server/pyproject.toml b/chatchat-server/pyproject.toml index 82cd1fa9..fa7cc8fe 100644 --- a/chatchat-server/pyproject.toml +++ b/chatchat-server/pyproject.toml @@ -57,18 +57,14 @@ optional = true # dependencies used for running tests (e.g., pytest, freezegun, response). # Any dependencies that do not meet that criteria will be removed. pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.7.0" -pytest-watcher = "^0.2.6" freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.20.3" -lark = "^1.1.5" -pandas = "^2.0.0" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" +pytest-mock = "^3.10.0" syrupy = "^4.0.2" +pytest-watcher = "^0.3.4" +pytest-asyncio = "^0.21.1" +grandalf = "^0.8" +pytest-profiling = "^1.7.0" +responses = "^0.25.0" model-providers = { path = "../model-providers", develop = true } diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py deleted file mode 100644 index decebcc4..00000000 --- a/model-providers/model_providers/__main__.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -from typing import Generator, cast - -from model_providers import provider_manager -from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.llm_entities import ( - LLMResultChunk, - LLMResultChunkDelta, -) -from model_providers.core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - UserPromptMessage, -) -from model_providers.core.model_runtime.entities.model_entities import ModelType - -if __name__ == '__main__': - # 基于配置管理器创建的模型实例 - - # Invoke model - model_instance = provider_manager.get_model_instance( - provider="openai", model_type=ModelType.LLM, model="gpt-4" - ) - - response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={ - "temperature": 0.7, - "top_p": 1.0, - "top_k": 1, - "plugin_web_search": True, - }, - stop=["you"], - stream=True, - user="abc-123", - ) - - assert isinstance(response, Generator) - total_message = "" - for chunk in response: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - total_message += chunk.delta.message.content - assert ( - len(chunk.delta.message.content) > 0 - if not chunk.delta.finish_reason - else True - ) - print(total_message) - assert "参考资料" in total_message diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index b1032861..e47ed678 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsRequest, EmbeddingsResponse, FunctionAvailable, - ModelList, + ModelList, ModelCard, ) from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) @@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): started_event.set() async def list_models(self, provider: str, request: Request): - pass + logger.info(f"Received list_models request for provider: {provider}") + # 返回ModelType所有的枚举 + llm_models: list[AIModelEntity] = [] + for model_type in ModelType.__members__.values(): + try: + provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle( + provider=provider, model_type=model_type + ) + llm_models.extend(provider_model_bundle.model_type_instance.predefined_models()) + except Exception as e: + logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}") + logger.error(e) + + # models list[AIModelEntity]转换称List[ModelCard] + + models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models] + + return ModelList( + data=models_list + ) async def create_embeddings( self, provider: str, request: Request, embeddings_request: EmbeddingsRequest diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 30d57c81..558dc95b 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -22,7 +22,7 @@ class Finish(str, Enum): class ModelCard(BaseModel): id: str - object: Literal["model"] = "model" + object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm" created: int = Field(default_factory=lambda: int(time.time())) owned_by: Literal["owner"] = "owner" diff --git a/model-providers/model_providers/core/model_manager.py b/model-providers/model_providers/core/model_manager.py index 5e2c69bd..af896423 100644 --- a/model-providers/model_providers/core/model_manager.py +++ b/model-providers/model_providers/core/model_manager.py @@ -245,6 +245,10 @@ class ModelManager: provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, ) + @property + def provider_manager(self) -> ProviderManager: + return self._provider_manager + def get_model_instance( self, provider: str, model_type: ModelType, model: str ) -> ModelInstance: diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index 307d459a..ae5938a4 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, List from pydantic import BaseModel @@ -74,6 +74,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/model-providers/model_providers/core/utils/utils.py b/model-providers/model_providers/core/utils/utils.py new file mode 100644 index 00000000..2bc3fde8 --- /dev/null +++ b/model-providers/model_providers/core/utils/utils.py @@ -0,0 +1,89 @@ +import logging + +import time +import os + +logger = logging.getLogger(__name__) + + +class LoggerNameFilter(logging.Filter): + def filter(self, record): + # return record.name.startswith("loom_core") or record.name in "ERROR" or ( + # record.name.startswith("uvicorn.error") + # and record.getMessage().startswith("Uvicorn running on") + # ) + return True + + +def get_log_file(log_path: str, sub_dir: str): + """ + sub_dir should contain a timestamp. + """ + log_dir = os.path.join(log_path, sub_dir) + # Here should be creating a new directory each time, so `exist_ok=False` + os.makedirs(log_dir, exist_ok=False) + return os.path.join(log_dir, "loom_core.log") + + +def get_config_dict( + log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int +) -> dict: + # for windows, the path should be a raw string. + log_file_path = ( + log_file_path.encode("unicode-escape").decode() + if os.name == "nt" + else log_file_path + ) + log_level = log_level.upper() + config_dict = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "formatter": { + "format": ( + "%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s" + ) + }, + }, + "filters": { + "logger_name_filter": { + "()": __name__ + ".LoggerNameFilter", + }, + }, + "handlers": { + "stream_handler": { + "class": "logging.StreamHandler", + "formatter": "formatter", + "level": log_level, + # "stream": "ext://sys.stdout", + # "filters": ["logger_name_filter"], + }, + "file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "formatter", + "level": log_level, + "filename": log_file_path, + "mode": "a", + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf8", + }, + }, + "loggers": { + "loom_core": { + "handlers": ["stream_handler", "file_handler"], + "level": log_level, + "propagate": False, + } + }, + "root": { + "level": log_level, + "handlers": ["stream_handler", "file_handler"], + }, + } + return config_dict + + +def get_timestamp_ms(): + t = time.time() + return int(round(t * 1000)) diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml index 44fc1529..859006de 100644 --- a/model-providers/pyproject.toml +++ b/model-providers/pyproject.toml @@ -26,18 +26,14 @@ boto3 = "1.28.17" # dependencies used for running tests (e.g., pytest, freezegun, response). # Any dependencies that do not meet that criteria will be removed. pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.7.0" -pytest-watcher = "^0.2.6" freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.20.3" -lark = "^1.1.5" -pandas = "^2.0.0" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" +pytest-mock = "^3.10.0" syrupy = "^4.0.2" +pytest-watcher = "^0.3.4" +pytest-asyncio = "^0.21.1" +grandalf = "^0.8" +pytest-profiling = "^1.7.0" +responses = "^0.25.0" @@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/server_unit_test/conftest.py new file mode 100644 index 00000000..dc1af083 --- /dev/null +++ b/model-providers/tests/server_unit_test/conftest.py @@ -0,0 +1,99 @@ +"""Configuration for unit tests.""" +from importlib import util +from typing import Dict, Sequence, List +import logging +import pytest +from pytest import Config, Function, Parser + +from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + +@pytest.fixture +def logging_conf() -> dict: + return get_config_dict( + "DEBUG", + get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), + 122, + 111, + ) diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py new file mode 100644 index 00000000..d4145c77 --- /dev/null +++ b/model-providers/tests/server_unit_test/test_init_server.py @@ -0,0 +1,28 @@ +from model_providers import BootstrapWebBuilder +import logging +import asyncio + +import pytest +logger = logging.getLogger(__name__) + + +@pytest.mark.requires("fastapi") +def test_init_server(logging_conf: dict) -> None: + try: + boot = BootstrapWebBuilder() \ + .model_providers_cfg_path( + model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml") \ + .host(host="127.0.0.1") \ + .port(port=20000) \ + .build() + boot.set_app_event(started_event=None) + boot.serve(logging_conf=logging_conf) + + async def pool_join_thread(): + await boot.join() + + asyncio.run(pool_join_thread()) + except SystemExit: + logger.info("SystemExit raised, exiting") + raise diff --git a/model-providers/tests/unit_test/conftest.py b/model-providers/tests/unit_test/conftest.py new file mode 100644 index 00000000..dc1af083 --- /dev/null +++ b/model-providers/tests/unit_test/conftest.py @@ -0,0 +1,99 @@ +"""Configuration for unit tests.""" +from importlib import util +from typing import Dict, Sequence, List +import logging +import pytest +from pytest import Config, Function, Parser + +from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + +@pytest.fixture +def logging_conf() -> dict: + return get_config_dict( + "DEBUG", + get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), + 122, + 111, + ) diff --git a/model-providers/tests/unit_test/test_provider_manager_models.py b/model-providers/tests/unit_test/test_provider_manager_models.py new file mode 100644 index 00000000..c7afd8dc --- /dev/null +++ b/model-providers/tests/unit_test/test_provider_manager_models.py @@ -0,0 +1,39 @@ +from omegaconf import OmegaConf + +from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration +import logging +import asyncio + +import pytest + +from model_providers.core.model_manager import ModelManager +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.provider_manager import ProviderManager + +logger = logging.getLogger(__name__) + + +def test_provider_manager_models(logging_conf: dict) -> None: + + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml") + # 转换配置文件 + provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + provider_model_bundle_llm = provider_manager.get_provider_model_bundle( + provider="openai", model_type=ModelType.LLM + ) + provider_model_bundle_emb = provider_manager.get_provider_model_bundle( + provider="openai", model_type=ModelType.TEXT_EMBEDDING + ) + predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models() + + logger.info(f"predefined_models: {predefined_models}")