Merge pull request #3579 from chatchat-space/dev_model_providers

完成了 BootstrapWebBuilder加载用户配置,适配标准报文使用RESTFulOpenAIBootstrapBaseWeb完成业务,提供xinference 插件示例
This commit is contained in:
glide-the 2024-04-01 20:10:26 +08:00 committed by GitHub
commit 2526fa9062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1469 additions and 227 deletions

View File

@ -57,18 +57,14 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
model-providers = { path = "../model-providers", develop = true }

View File

@ -0,0 +1,29 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'

View File

@ -1,23 +1,99 @@
from chatchat.configs import MODEL_PLATFORMS
from omegaconf import DictConfig, OmegaConf
from model_providers.bootstrap_web.openai_bootstrap_web import (
RESTFulOpenAIBootstrapBaseWeb,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager
def _to_custom_provide_configuration():
def _to_custom_provide_configuration(cfg: DictConfig):
"""
```
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
```
:param model_providers_cfg:
:return:
"""
provider_name_to_provider_records_dict = {}
provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items():
model_credential = item.get("model_credential")
provider_credential = item.get("provider_credential")
# 转换omegaconf对象为基本属性
if model_credential:
model_credential = OmegaConf.to_container(model_credential)
provider_name_to_provider_model_records_dict[key] = model_credential
if provider_credential:
provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential
return (
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
"openai": {
"openai_api_key": "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={},
)
class BootstrapWebBuilder:
"""
创建一个模型实例创建工具
"""
_model_providers_cfg_path: str
_host: str
_port: int
def model_providers_cfg_path(self, model_providers_cfg_path: str):
self._model_providers_cfg_path = model_providers_cfg_path
return self
def host(self, host: str):
self._host = host
return self
def port(self, port: int):
self._port = port
return self
def build(self) -> OpenAIBootstrapBaseWeb:
assert (
self._model_providers_cfg_path is not None
and self._host is not None
and self._port is not None
)
# 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ModelManager(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
# 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg={"host": self._host, "port": self._port}
)
restful.provider_manager = provider_manager
return restful

View File

@ -1,58 +0,0 @@
import os
from typing import Generator, cast
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == "__main__":
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
"openai": {
"openai_api_key": "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={},
)
#
# Invoke model
model_instance = provider_manager.get_model_instance(
provider="openai", model_type=ModelType.LLM, model="gpt-4"
)
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert (
len(chunk.delta.message.content) > 0
if not chunk.delta.finish_reason
else True
)
print(total_message)
assert "参考资料" in total_message

View File

@ -0,0 +1,27 @@
import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -0,0 +1,183 @@
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
)
from model_providers.core.entities.provider_entities import (
ProviderQuotaType,
ProviderType,
QuotaConfiguration,
)
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ProviderModel,
)
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
def __init__(self, **data) -> None:
super().__init__(**data)
#
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderListResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderResponse] = []
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderModelTypeResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderWithModelsResponse] = []
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.dict())

View File

