mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
Merge pull request #3579 from chatchat-space/dev_model_providers
完成了 BootstrapWebBuilder加载用户配置,适配标准报文使用RESTFulOpenAIBootstrapBaseWeb完成业务,提供xinference 插件示例
This commit is contained in:
commit
2526fa9062
@ -57,18 +57,14 @@ optional = true
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.7.0"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
pandas = "^2.0.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
model-providers = { path = "../model-providers", develop = true }
|
||||
|
||||
|
||||
|
||||
29
model-providers/model_providers.yaml
Normal file
29
model-providers/model_providers.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
openai:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
- model: 'gpt-4'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
provider_credential:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
|
||||
|
||||
@ -1,23 +1,99 @@
|
||||
from chatchat.configs import MODEL_PLATFORMS
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from model_providers.bootstrap_web.openai_bootstrap_web import (
|
||||
RESTFulOpenAIBootstrapBaseWeb,
|
||||
)
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
|
||||
def _to_custom_provide_configuration():
|
||||
def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
"""
|
||||
```
|
||||
openai:
|
||||
model_credential:
|
||||
- model: 'gpt-3.5-turbo'
|
||||
model_credentials:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
- model: 'gpt-4'
|
||||
model_credentials:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
provider_credential:
|
||||
openai_api_key: ''
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
```
|
||||
:param model_providers_cfg:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_records_dict = {}
|
||||
provider_name_to_provider_model_records_dict = {}
|
||||
|
||||
for key, item in cfg.items():
|
||||
model_credential = item.get("model_credential")
|
||||
provider_credential = item.get("provider_credential")
|
||||
# 转换omegaconf对象为基本属性
|
||||
if model_credential:
|
||||
model_credential = OmegaConf.to_container(model_credential)
|
||||
provider_name_to_provider_model_records_dict[key] = model_credential
|
||||
if provider_credential:
|
||||
provider_credential = OmegaConf.to_container(provider_credential)
|
||||
provider_name_to_provider_records_dict[key] = provider_credential
|
||||
|
||||
return (
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
|
||||
# 基于配置管理器创建的模型实例
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict={
|
||||
"openai": {
|
||||
"openai_api_key": "sk-4M9LYF",
|
||||
}
|
||||
},
|
||||
provider_name_to_provider_model_records_dict={},
|
||||
)
|
||||
class BootstrapWebBuilder:
|
||||
"""
|
||||
创建一个模型实例创建工具
|
||||
"""
|
||||
|
||||
_model_providers_cfg_path: str
|
||||
_host: str
|
||||
_port: int
|
||||
|
||||
def model_providers_cfg_path(self, model_providers_cfg_path: str):
|
||||
self._model_providers_cfg_path = model_providers_cfg_path
|
||||
return self
|
||||
|
||||
def host(self, host: str):
|
||||
self._host = host
|
||||
return self
|
||||
|
||||
def port(self, port: int):
|
||||
self._port = port
|
||||
return self
|
||||
|
||||
def build(self) -> OpenAIBootstrapBaseWeb:
|
||||
assert (
|
||||
self._model_providers_cfg_path is not None
|
||||
and self._host is not None
|
||||
and self._port is not None
|
||||
)
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
# 创建web服务
|
||||
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||
cfg={"host": self._host, "port": self._port}
|
||||
)
|
||||
restful.provider_manager = provider_manager
|
||||
return restful
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
import os
|
||||
from typing import Generator, cast
|
||||
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 基于配置管理器创建的模型实例
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict={
|
||||
"openai": {
|
||||
"openai_api_key": "sk-4M9LYF",
|
||||
}
|
||||
},
|
||||
provider_name_to_provider_model_records_dict={},
|
||||
)
|
||||
|
||||
#
|
||||
# Invoke model
|
||||
model_instance = provider_manager.get_model_instance(
|
||||
provider="openai", model_type=ModelType.LLM, model="gpt-4"
|
||||
)
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
total_message += chunk.delta.message.content
|
||||
assert (
|
||||
len(chunk.delta.message.content) > 0
|
||||
if not chunk.delta.finish_reason
|
||||
else True
|
||||
)
|
||||
print(total_message)
|
||||
assert "参考资料" in total_message
|
||||
27
model-providers/model_providers/bootstrap_web/common.py
Normal file
27
model-providers/model_providers/bootstrap_web/common.py
Normal file
@ -0,0 +1,27 @@
|
||||
import typing
|
||||
from subprocess import Popen
|
||||
from typing import Optional
|
||||
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionStreamResponseChoice,
|
||||
Finish,
|
||||
)
|
||||
from model_providers.core.utils.generic import jsonify
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
|
||||
|
||||
|
||||
def create_stream_chunk(
|
||||
request_id: str,
|
||||
model: str,
|
||||
delta: "ChatCompletionMessage",
|
||||
index: Optional[int] = 0,
|
||||
finish_reason: Optional[Finish] = None,
|
||||
) -> str:
|
||||
choice = ChatCompletionStreamResponseChoice(
|
||||
index=index, delta=delta, finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
|
||||
return jsonify(chunk)
|
||||
@ -0,0 +1,183 @@
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from model_providers.core.entities.model_entities import (
|
||||
ModelStatus,
|
||||
ModelWithProviderEntity,
|
||||
)
|
||||
from model_providers.core.entities.provider_entities import (
|
||||
ProviderQuotaType,
|
||||
ProviderType,
|
||||
QuotaConfiguration,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
ProviderModel,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
ProviderHelpEntity,
|
||||
SimpleProviderEntity,
|
||||
)
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration response.
|
||||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration response.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
preferred_provider_type: ProviderType
|
||||
custom_configuration: CustomConfigurationResponse
|
||||
system_configuration: SystemConfigurationResponse
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
#
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class ProviderListResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ProviderResponse] = []
|
||||
|
||||
|
||||
class ModelResponse(ProviderModel):
|
||||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class ProviderWithModelsResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ModelResponse]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class ProviderModelTypeResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ProviderWithModelsResponse] = []
|
||||
|
||||
|
||||
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
"""
|
||||
Simple provider entity response.
|
||||
"""
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
# if self.icon_small is not None:
|
||||
# self.icon_small = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
# )
|
||||
#
|
||||
# if self.icon_large is not None:
|
||||
# self.icon_large = I18nObject(
|
||||
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
# )
|
||||
|
||||
|
||||
class DefaultModelResponse(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.dict())
|
||||
@ -5,60 +5,272 @@ import multiprocessing as mp
|
||||
import os
|
||||
import pprint
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette import EventSourceResponse
|
||||
from uvicorn import Config, Server
|
||||
|
||||
from model_providers.bootstrap_web.common import create_stream_chunk
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
ProviderListResponse,
|
||||
ProviderModelTypeResponse,
|
||||
)
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
Finish,
|
||||
FunctionAvailable,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
UsageInfo,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelType,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
||||
from model_providers.core.utils.generic import dictify, jsonify
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MessageLike = Union[ChatMessage, PromptMessage]
|
||||
|
||||
async def create_stream_chat_completion(
|
||||
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
|
||||
):
|
||||
try:
|
||||
response = model_type_instance.invoke(
|
||||
model=chat_request.model,
|
||||
credentials={
|
||||
"openai_api_key": "sk-",
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123",
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError("User message content must be str")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# check if last message is user message
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: Union[str, list]
|
||||
) -> PromptMessage:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
||||
template: str the template string.
|
||||
|
||||
Returns:
|
||||
a message prompt template of the appropriate type.
|
||||
"""
|
||||
if isinstance(template, str):
|
||||
content = template
|
||||
elif isinstance(template, list):
|
||||
content = []
|
||||
for tmpl in template:
|
||||
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||
if isinstance(tmpl, str):
|
||||
text: str = tmpl
|
||||
else:
|
||||
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||
content.append(TextPromptMessageContent(data=text))
|
||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||
img_template = cast(dict, tmpl)["image_url"]
|
||||
if isinstance(img_template, str):
|
||||
img_template_obj = ImagePromptMessageContent(data=img_template)
|
||||
elif isinstance(img_template, dict):
|
||||
img_template = dict(img_template)
|
||||
if "url" in img_template:
|
||||
url = img_template["url"]
|
||||
else:
|
||||
url = None
|
||||
img_template_obj = ImagePromptMessageContent(data=url)
|
||||
else:
|
||||
raise ValueError()
|
||||
content.append(img_template_obj)
|
||||
else:
|
||||
raise ValueError()
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
if message_type in ("human", "user"):
|
||||
_message = UserPromptMessage(content=content)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
_message = AssistantPromptMessage(content=content)
|
||||
elif message_type == "system":
|
||||
_message = SystemPromptMessage(content=content)
|
||||
elif message_type in ("function", "tool"):
|
||||
_message = ToolPromptMessage(content=content)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return _message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> Union[PromptMessage]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- 2-tuple of (message class, template)
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, ChatMessage):
|
||||
_message = _create_template_from_message_type(
|
||||
message.role.to_origin_role(), message.content
|
||||
)
|
||||
|
||||
elif isinstance(message, PromptMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_template_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
if isinstance(message_type_str, str):
|
||||
_message = _create_template_from_message_type(message_type_str, template)
|
||||
else:
|
||||
raise ValueError(f"Expected message type string, got {message_type_str}")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
||||
|
||||
|
||||
async def _stream_openai_chat_completion(
|
||||
response: Generator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
request_id, model = None, None
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, LLMResultChunk):
|
||||
yield "[ERROR]"
|
||||
return
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
if request_id is None:
|
||||
request_id = "request_id"
|
||||
yield create_stream_chunk(
|
||||
request_id,
|
||||
model,
|
||||
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
|
||||
)
|
||||
|
||||
new_token = chunk.delta.message.content
|
||||
|
||||
if new_token:
|
||||
delta = ChatCompletionMessage(
|
||||
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
|
||||
content=new_token,
|
||||
tool_calls=chunk.delta.message.tool_calls,
|
||||
)
|
||||
yield create_stream_chunk(
|
||||
request_id=request_id,
|
||||
model=model,
|
||||
delta=delta,
|
||||
index=chunk.delta.index,
|
||||
finish_reason=chunk.delta.finish_reason,
|
||||
)
|
||||
|
||||
yield create_stream_chunk(
|
||||
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
**_convert_prompt_message_to_dict(message=response.message)
|
||||
),
|
||||
finish_reason=Finish.STOP,
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
id="request_id",
|
||||
model=response.model,
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
@ -94,21 +306,33 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/models",
|
||||
"/workspaces/current/model-providers",
|
||||
self.workspaces_model_providers,
|
||||
response_model=ProviderListResponse,
|
||||
methods=["GET"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/workspaces/current/models/model-types/{model_type}",
|
||||
self.workspaces_model_types,
|
||||
response_model=ProviderModelTypeResponse,
|
||||
methods=["GET"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/{provider}/v1/models",
|
||||
self.list_models,
|
||||
response_model=ModelList,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/embeddings",
|
||||
"/{provider}/v1/embeddings",
|
||||
self.create_embeddings,
|
||||
response_model=EmbeddingsResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
methods=["POST"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/v1/chat/completions",
|
||||
"/{provider}/v1/chat/completions",
|
||||
self.create_chat_completion,
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
@ -137,84 +361,111 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
|
||||
async def list_models(self, request: Request):
|
||||
pass
|
||||
async def workspaces_model_providers(self, request: Request):
|
||||
provider_list = self.get_provider_list(model_type=request.get("model_type"))
|
||||
return ProviderListResponse(data=provider_list)
|
||||
|
||||
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
|
||||
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
logger.info(f"Received list_models request for provider: {provider}")
|
||||
# 返回ModelType所有的枚举
|
||||
llm_models: list[AIModelEntity] = []
|
||||
for model_type in ModelType.__members__.values():
|
||||
try:
|
||||
provider_model_bundle = (
|
||||
self._provider_manager.provider_manager.get_provider_model_bundle(
|
||||
provider=provider, model_type=model_type
|
||||
)
|
||||
)
|
||||
llm_models.extend(
|
||||
provider_model_bundle.model_type_instance.predefined_models()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||
)
|
||||
logger.error(e)
|
||||
|
||||
# models list[AIModelEntity]转换称List[ModelCard]
|
||||
|
||||
models_list = [
|
||||
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
|
||||
for model in llm_models
|
||||
]
|
||||
|
||||
return ModelList(data=models_list)
|
||||
|
||||
async def create_embeddings(
|
||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
)
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
client = ZhipuAI(api_key=authorization)
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = None
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
text = encoding.decode(token)
|
||||
input += text
|
||||
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=embeddings_request.model,
|
||||
input=input,
|
||||
)
|
||||
response = None
|
||||
return EmbeddingsResponse(**dictify(response))
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||
)
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
model_provider_factory.get_providers(provider_name="openai")
|
||||
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
if chat_request.stream:
|
||||
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
else:
|
||||
response = model_type_instance.invoke(
|
||||
model="gpt-4",
|
||||
credentials={
|
||||
"openai_api_key": "sk-",
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=False,
|
||||
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
||||
)
|
||||
prompt_messages = [
|
||||
_convert_to_message(message) for message in chat_request.messages
|
||||
]
|
||||
|
||||
tools = []
|
||||
if chat_request.tools:
|
||||
tools = [
|
||||
PromptMessageTool(
|
||||
name=f.function.name,
|
||||
description=f.function.description,
|
||||
parameters=f.function.parameters,
|
||||
)
|
||||
for f in chat_request.tools
|
||||
]
|
||||
if chat_request.functions:
|
||||
tools.extend(
|
||||
[
|
||||
PromptMessageTool(
|
||||
name=f.name, description=f.description, parameters=f.parameters
|
||||
)
|
||||
for f in chat_request.functions
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||
tools=tools,
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
chat_response = ChatCompletionResponse(**dictify(response))
|
||||
if chat_request.stream:
|
||||
return EventSourceResponse(
|
||||
_stream_openai_chat_completion(response),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return await _openai_chat_completion(response)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
return chat_response
|
||||
except InvokeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
@ -224,10 +475,6 @@ def run(
|
||||
):
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
try:
|
||||
import signal
|
||||
|
||||
# 跳过键盘中断,使用xoscar的信号处理
|
||||
signal.signal(signal.SIGINT, lambda *_: None)
|
||||
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||
cfg=cfg.get("run_openai_api", {})
|
||||
)
|
||||
|
||||
@ -1,11 +1,28 @@
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
ModelResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from model_providers.core.bootstrap.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingsRequest,
|
||||
)
|
||||
from model_providers.core.entities.model_entities import ModelStatus
|
||||
from model_providers.core.entities.provider_entities import ProviderType
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
|
||||
"""最大的任务队列"""
|
||||
|
||||
_MAX_ONGOING_TASKS: int = 1
|
||||
@ -13,9 +30,150 @@ class Bootstrap:
|
||||
"""任务队列"""
|
||||
_QUEUE: deque = deque()
|
||||
|
||||
_provider_manager: ModelManager
|
||||
|
||||
def __init__(self):
|
||||
self._version = "v0.0.1"
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ModelManager:
|
||||
return self._provider_manager
|
||||
|
||||
@provider_manager.setter
|
||||
def provider_manager(self, provider_manager: ModelManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_provider_list(
|
||||
self, model_type: Optional[str] = None
|
||||
) -> List[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
if (
|
||||
model_type_entity
|
||||
not in provider_configuration.provider.supported_model_types
|
||||
):
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=ProviderType.value_of("custom"),
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
|
||||
return provider_responses
|
||||
|
||||
def get_models_by_model_type(
|
||||
self, model_type: str
|
||||
) -> List[ProviderWithModelsResponse]:
|
||||
"""
|
||||
get models by model type.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# 合并两个字典的键
|
||||
provider = set(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||
)
|
||||
provider.update(
|
||||
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||
)
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = (
|
||||
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||
)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
first_model = models[0]
|
||||
|
||||
has_active_models = any(
|
||||
[model.status == ModelStatus.ACTIVE for model in models]
|
||||
)
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if has_active_models
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
models=[
|
||||
ModelResponse(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return providers_with_models
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, cfg=None):
|
||||
@ -43,17 +201,17 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self, request: Request):
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_embeddings(
|
||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(
|
||||
self, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@ -13,16 +13,74 @@ class Role(str, Enum):
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_role: str) -> "Role":
|
||||
if origin_role == "user":
|
||||
return cls.USER
|
||||
elif origin_role == "assistant":
|
||||
return cls.ASSISTANT
|
||||
elif origin_role == "system":
|
||||
return cls.SYSTEM
|
||||
elif origin_role == "function":
|
||||
return cls.FUNCTION
|
||||
elif origin_role == "tool":
|
||||
return cls.TOOL
|
||||
else:
|
||||
raise ValueError(f"invalid origin role {origin_role}")
|
||||
|
||||
def to_origin_role(self) -> str:
|
||||
if self == self.USER:
|
||||
return "user"
|
||||
elif self == self.ASSISTANT:
|
||||
return "assistant"
|
||||
elif self == self.SYSTEM:
|
||||
return "system"
|
||||
elif self == self.FUNCTION:
|
||||
return "function"
|
||||
elif self == self.TOOL:
|
||||
return "tool"
|
||||
else:
|
||||
raise ValueError(f"invalid role {self}")
|
||||
|
||||
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
TOOL = "tool_calls"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_finish: str) -> "Finish":
|
||||
if origin_finish == "stop":
|
||||
return cls.STOP
|
||||
elif origin_finish == "length":
|
||||
return cls.LENGTH
|
||||
elif origin_finish == "tool_calls":
|
||||
return cls.TOOL
|
||||
else:
|
||||
raise ValueError(f"invalid origin finish {origin_finish}")
|
||||
|
||||
def to_origin_finish(self) -> str:
|
||||
if self == self.STOP:
|
||||
return "stop"
|
||||
elif self == self.LENGTH:
|
||||
return "length"
|
||||
elif self == self.TOOL:
|
||||
return "tool_calls"
|
||||
else:
|
||||
raise ValueError(f"invalid finish {self}")
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
object: Literal[
|
||||
"text-generation",
|
||||
"embeddings",
|
||||
"reranking",
|
||||
"speech2text",
|
||||
"moderation",
|
||||
"tts",
|
||||
"text2img",
|
||||
] = "llm"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
@ -82,17 +140,17 @@ class ChatCompletionRequest(BaseModel):
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
functions: Optional[List[FunctionDefinition]] = None
|
||||
function_call: Optional[FunctionCallDefinition] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = 0.75
|
||||
top_p: Optional[float] = 0.75
|
||||
top_k: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[list[str]] = (None,)
|
||||
max_tokens: Optional[int] = 256
|
||||
stop: Optional[list[str]] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def to_model_parameters_dict(self, *args, **kwargs):
|
||||
# 调用父类的to_dict方法,并排除tools字段
|
||||
helper.dump_model
|
||||
|
||||
return super().dict(
|
||||
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
|
||||
)
|
||||
|
||||
@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (
|
||||
self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0
|
||||
)
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@ -6,12 +6,82 @@ from pydantic import BaseModel
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = "custom"
|
||||
SYSTEM = "system"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = "times"
|
||||
TOKENS = "tokens"
|
||||
CREDITS = "credits"
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
model: str
|
||||
base_model_name: Optional[str] = None
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class QuotaConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider quota configuration.
|
||||
"""
|
||||
|
||||
quota_type: ProviderQuotaType
|
||||
quota_unit: QuotaUnit
|
||||
quota_limit: int
|
||||
quota_used: int
|
||||
is_valid: bool
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: Optional[dict] = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
|
||||
@ -245,6 +245,10 @@ class ModelManager:
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
return self._provider_manager
|
||||
|
||||
def get_model_instance(
|
||||
self, provider: str, model_type: ModelType, model: str
|
||||
) -> ModelInstance:
|
||||
|
||||
@ -28,6 +28,23 @@ class PromptMessageRole(Enum):
|
||||
return mode
|
||||
raise ValueError(f"invalid prompt message type value {value}")
|
||||
|
||||
def to_origin_role(self) -> str:
|
||||
"""
|
||||
Get origin role from prompt message role.
|
||||
|
||||
:return: origin role
|
||||
"""
|
||||
if self == self.SYSTEM:
|
||||
return "system"
|
||||
elif self == self.USER:
|
||||
return "user"
|
||||
elif self == self.ASSISTANT:
|
||||
return "assistant"
|
||||
elif self == self.TOOL:
|
||||
return "tool"
|
||||
else:
|
||||
raise ValueError(f"invalid role {self}")
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -49,7 +49,9 @@ class ModelProviderFactory:
|
||||
if init_cache:
|
||||
self.get_providers()
|
||||
|
||||
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]:
|
||||
def get_providers(
|
||||
self, provider_name: Union[str, set] = ""
|
||||
) -> list[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
@ -60,20 +62,36 @@ class ModelProviderFactory:
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for name, model_provider_extension in model_provider_extensions.items():
|
||||
if provider_name in (name, ""):
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
if isinstance(provider_name, str):
|
||||
if provider_name in (name, ""):
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
|
||||
providers.append(provider_schema)
|
||||
providers.append(provider_schema)
|
||||
elif isinstance(provider_name, set):
|
||||
if name in provider_name:
|
||||
# get model_provider instance
|
||||
model_provider_instance = model_provider_extension.provider_instance
|
||||
|
||||
# get provider schema
|
||||
provider_schema = model_provider_instance.get_provider_schema()
|
||||
|
||||
for model_type in provider_schema.supported_model_types:
|
||||
# get predefined models for given model type
|
||||
models = model_provider_instance.models(model_type)
|
||||
if models:
|
||||
provider_schema.models.extend(models)
|
||||
|
||||
providers.append(provider_schema)
|
||||
|
||||
# return providers
|
||||
return providers
|
||||
|
||||
@ -0,0 +1 @@
|
||||
- chatglm3-6b
|
||||
@ -0,0 +1,43 @@
|
||||
model: chatglm3-6b
|
||||
label:
|
||||
zh_Hans: chatglm3-6b
|
||||
en_US: chatglm3-6b
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.002'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@ -45,7 +45,7 @@ class ProviderManager:
|
||||
provider_name_to_provider_model_records_dict
|
||||
)
|
||||
|
||||
def get_configurations(self, provider: str) -> ProviderConfigurations:
|
||||
def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
|
||||
"""
|
||||
Get model provider configurations.
|
||||
|
||||
@ -155,7 +155,7 @@ class ProviderManager:
|
||||
|
||||
default_model = {}
|
||||
# Get provider configurations
|
||||
provider_configurations = self.get_configurations()
|
||||
provider_configurations = self.get_configurations(provider="openai")
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
@ -212,7 +212,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_credential_secret_variables = self._extract_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
@ -229,7 +229,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
model_credential_variables = self._extract_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
@ -242,7 +242,7 @@ class ProviderManager:
|
||||
continue
|
||||
|
||||
provider_model_credentials = {}
|
||||
for variable in model_credential_secret_variables:
|
||||
for variable in model_credential_variables:
|
||||
if variable in provider_model_record.get("model_credentials"):
|
||||
try:
|
||||
provider_model_credentials[
|
||||
@ -253,7 +253,7 @@ class ProviderManager:
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.get("model_name"),
|
||||
model=provider_model_record.get("model"),
|
||||
model_type=ModelType.value_of(
|
||||
provider_model_record.get("model_type")
|
||||
),
|
||||
@ -265,18 +265,17 @@ class ProviderManager:
|
||||
provider=custom_provider_configuration, models=custom_model_configurations
|
||||
)
|
||||
|
||||
def _extract_secret_variables(
|
||||
def _extract_variables(
|
||||
self, credential_form_schemas: list[CredentialFormSchema]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
Extract input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
secret_input_form_variables = []
|
||||
input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
return input_form_variables
|
||||
|
||||
88
model-providers/model_providers/core/utils/utils.py
Normal file
88
model-providers/model_providers/core/utils/utils.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggerNameFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
|
||||
# record.name.startswith("uvicorn.error")
|
||||
# and record.getMessage().startswith("Uvicorn running on")
|
||||
# )
|
||||
return True
|
||||
|
||||
|
||||
def get_log_file(log_path: str, sub_dir: str):
|
||||
"""
|
||||
sub_dir should contain a timestamp.
|
||||
"""
|
||||
log_dir = os.path.join(log_path, sub_dir)
|
||||
# Here should be creating a new directory each time, so `exist_ok=False`
|
||||
os.makedirs(log_dir, exist_ok=False)
|
||||
return os.path.join(log_dir, "loom_core.log")
|
||||
|
||||
|
||||
def get_config_dict(
|
||||
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
|
||||
) -> dict:
|
||||
# for windows, the path should be a raw string.
|
||||
log_file_path = (
|
||||
log_file_path.encode("unicode-escape").decode()
|
||||
if os.name == "nt"
|
||||
else log_file_path
|
||||
)
|
||||
log_level = log_level.upper()
|
||||
config_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"formatter": {
|
||||
"format": (
|
||||
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
|
||||
)
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"logger_name_filter": {
|
||||
"()": __name__ + ".LoggerNameFilter",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"stream_handler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
# "stream": "ext://sys.stdout",
|
||||
# "filters": ["logger_name_filter"],
|
||||
},
|
||||
"file_handler": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
"filename": log_file_path,
|
||||
"mode": "a",
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"loom_core": {
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
"level": log_level,
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": log_level,
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
},
|
||||
}
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_timestamp_ms():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
|
||||
pyyaml = "6.0.1"
|
||||
pydantic = "2.6.4"
|
||||
redis = "4.5.4"
|
||||
# config manage
|
||||
omegaconf = "2.0.6"
|
||||
# modle_runtime
|
||||
openai = "1.13.3"
|
||||
tiktoken = "0.5.2"
|
||||
@ -24,18 +26,14 @@ boto3 = "1.28.17"
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.7.0"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
pandas = "^2.0.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
|
||||
|
||||
|
||||
@ -46,12 +44,6 @@ optional = true
|
||||
ruff = "^0.1.5"
|
||||
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
@ -186,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
|
||||
104
model-providers/tests/server_unit_test/conftest.py
Normal file
104
model-providers/tests/server_unit_test/conftest.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Configuration for unit tests."""
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
get_timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logging_conf() -> dict:
|
||||
return get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
33
model-providers/tests/server_unit_test/test_init_server.py
Normal file
33
model-providers/tests/server_unit_test/test_init_server.py
Normal file
@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from model_providers import BootstrapWebBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("fastapi")
|
||||
def test_init_server(logging_conf: dict) -> None:
|
||||
try:
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml"
|
||||
)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
104
model-providers/tests/unit_test/conftest.py
Normal file
104
model-providers/tests/unit_test/conftest.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Configuration for unit tests."""
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
get_timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logging_conf() -> dict:
|
||||
return get_config_dict(
|
||||
"DEBUG",
|
||||
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
|
||||
122,
|
||||
111,
|
||||
)
|
||||
@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_provider_manager_models(logging_conf: dict) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml"
|
||||
)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.LLM
|
||||
)
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = (
|
||||
provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
Loading…
x
Reference in New Issue
Block a user