mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
make format
This commit is contained in:
parent
b8d748b668
commit
a2df71d9ea
@ -1,9 +1,11 @@
|
|||||||
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from model_providers.bootstrap_web.openai_bootstrap_web import (
|
||||||
|
RESTFulOpenAIBootstrapBaseWeb,
|
||||||
|
)
|
||||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
|
||||||
from omegaconf import OmegaConf, DictConfig
|
|
||||||
|
|
||||||
|
|
||||||
def _to_custom_provide_configuration(cfg: DictConfig):
|
def _to_custom_provide_configuration(cfg: DictConfig):
|
||||||
"""
|
"""
|
||||||
@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
|||||||
provider_name_to_provider_model_records_dict = {}
|
provider_name_to_provider_model_records_dict = {}
|
||||||
|
|
||||||
for key, item in cfg.items():
|
for key, item in cfg.items():
|
||||||
model_credential = item.get('model_credential')
|
model_credential = item.get("model_credential")
|
||||||
provider_credential = item.get('provider_credential')
|
provider_credential = item.get("provider_credential")
|
||||||
# 转换omegaconf对象为基本属性
|
# 转换omegaconf对象为基本属性
|
||||||
if model_credential:
|
if model_credential:
|
||||||
model_credential = OmegaConf.to_container(model_credential)
|
model_credential = OmegaConf.to_container(model_credential)
|
||||||
@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
|||||||
provider_credential = OmegaConf.to_container(provider_credential)
|
provider_credential = OmegaConf.to_container(provider_credential)
|
||||||
provider_name_to_provider_records_dict[key] = provider_credential
|
provider_name_to_provider_records_dict[key] = provider_credential
|
||||||
|
|
||||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
return (
|
||||||
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BootstrapWebBuilder:
|
class BootstrapWebBuilder:
|
||||||
"""
|
"""
|
||||||
创建一个模型实例创建工具
|
创建一个模型实例创建工具
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_model_providers_cfg_path: str
|
_model_providers_cfg_path: str
|
||||||
_host: str
|
_host: str
|
||||||
_port: int
|
_port: int
|
||||||
@ -68,20 +74,26 @@ class BootstrapWebBuilder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> OpenAIBootstrapBaseWeb:
|
def build(self) -> OpenAIBootstrapBaseWeb:
|
||||||
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
|
assert (
|
||||||
|
self._model_providers_cfg_path is not None
|
||||||
|
and self._host is not None
|
||||||
|
and self._port is not None
|
||||||
|
)
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
(
|
||||||
cfg)
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
) = _to_custom_provide_configuration(cfg)
|
||||||
# 创建模型管理器
|
# 创建模型管理器
|
||||||
provider_manager = ModelManager(provider_name_to_provider_records_dict,
|
provider_manager = ModelManager(
|
||||||
provider_name_to_provider_model_records_dict)
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
# 创建web服务
|
# 创建web服务
|
||||||
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
|
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||||
"host": self._host,
|
cfg={"host": self._host, "port": self._port}
|
||||||
"port": self._port
|
)
|
||||||
|
|
||||||
})
|
|
||||||
restful.provider_manager = provider_manager
|
restful.provider_manager = provider_manager
|
||||||
return restful
|
return restful
|
||||||
|
|||||||
@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import (
|
|||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
FunctionAvailable,
|
FunctionAvailable,
|
||||||
ModelList, ModelCard,
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_manager import ModelManager, ModelInstance
|
from model_providers.core.model_manager import ModelInstance, ModelManager
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
from model_providers.core.utils.generic import dictify, jsonify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
llm_models: list[AIModelEntity] = []
|
llm_models: list[AIModelEntity] = []
|
||||||
for model_type in ModelType.__members__.values():
|
for model_type in ModelType.__members__.values():
|
||||||
try:
|
try:
|
||||||
provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle(
|
provider_model_bundle = (
|
||||||
provider=provider, model_type=model_type
|
self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||||
|
provider=provider, model_type=model_type
|
||||||
|
)
|
||||||
|
)
|
||||||
|
llm_models.extend(
|
||||||
|
provider_model_bundle.model_type_instance.predefined_models()
|
||||||
)
|
)
|
||||||
llm_models.extend(provider_model_bundle.model_type_instance.predefined_models())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}")
|
logger.error(
|
||||||
|
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||||
|
)
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
# models list[AIModelEntity]转换称List[ModelCard]
|
# models list[AIModelEntity]转换称List[ModelCard]
|
||||||
|
|
||||||
models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models]
|
models_list = [
|
||||||
|
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
|
||||||
|
for model in llm_models
|
||||||
|
]
|
||||||
|
|
||||||
return ModelList(
|
return ModelList(data=models_list)
|
||||||
data=models_list
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return EmbeddingsResponse(**dictify(response))
|
return EmbeddingsResponse(**dictify(response))
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
logging_conf: Optional[dict] = None,
|
logging_conf: Optional[dict] = None,
|
||||||
started_event: mp.Event = None,
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -3,7 +3,10 @@ from collections import deque
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
)
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -22,7 +22,15 @@ class Finish(str, Enum):
|
|||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm"
|
object: Literal[
|
||||||
|
"text-generation",
|
||||||
|
"embeddings",
|
||||||
|
"reranking",
|
||||||
|
"speech2text",
|
||||||
|
"moderation",
|
||||||
|
"tts",
|
||||||
|
"text2img",
|
||||||
|
] = "llm"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Literal["owner"] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -74,7 +74,6 @@ class ModelType(Enum):
|
|||||||
raise ValueError(f"invalid model type {self}")
|
raise ValueError(f"invalid model type {self}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FetchFrom(Enum):
|
class FetchFrom(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for fetch from.
|
Enum class for fetch from.
|
||||||
|
|||||||
@ -253,9 +253,11 @@ class ProviderManager:
|
|||||||
|
|
||||||
custom_model_configurations.append(
|
custom_model_configurations.append(
|
||||||
CustomModelConfiguration(
|
CustomModelConfiguration(
|
||||||
model=provider_model_record.get('model'),
|
model=provider_model_record.get("model"),
|
||||||
model_type=ModelType.value_of(provider_model_record.get('model_type')),
|
model_type=ModelType.value_of(
|
||||||
credentials=provider_model_credentials
|
provider_model_record.get("model_type")
|
||||||
|
),
|
||||||
|
credentials=provider_model_credentials,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_config_dict(
|
def get_config_dict(
|
||||||
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# for windows, the path should be a raw string.
|
# for windows, the path should be a raw string.
|
||||||
log_file_path = (
|
log_file_path = (
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
"""Configuration for unit tests."""
|
"""Configuration for unit tests."""
|
||||||
from importlib import util
|
|
||||||
from typing import Dict, Sequence, List
|
|
||||||
import logging
|
import logging
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, List, Sequence
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import Config, Function, Parser
|
from pytest import Config, Function, Parser
|
||||||
|
|
||||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
from model_providers.core.utils.utils import (
|
||||||
|
get_config_dict,
|
||||||
|
get_log_file,
|
||||||
|
get_timestamp_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: Parser) -> None:
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
|||||||
@ -1,21 +1,26 @@
|
|||||||
from model_providers import BootstrapWebBuilder
|
|
||||||
import logging
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from model_providers import BootstrapWebBuilder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("fastapi")
|
@pytest.mark.requires("fastapi")
|
||||||
def test_init_server(logging_conf: dict) -> None:
|
def test_init_server(logging_conf: dict) -> None:
|
||||||
try:
|
try:
|
||||||
boot = BootstrapWebBuilder() \
|
boot = (
|
||||||
|
BootstrapWebBuilder()
|
||||||
.model_providers_cfg_path(
|
.model_providers_cfg_path(
|
||||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||||
"/model_providers.yaml") \
|
"/model_providers.yaml"
|
||||||
.host(host="127.0.0.1") \
|
)
|
||||||
.port(port=20000) \
|
.host(host="127.0.0.1")
|
||||||
|
.port(port=20000)
|
||||||
.build()
|
.build()
|
||||||
|
)
|
||||||
boot.set_app_event(started_event=None)
|
boot.set_app_event(started_event=None)
|
||||||
boot.serve(logging_conf=logging_conf)
|
boot.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
"""Configuration for unit tests."""
|
"""Configuration for unit tests."""
|
||||||
from importlib import util
|
|
||||||
from typing import Dict, Sequence, List
|
|
||||||
import logging
|
import logging
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, List, Sequence
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import Config, Function, Parser
|
from pytest import Config, Function, Parser
|
||||||
|
|
||||||
from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file
|
from model_providers.core.utils.utils import (
|
||||||
|
get_config_dict,
|
||||||
|
get_log_file,
|
||||||
|
get_timestamp_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: Parser) -> None:
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.provider_manager import ProviderManager
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
@ -14,14 +13,17 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def test_provider_manager_models(logging_conf: dict) -> None:
|
def test_provider_manager_models(logging_conf: dict) -> None:
|
||||||
|
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
cfg = OmegaConf.load(
|
||||||
"/model_providers.yaml")
|
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||||
|
"/model_providers.yaml"
|
||||||
|
)
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
(
|
||||||
cfg)
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
) = _to_custom_provide_configuration(cfg)
|
||||||
# 创建模型管理器
|
# 创建模型管理器
|
||||||
provider_manager = ProviderManager(
|
provider_manager = ProviderManager(
|
||||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||||
@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None:
|
|||||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models()
|
predefined_models = (
|
||||||
|
provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"predefined_models: {predefined_models}")
|
logger.info(f"predefined_models: {predefined_models}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user