mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 14:23:23 +08:00
模型列表适配
This commit is contained in:
parent
032dc8f58d
commit
f005ea3298
@ -57,18 +57,14 @@ optional = true
|
|||||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||||
# Any dependencies that do not meet that criteria will be removed.
|
# Any dependencies that do not meet that criteria will be removed.
|
||||||
pytest = "^7.3.0"
|
pytest = "^7.3.0"
|
||||||
pytest-cov = "^4.0.0"
|
|
||||||
pytest-dotenv = "^0.5.2"
|
|
||||||
duckdb-engine = "^0.7.0"
|
|
||||||
pytest-watcher = "^0.2.6"
|
|
||||||
freezegun = "^1.2.2"
|
freezegun = "^1.2.2"
|
||||||
responses = "^0.22.0"
|
pytest-mock = "^3.10.0"
|
||||||
pytest-asyncio = "^0.20.3"
|
|
||||||
lark = "^1.1.5"
|
|
||||||
pandas = "^2.0.0"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
pytest-socket = "^0.6.0"
|
|
||||||
syrupy = "^4.0.2"
|
syrupy = "^4.0.2"
|
||||||
|
pytest-watcher = "^0.3.4"
|
||||||
|
pytest-asyncio = "^0.21.1"
|
||||||
|
grandalf = "^0.8"
|
||||||
|
pytest-profiling = "^1.7.0"
|
||||||
|
responses = "^0.25.0"
|
||||||
model-providers = { path = "../model-providers", develop = true }
|
model-providers = { path = "../model-providers", develop = true }
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,50 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Generator, cast
|
|
||||||
|
|
||||||
from model_providers import provider_manager
|
|
||||||
from model_providers.core.model_manager import ModelManager
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
|
||||||
LLMResultChunk,
|
|
||||||
LLMResultChunkDelta,
|
|
||||||
)
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# 基于配置管理器创建的模型实例
|
|
||||||
|
|
||||||
# Invoke model
|
|
||||||
model_instance = provider_manager.get_model_instance(
|
|
||||||
provider="openai", model_type=ModelType.LLM, model="gpt-4"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = model_instance.invoke_llm(
|
|
||||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
|
||||||
model_parameters={
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"top_k": 1,
|
|
||||||
"plugin_web_search": True,
|
|
||||||
},
|
|
||||||
stop=["you"],
|
|
||||||
stream=True,
|
|
||||||
user="abc-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, Generator)
|
|
||||||
total_message = ""
|
|
||||||
for chunk in response:
|
|
||||||
assert isinstance(chunk, LLMResultChunk)
|
|
||||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
|
||||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
|
||||||
total_message += chunk.delta.message.content
|
|
||||||
assert (
|
|
||||||
len(chunk.delta.message.content) > 0
|
|
||||||
if not chunk.delta.finish_reason
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
print(total_message)
|
|
||||||
assert "参考资料" in total_message
|
|
||||||
@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import (
|
|||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
FunctionAvailable,
|
FunctionAvailable,
|
||||||
ModelList,
|
ModelList, ModelCard,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_manager import ModelManager, ModelInstance
|
from model_providers.core.model_manager import ModelManager, ModelInstance
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
from model_providers.core.utils.generic import dictify, jsonify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
async def list_models(self, provider: str, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
pass
|
logger.info(f"Received list_models request for provider: {provider}")
|
||||||
|
# 返回ModelType所有的枚举
|
||||||
|
llm_models: list[AIModelEntity] = []
|
||||||
|
for model_type in ModelType.__members__.values():
|
||||||
|
try:
|
||||||
|
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||||
|
provider=provider, model_type=model_type
|
||||||
|
)
|
||||||
|
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
# models list[AIModelEntity]转换称List[ModelCard]
|
||||||
|
|
||||||
|
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
|
||||||
|
|
||||||
|
return ModelList(
|
||||||
|
data=models_list
|
||||||
|
)
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class Finish(str, Enum):
|
|||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Literal["model"] = "model"
|
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Literal["owner"] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|||||||
@ -245,6 +245,10 @@ class ModelManager:
|
|||||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_manager(self) -> ProviderManager:
|
||||||
|
return self._provider_manager
|
||||||
|
|
||||||
def get_model_instance(
|
def get_model_instance(
|
||||||
self, provider: str, model_type: ModelType, model: str
|
self, provider: str, model_type: ModelType, model: str
|
||||||
) -> ModelInstance:
|
) -> ModelInstance:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -74,6 +74,7 @@ class ModelType(Enum):
|
|||||||
raise ValueError(f"invalid model type {self}")
|
raise ValueError(f"invalid model type {self}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FetchFrom(Enum):
|
class FetchFrom(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for fetch from.
|
Enum class for fetch from.
|
||||||
|
|||||||
89
model-providers/model_providers/core/utils/utils.py
Normal file
89
model-providers/model_providers/core/utils/utils.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerNameFilter(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
|
||||||
|
# record.name.startswith("uvicorn.error")
|
||||||
|
# and record.getMessage().startswith("Uvicorn running on")
|
||||||
|
# )
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_file(log_path: str, sub_dir: str):
|
||||||
|
"""
|
||||||
|
sub_dir should contain a timestamp.
|
||||||
|
"""
|
||||||
|
log_dir = os.path.join(log_path, sub_dir)
|
||||||
|
# Here should be creating a new directory each time, so `exist_ok=False`
|
||||||
|
os.makedirs(log_dir, exist_ok=False)
|
||||||
|
return os.path.join(log_dir, "loom_core.log")
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dict(
|
||||||
|
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||||
|
) -> dict:
|
||||||
|
# for windows, the path should be a raw string.
|
||||||
|
log_file_path = (
|
||||||
|
log_file_path.encode("unicode-escape").decode()
|
||||||
|
if os.name == "nt"
|
||||||
|
else log_file_path
|
||||||
|
)
|
||||||
|
log_level = log_level.upper()
|
||||||
|
config_dict = {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"formatter": {
|
||||||
|
"format": (
|
||||||
|
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"logger_name_filter": {
|
||||||
|
"()": __name__ + ".LoggerNameFilter",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"stream_handler": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": "formatter",
|
||||||
|
"level": log_level,
|
||||||
|
# "stream": "ext://sys.stdout",
|
||||||
|
# "filters": ["logger_name_filter"],
|
||||||
|
},
|
||||||
|
"file_handler": {
|
||||||
|
"class": "logging.handlers.RotatingFileHandler",
|
||||||
|
"formatter": "formatter",
|
||||||
|
"level": log_level,
|
||||||
|
"filename": log_file_path,
|
||||||
|
"mode": "a",
|
||||||
|
"maxBytes": log_max_bytes,
|
||||||
|
"backupCount": log_backup_count,
|
||||||
|
"encoding": "utf8",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"loom_core": {
|
||||||
|
"handlers": ["stream_handler", "file_handler"],
|
||||||
|
"level": log_level,
|
||||||
|
"propagate": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"level": log_level,
|
||||||
|
"handlers": ["stream_handler", "file_handler"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestamp_ms():
|
||||||
|
t = time.time()
|
||||||
|
return int(round(t * 1000))
|
||||||
@ -26,18 +26,14 @@ boto3 = "1.28.17"
|
|||||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||||
# Any dependencies that do not meet that criteria will be removed.
|
# Any dependencies that do not meet that criteria will be removed.
|
||||||
pytest = "^7.3.0"
|
pytest = "^7.3.0"
|
||||||
pytest-cov = "^4.0.0"
|
|
||||||
pytest-dotenv = "^0.5.2"
|
|
||||||
duckdb-engine = "^0.7.0"
|
|
||||||
pytest-watcher = "^0.2.6"
|
|
||||||
freezegun = "^1.2.2"
|
freezegun = "^1.2.2"
|
||||||
responses = "^0.22.0"
|
pytest-mock = "^3.10.0"
|
||||||
pytest-asyncio = "^0.20.3"
|
|
||||||
lark = "^1.1.5"
|
|
||||||
pandas = "^2.0.0"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
pytest-socket = "^0.6.0"
|
|
||||||
syrupy = "^4.0.2"
|
syrupy = "^4.0.2"
|
||||||
|
pytest-watcher = "^0.3.4"
|
||||||
|
pytest-asyncio = "^0.21.1"
|
||||||
|
grandalf = "^0.8"
|
||||||
|
pytest-profiling = "^1.7.0"
|
||||||
|
responses = "^0.25.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
#
|
#
|
||||||
# https://github.com/tophat/syrupy
|
# https://github.com/tophat/syrupy
|
||||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
|
||||||
# Registering custom markers.
|
# Registering custom markers.
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
markers = [
|
markers = [
|
||||||
|
|||||||
99
model-providers/tests/server_unit_test/conftest.py
Normal file
99
model-providers/tests/server_unit_test/conftest.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
"""Configuration for unit tests."""
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, Sequence, List
|
||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
from pytest import Config, Function, Parser
|
||||||
|
|
||||||
|
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
"""Add custom command line options to pytest."""
|
||||||
|
parser.addoption(
|
||||||
|
"--only-extended",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--only-core",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run core tests. Never runs any extended tests.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||||
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
|
|
||||||
|
The `requires` marker is used to denote tests that require one or more packages
|
||||||
|
to be installed to run. If the package is not installed, the test is skipped.
|
||||||
|
|
||||||
|
The `requires` marker syntax is:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@pytest.mark.requires("package1", "package2")
|
||||||
|
def test_something():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
# Mapping from the name of a package to whether it is installed or not.
|
||||||
|
# Used to avoid repeated calls to `util.find_spec`
|
||||||
|
required_pkgs_info: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
only_extended = config.getoption("--only-extended") or False
|
||||||
|
only_core = config.getoption("--only-core") or False
|
||||||
|
|
||||||
|
if only_extended and only_core:
|
||||||
|
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
requires_marker = item.get_closest_marker("requires")
|
||||||
|
if requires_marker is not None:
|
||||||
|
if only_core:
|
||||||
|
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Iterate through the list of required packages
|
||||||
|
required_pkgs = requires_marker.args
|
||||||
|
for pkg in required_pkgs:
|
||||||
|
# If we haven't yet checked whether the pkg is installed
|
||||||
|
# let's check it and store the result.
|
||||||
|
if pkg not in required_pkgs_info:
|
||||||
|
try:
|
||||||
|
installed = util.find_spec(pkg) is not None
|
||||||
|
except Exception:
|
||||||
|
installed = False
|
||||||
|
required_pkgs_info[pkg] = installed
|
||||||
|
|
||||||
|
if not required_pkgs_info[pkg]:
|
||||||
|
if only_extended:
|
||||||
|
pytest.fail(
|
||||||
|
f"Package `{pkg}` is not installed but is required for "
|
||||||
|
f"extended tests. Please install the given package and "
|
||||||
|
f"try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If the package is not installed, we immediately break
|
||||||
|
# and mark the test as skipped.
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if only_extended:
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def logging_conf() -> dict:
|
||||||
|
return get_config_dict(
|
||||||
|
"DEBUG",
|
||||||
|
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||||
|
122,
|
||||||
|
111,
|
||||||
|
)
|
||||||
28
model-providers/tests/server_unit_test/test_init_server.py
Normal file
28
model-providers/tests/server_unit_test/test_init_server.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from model_providers import BootstrapWebBuilder
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("fastapi")
|
||||||
|
def test_init_server(logging_conf: dict) -> None:
|
||||||
|
try:
|
||||||
|
boot = BootstrapWebBuilder() \
|
||||||
|
.model_providers_cfg_path(
|
||||||
|
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||||
|
"/model_providers.yaml") \
|
||||||
|
.host(host="127.0.0.1") \
|
||||||
|
.port(port=20000) \
|
||||||
|
.build()
|
||||||
|
boot.set_app_event(started_event=None)
|
||||||
|
boot.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
async def pool_join_thread():
|
||||||
|
await boot.join()
|
||||||
|
|
||||||
|
asyncio.run(pool_join_thread())
|
||||||
|
except SystemExit:
|
||||||
|
logger.info("SystemExit raised, exiting")
|
||||||
|
raise
|
||||||
99
model-providers/tests/unit_test/conftest.py
Normal file
99
model-providers/tests/unit_test/conftest.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
"""Configuration for unit tests."""
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, Sequence, List
|
||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
from pytest import Config, Function, Parser
|
||||||
|
|
||||||
|
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
"""Add custom command line options to pytest."""
|
||||||
|
parser.addoption(
|
||||||
|
"--only-extended",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--only-core",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run core tests. Never runs any extended tests.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||||
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
|
|
||||||
|
The `requires` marker is used to denote tests that require one or more packages
|
||||||
|
to be installed to run. If the package is not installed, the test is skipped.
|
||||||
|
|
||||||
|
The `requires` marker syntax is:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@pytest.mark.requires("package1", "package2")
|
||||||
|
def test_something():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
# Mapping from the name of a package to whether it is installed or not.
|
||||||
|
# Used to avoid repeated calls to `util.find_spec`
|
||||||
|
required_pkgs_info: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
only_extended = config.getoption("--only-extended") or False
|
||||||
|
only_core = config.getoption("--only-core") or False
|
||||||
|
|
||||||
|
if only_extended and only_core:
|
||||||
|
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
requires_marker = item.get_closest_marker("requires")
|
||||||
|
if requires_marker is not None:
|
||||||
|
if only_core:
|
||||||
|
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Iterate through the list of required packages
|
||||||
|
required_pkgs = requires_marker.args
|
||||||
|
for pkg in required_pkgs:
|
||||||
|
# If we haven't yet checked whether the pkg is installed
|
||||||
|
# let's check it and store the result.
|
||||||
|
if pkg not in required_pkgs_info:
|
||||||
|
try:
|
||||||
|
installed = util.find_spec(pkg) is not None
|
||||||
|
except Exception:
|
||||||
|
installed = False
|
||||||
|
required_pkgs_info[pkg] = installed
|
||||||
|
|
||||||
|
if not required_pkgs_info[pkg]:
|
||||||
|
if only_extended:
|
||||||
|
pytest.fail(
|
||||||
|
f"Package `{pkg}` is not installed but is required for "
|
||||||
|
f"extended tests. Please install the given package and "
|
||||||
|
f"try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If the package is not installed, we immediately break
|
||||||
|
# and mark the test as skipped.
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if only_extended:
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def logging_conf() -> dict:
|
||||||
|
return get_config_dict(
|
||||||
|
"DEBUG",
|
||||||
|
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||||
|
122,
|
||||||
|
111,
|
||||||
|
)
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_manager_models(logging_conf: dict) -> None:
|
||||||
|
|
||||||
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
|
# 读取配置文件
|
||||||
|
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||||
|
"/model_providers.yaml")
|
||||||
|
# 转换配置文件
|
||||||
|
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||||
|
cfg)
|
||||||
|
# 创建模型管理器
|
||||||
|
provider_manager = ProviderManager(
|
||||||
|
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||||
|
provider="openai", model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||||
|
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||||
|
|
||||||
|
logger.info(f"predefined_models: {predefined_models}")
|
||||||
Loading…
x
Reference in New Issue
Block a user