mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
模型列表适配
This commit is contained in:
parent
032dc8f58d
commit
f005ea3298
@ -57,18 +57,14 @@ optional = true
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.7.0"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
pandas = "^2.0.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
model-providers = { path = "../model-providers", develop = true }
|
||||
|
||||
|
||||
|
||||
@ -1,50 +0,0 @@
|
||||
import os
|
||||
from typing import Generator, cast
|
||||
|
||||
from model_providers import provider_manager
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 基于配置管理器创建的模型实例
|
||||
|
||||
# Invoke model
|
||||
model_instance = provider_manager.get_model_instance(
|
||||
provider="openai", model_type=ModelType.LLM, model="gpt-4"
|
||||
)
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
total_message += chunk.delta.message.content
|
||||
assert (
|
||||
len(chunk.delta.message.content) > 0
|
||||
if not chunk.delta.finish_reason
|
||||
else True
|
||||
)
|
||||
print(total_message)
|
||||
assert "参考资料" in total_message
|
||||
@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import (
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
FunctionAvailable,
|
||||
ModelList,
|
||||
ModelList, ModelCard,
|
||||
)
|
||||
from model_providers.core.model_manager import ModelManager, ModelInstance
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||
from model_providers.core.utils.generic import dictify, jsonify
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
started_event.set()
|
||||
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
pass
|
||||
logger.info(f"Received list_models request for provider: {provider}")
|
||||
# 返回ModelType所有的枚举
|
||||
llm_models: list[AIModelEntity] = []
|
||||
for model_type in ModelType.__members__.values():
|
||||
try:
|
||||
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||
provider=provider, model_type=model_type
|
||||
)
|
||||
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
|
||||
except Exception as e:
|
||||
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
|
||||
logger.error(e)
|
||||
|
||||
# models list[AIModelEntity]转换称List[ModelCard]
|
||||
|
||||
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
|
||||
|
||||
return ModelList(
|
||||
data=models_list
|
||||
)
|
||||
|
||||
async def create_embeddings(
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
|
||||
@ -22,7 +22,7 @@ class Finish(str, Enum):
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
@ -245,6 +245,10 @@ class ModelManager:
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
return self._provider_manager
|
||||
|
||||
def get_model_instance(
|
||||
self, provider: str, model_type: ModelType, model: str
|
||||
) -> ModelInstance:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -74,6 +74,7 @@ class ModelType(Enum):
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
|
||||
89
model-providers/model_providers/core/utils/utils.py
Normal file
89
model-providers/model_providers/core/utils/utils.py
Normal file
@ -0,0 +1,89 @@
|
||||
import logging
|
||||
|
||||
import time
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggerNameFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
|
||||
# record.name.startswith("uvicorn.error")
|
||||
# and record.getMessage().startswith("Uvicorn running on")
|
||||
# )
|
||||
return True
|
||||
|
||||
|
||||
def get_log_file(log_path: str, sub_dir: str):
|
||||
"""
|
||||
sub_dir should contain a timestamp.
|
||||
"""
|
||||
log_dir = os.path.join(log_path, sub_dir)
|
||||
# Here should be creating a new directory each time, so `exist_ok=False`
|
||||
os.makedirs(log_dir, exist_ok=False)
|
||||
return os.path.join(log_dir, "loom_core.log")
|
||||
|
||||
|
||||
def get_config_dict(
|
||||
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||
) -> dict:
|
||||
# for windows, the path should be a raw string.
|
||||
log_file_path = (
|
||||
log_file_path.encode("unicode-escape").decode()
|
||||
if os.name == "nt"
|
||||
else log_file_path
|
||||
)
|
||||
log_level = log_level.upper()
|
||||
config_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"formatter": {
|
||||
"format": (
|
||||
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
|
||||
)
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"logger_name_filter": {
|
||||
"()": __name__ + ".LoggerNameFilter",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"stream_handler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
# "stream": "ext://sys.stdout",
|
||||
# "filters": ["logger_name_filter"],
|
||||
},
|
||||
"file_handler": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
"filename": log_file_path,
|
||||
"mode": "a",
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"loom_core": {
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
"level": log_level,
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": log_level,
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
},
|
||||
}
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_timestamp_ms():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
@ -26,18 +26,14 @@ boto3 = "1.28.17"
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.7.0"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
pandas = "^2.0.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
|
||||
|
||||
|
||||
@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
|
||||
99
model-providers/tests/server_unit_test/conftest.py
Normal file
99
model-providers/tests/server_unit_test/conftest.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence, List
|
||||
import logging
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logging_conf() -> dict:
|
||||
return get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
28
model-providers/tests/server_unit_test/test_init_server.py
Normal file
28
model-providers/tests/server_unit_test/test_init_server.py
Normal file
@ -0,0 +1,28 @@
|
||||
from model_providers import BootstrapWebBuilder
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("fastapi")
|
||||
def test_init_server(logging_conf: dict) -> None:
|
||||
try:
|
||||
boot = BootstrapWebBuilder() \
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml") \
|
||||
.host(host="127.0.0.1") \
|
||||
.port(port=20000) \
|
||||
.build()
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
99
model-providers/tests/unit_test/conftest.py
Normal file
99
model-providers/tests/unit_test/conftest.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence, List
|
||||
import logging
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logging_conf() -> dict:
|
||||
return get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
@ -0,0 +1,39 @@
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_provider_manager_models(logging_conf: dict) -> None:
|
||||
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml")
|
||||
# 转换配置文件
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.LLM
|
||||
)
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
Loading…
x
Reference in New Issue
Block a user