@ -5,60 +5,272 @@ import multiprocessing as mp
import os
import pprint
import threading
from typing import Any, Dict, Optional
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import tiktoken
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse,
ProviderModelTypeResponse,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest,
EmbeddingsResponse,
Finish,
FunctionAvailable,
ModelCard,
ModelList,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeError
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
async def create_stream_chat_completion(
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return response
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
@ -94,21 +306,33 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
self._router.add_api_route(
"/v1/models",
"/workspaces/current/model-providers",
self.workspaces_model_providers,
response_model=ProviderListResponse,
methods=["GET"],
)
self._router.add_api_route(
"/workspaces/current/models/model-types/{model_type}",
self.workspaces_model_types,
response_model=ProviderModelTypeResponse,
methods=["GET"],
)
self._router.add_api_route(
"/{provider}/v1/models",
self.list_models,
response_model=ModelList,
methods=["GET"],
)
self._router.add_api_route(
"/v1/embeddings",
"/{provider}/v1/embeddings",
self.create_embeddings,
response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._router.add_api_route(
"/v1/chat/completions",
"/{provider}/v1/chat/completions",
self.create_chat_completion,
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
@ -137,84 +361,111 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None:
started_event.set()
async def list_models(self, request: Request):
pass
async def workspaces_model_providers(self, request: Request):
provider_list = self.get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = (
self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
)
llm_models.extend(
provider_model_bundle.model_type_instance.predefined_models()
)
except Exception as e:
logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
)
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
for model in llm_models
]
return ModelList(data=models_list)
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
response = None
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name="openai")
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request)
return EventSourceResponse(generator, media_type="text/event-stream")
else:
response = model_type_instance.invoke(
model="gpt-4",
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=False,
model_instance = self._provider_manager.get_model_instance(
provider=provider, model_type=ModelType.LLM, model=chat_request.model
)
prompt_messages = [
_convert_to_message(message) for message in chat_request.messages
]
tools = []
if chat_request.tools:
tools = [
PromptMessageTool(
name=f.function.name,
description=f.function.description,
parameters=f.function.parameters,
)
for f in chat_request.tools
]
if chat_request.functions:
tools.extend(
[
PromptMessageTool(
name=f.name, description=f.description, parameters=f.parameters
)
for f in chat_request.functions
]
)
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={**chat_request.to_model_parameters_dict()},
tools=tools,
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
chat_response = ChatCompletionResponse(**dictify(response))
if chat_request.stream:
return EventSourceResponse(
_stream_openai_chat_completion(response),
media_type="text/event-stream",
)
else:
return await _openai_chat_completion(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
return chat_response
except InvokeError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
def run(
@ -224,10 +475,6 @@ def run(
):
logging.config.dictConfig(logging_conf) # type: ignore
try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg=cfg.get("run_openai_api", {})
)

View File

@ -1,11 +1,28 @@
from abc import abstractmethod
from collections import deque
from typing import List, Optional
from fastapi import Request
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
class Bootstrap:
"""最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1
@ -13,9 +30,150 @@ class Bootstrap:
"""任务队列"""
_QUEUE: deque = deque()
_provider_manager: ModelManager
def __init__(self):
self._version = "v0.0.1"
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models
@classmethod
@abstractmethod
def from_config(cls, cfg=None):
@ -43,17 +201,17 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
super().__init__()
@abstractmethod
async def list_models(self, request: Request):
async def list_models(self, provider: str, request: Request):
pass
@abstractmethod
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
pass
@abstractmethod
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
pass

View File

@ -13,16 +13,74 @@ class Role(str, Enum):
FUNCTION = "function"
TOOL = "tool"
@classmethod
def value_of(cls, origin_role: str) -> "Role":
if origin_role == "user":
return cls.USER
elif origin_role == "assistant":
return cls.ASSISTANT
elif origin_role == "system":
return cls.SYSTEM
elif origin_role == "function":
return cls.FUNCTION
elif origin_role == "tool":
return cls.TOOL
else:
raise ValueError(f"invalid origin role {origin_role}")
def to_origin_role(self) -> str:
if self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.SYSTEM:
return "system"
elif self == self.FUNCTION:
return "function"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
@classmethod
def value_of(cls, origin_finish: str) -> "Finish":
if origin_finish == "stop":
return cls.STOP
elif origin_finish == "length":
return cls.LENGTH
elif origin_finish == "tool_calls":
return cls.TOOL
else:
raise ValueError(f"invalid origin finish {origin_finish}")
def to_origin_finish(self) -> str:
if self == self.STOP:
return "stop"
elif self == self.LENGTH:
return "length"
elif self == self.TOOL:
return "tool_calls"
else:
raise ValueError(f"invalid finish {self}")
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
object: Literal[
"text-generation",
"embeddings",
"reranking",
"speech2text",
"moderation",
"tts",
"text2img",
] = "llm"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
@ -82,17 +140,17 @@ class ChatCompletionRequest(BaseModel):
tools: Optional[List[FunctionAvailable]] = None
functions: Optional[List[FunctionDefinition]] = None
function_call: Optional[FunctionCallDefinition] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = 0.75
top_p: Optional[float] = 0.75
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[list[str]] = (None,)
max_tokens: Optional[int] = 256
stop: Optional[list[str]] = None
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict(
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
)

View File

@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
else:
return None
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (
self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0
)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Get custom credentials.

View File

@ -6,12 +6,82 @@ from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
@staticmethod
def value_of(value):
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.

View File

@ -245,6 +245,10 @@ class ModelManager:
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
@property
def provider_manager(self) -> ProviderManager:
return self._provider_manager
def get_model_instance(
self, provider: str, model_type: ModelType, model: str
) -> ModelInstance:

View File

@ -28,6 +28,23 @@ class PromptMessageRole(Enum):
return mode
raise ValueError(f"invalid prompt message type value {value}")
def to_origin_role(self) -> str:
"""
Get origin role from prompt message role.
:return: origin role
"""
if self == self.SYSTEM:
return "system"
elif self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class PromptMessageTool(BaseModel):
"""

View File

@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
from typing import Any, Optional
from typing import Any, List, Optional
from pydantic import BaseModel

View File

@ -1,7 +1,7 @@
import importlib
import logging
import os
from typing import Optional
from typing import Optional, Union
from pydantic import BaseModel
@ -49,7 +49,9 @@ class ModelProviderFactory:
if init_cache:
self.get_providers()
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]:
def get_providers(
self, provider_name: Union[str, set] = ""
) -> list[ProviderEntity]:
"""
Get all providers
:return: list of providers
@ -60,20 +62,36 @@ class ModelProviderFactory:
# traverse all model_provider_extensions
providers = []
for name, model_provider_extension in model_provider_extensions.items():
if provider_name in (name, ""):
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
if isinstance(provider_name, str):
if provider_name in (name, ""):
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
providers.append(provider_schema)
elif isinstance(provider_name, set):
if name in provider_name:
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
# return providers
return providers

View File

@ -0,0 +1,43 @@
model: chatglm3-6b
label:
zh_Hans: chatglm3-6b
en_US: chatglm3-6b
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.002'
unit: '0.001'
currency: USD

View File

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from typing import Optional, Union
from sqlalchemy.exc import IntegrityError
@ -45,7 +45,7 @@ class ProviderManager:
provider_name_to_provider_model_records_dict
)
def get_configurations(self, provider: str) -> ProviderConfigurations:
def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
"""
Get model provider configurations.
@ -155,7 +155,7 @@ class ProviderManager:
default_model = {}
# Get provider configurations
provider_configurations = self.get_configurations()
provider_configurations = self.get_configurations(provider="openai")
# get available models from provider_configurations
available_models = provider_configurations.get_models(
@ -212,7 +212,7 @@ class ProviderManager:
:return:
"""
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_credential_secret_variables = self._extract_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema
else []
@ -229,7 +229,7 @@ class ProviderManager:
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
model_credential_variables = self._extract_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
@ -242,7 +242,7 @@ class ProviderManager:
continue
provider_model_credentials = {}
for variable in model_credential_secret_variables:
for variable in model_credential_variables:
if variable in provider_model_record.get("model_credentials"):
try:
provider_model_credentials[
@ -253,7 +253,7 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.get("model_name"),
model=provider_model_record.get("model"),
model_type=ModelType.value_of(
provider_model_record.get("model_type")
),
@ -265,18 +265,17 @@ class ProviderManager:
provider=custom_provider_configuration, models=custom_model_configurations
)
def _extract_secret_variables(
def _extract_variables(
self, credential_form_schemas: list[CredentialFormSchema]
) -> list[str]:
"""
Extract secret input form variables.
Extract input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
return input_form_variables

View File

@ -0,0 +1,88 @@
import logging
import os
import time
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "loom_core.log")
def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
# for windows, the path should be a raw string.
log_file_path = (
log_file_path.encode("unicode-escape").decode()
if os.name == "nt"
else log_file_path
)
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {
"format": (
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
)
},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
pyyaml = "6.0.1"
pydantic = "2.6.4"
redis = "4.5.4"
# config manage
omegaconf = "2.0.6"
# modle_runtime
openai = "1.13.3"
tiktoken = "0.5.2"
@ -24,18 +26,14 @@ boto3 = "1.28.17"
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
@ -46,12 +44,6 @@ optional = true
ruff = "^0.1.5"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.dev]
optional = true
@ -186,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [

View File

@ -0,0 +1,104 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,33 @@
import asyncio
import logging
import pytest
from model_providers import BootstrapWebBuilder
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,104 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,43 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = (
provider_model_bundle_emb.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")