Merge pull request #3579 from chatchat-space/dev_model_providers

完成了 BootstrapWebBuilder加载用户配置,适配标准报文使用RESTFulOpenAIBootstrapBaseWeb完成业务,提供xinference 插件示例
This commit is contained in:
glide-the 2024-04-01 20:10:26 +08:00 committed by GitHub
commit 2526fa9062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1469 additions and 227 deletions

View File

@ -57,18 +57,14 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response). # dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed. # Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0" pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" pytest-mock = "^3.10.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2" syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
model-providers = { path = "../model-providers", develop = true } model-providers = { path = "../model-providers", develop = true }

View File

@ -0,0 +1,29 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'

View File

@ -1,23 +1,99 @@
from chatchat.configs import MODEL_PLATFORMS from omegaconf import DictConfig, OmegaConf
from model_providers.bootstrap_web.openai_bootstrap_web import (
RESTFulOpenAIBootstrapBaseWeb,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
def _to_custom_provide_configuration(): def _to_custom_provide_configuration(cfg: DictConfig):
"""
```
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_credentials:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: ''
openai_organization: ''
openai_api_base: ''
```
:param model_providers_cfg:
:return:
"""
provider_name_to_provider_records_dict = {} provider_name_to_provider_records_dict = {}
provider_name_to_provider_model_records_dict = {} provider_name_to_provider_model_records_dict = {}
for key, item in cfg.items():
model_credential = item.get("model_credential")
provider_credential = item.get("provider_credential")
# 转换omegaconf对象为基本属性
if model_credential:
model_credential = OmegaConf.to_container(model_credential)
provider_name_to_provider_model_records_dict[key] = model_credential
if provider_credential:
provider_credential = OmegaConf.to_container(provider_credential)
provider_name_to_provider_records_dict[key] = provider_credential
return ( return (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict, provider_name_to_provider_model_records_dict,
) )
# 基于配置管理器创建的模型实例 class BootstrapWebBuilder:
provider_manager = ModelManager( """
provider_name_to_provider_records_dict={ 创建一个模型实例创建工具
"openai": { """
"openai_api_key": "sk-4M9LYF",
} _model_providers_cfg_path: str
}, _host: str
provider_name_to_provider_model_records_dict={}, _port: int
)
def model_providers_cfg_path(self, model_providers_cfg_path: str):
self._model_providers_cfg_path = model_providers_cfg_path
return self
def host(self, host: str):
self._host = host
return self
def port(self, port: int):
self._port = port
return self
def build(self) -> OpenAIBootstrapBaseWeb:
assert (
self._model_providers_cfg_path is not None
and self._host is not None
and self._port is not None
)
# 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ModelManager(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
)
# 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg={"host": self._host, "port": self._port}
)
restful.provider_manager = provider_manager
return restful

View File

@ -1,58 +0,0 @@
import os
from typing import Generator, cast
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == "__main__":
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
"openai": {
"openai_api_key": "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={},
)
#
# Invoke model
model_instance = provider_manager.get_model_instance(
provider="openai", model_type=ModelType.LLM, model="gpt-4"
)
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert (
len(chunk.delta.message.content) > 0
if not chunk.delta.finish_reason
else True
)
print(total_message)
assert "参考资料" in total_message

View File

@ -0,0 +1,27 @@
import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -0,0 +1,183 @@
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
)
from model_providers.core.entities.provider_entities import (
ProviderQuotaType,
ProviderType,
QuotaConfiguration,
)
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ProviderModel,
)
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
def __init__(self, **data) -> None:
super().__init__(**data)
#
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderListResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderResponse] = []
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderModelTypeResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderWithModelsResponse] = []
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.dict())

View File

@ -5,60 +5,272 @@ import multiprocessing as mp
import os import os
import pprint import pprint
import threading import threading
from typing import Any, Dict, Optional from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import tiktoken
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse,
ProviderModelTypeResponse,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import ( from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
Finish,
FunctionAvailable, FunctionAvailable,
ModelCard,
ModelList, ModelList,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
) )
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import (
from model_providers.core.model_runtime.model_providers import model_provider_factory AIModelEntity,
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel ModelType,
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
) )
from model_providers.core.model_runtime.errors.invoke import InvokeError
from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
async def create_stream_chat_completion( MessageLikeRepresentation = Union[
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest MessageLike,
): Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
try: str,
response = model_type_instance.invoke( ]
model=chat_request.model,
credentials={
"openai_api_key": "sk-", def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"), """
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), Convert PromptMessage to dict for OpenAI Compatibility API
}, """
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], if isinstance(message, UserPromptMessage):
model_parameters={**chat_request.to_model_parameters_dict()}, message = cast(UserPromptMessage, message)
stop=chat_request.stop, if isinstance(message.content, str):
stream=chat_request.stream, message_dict = {"role": "user", "content": message.content}
user="abc-123", else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
) )
return response
except Exception as e: return _message
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
@ -94,21 +306,33 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/models", "/workspaces/current/model-providers",
self.workspaces_model_providers,
response_model=ProviderListResponse,
methods=["GET"],
)
self._router.add_api_route(
"/workspaces/current/models/model-types/{model_type}",
self.workspaces_model_types,
response_model=ProviderModelTypeResponse,
methods=["GET"],
)
self._router.add_api_route(
"/{provider}/v1/models",
self.list_models, self.list_models,
response_model=ModelList, response_model=ModelList,
methods=["GET"], methods=["GET"],
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/embeddings", "/{provider}/v1/embeddings",
self.create_embeddings, self.create_embeddings,
response_model=EmbeddingsResponse, response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
methods=["POST"], methods=["POST"],
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/chat/completions", "/{provider}/v1/chat/completions",
self.create_chat_completion, self.create_chat_completion,
response_model=ChatCompletionResponse, response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
@ -137,84 +361,111 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None: if started_event is not None:
started_event.set() started_event.set()
async def list_models(self, request: Request): async def workspaces_model_providers(self, request: Request):
pass provider_list = self.get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
llm_models: list[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = (
self._provider_manager.provider_manager.get_provider_model_bundle(
provider=provider, model_type=model_type
)
)
llm_models.extend(
provider_model_bundle.model_type_instance.predefined_models()
)
except Exception as e:
logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
)
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
for model in llm_models
]
return ModelList(data=models_list)
async def create_embeddings( async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
) )
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else: response = None
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
return EmbeddingsResponse(**dictify(response)) return EmbeddingsResponse(**dictify(response))
async def create_chat_completion( async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
logger.info( logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}" f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
) )
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization") model_instance = self._provider_manager.get_model_instance(
authorization = authorization.split("Bearer ")[-1] provider=provider, model_type=ModelType.LLM, model=chat_request.model
else: )
authorization = os.environ["API_KEY"] prompt_messages = [
model_provider_factory.get_providers(provider_name="openai") _convert_to_message(message) for message in chat_request.messages
provider_instance = model_provider_factory.get_provider_instance("openai") ]
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream: tools = []
generator = create_stream_chat_completion(model_type_instance, chat_request) if chat_request.tools:
return EventSourceResponse(generator, media_type="text/event-stream") tools = [
else: PromptMessageTool(
response = model_type_instance.invoke( name=f.function.name,
model="gpt-4", description=f.function.description,
credentials={ parameters=f.function.parameters,
"openai_api_key": "sk-", )
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"), for f in chat_request.tools
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), ]
}, if chat_request.functions:
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], tools.extend(
model_parameters={ [
"temperature": 0.7, PromptMessageTool(
"top_p": 1.0, name=f.name, description=f.description, parameters=f.parameters
"top_k": 1, )
"plugin_web_search": True, for f in chat_request.functions
}, ]
stop=["you"], )
stream=False,
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={**chat_request.to_model_parameters_dict()},
tools=tools,
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123", user="abc-123",
) )
chat_response = ChatCompletionResponse(**dictify(response)) if chat_request.stream:
return EventSourceResponse(
_stream_openai_chat_completion(response),
media_type="text/event-stream",
)
else:
return await _openai_chat_completion(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
return chat_response except InvokeError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
def run( def run(
@ -224,10 +475,6 @@ def run(
): ):
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
try: try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config( api = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg=cfg.get("run_openai_api", {}) cfg=cfg.get("run_openai_api", {})
) )

View File

@ -1,11 +1,28 @@
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
from typing import List, Optional
from fastapi import Request from fastapi import Request
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
class Bootstrap: class Bootstrap:
"""最大的任务队列""" """最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1 _MAX_ONGOING_TASKS: int = 1
@ -13,9 +30,150 @@ class Bootstrap:
"""任务队列""" """任务队列"""
_QUEUE: deque = deque() _QUEUE: deque = deque()
_provider_manager: ModelManager
def __init__(self): def __init__(self):
self._version = "v0.0.1" self._version = "v0.0.1"
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_config(cls, cfg=None): def from_config(cls, cfg=None):
@ -43,17 +201,17 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
super().__init__() super().__init__()
@abstractmethod @abstractmethod
async def list_models(self, request: Request): async def list_models(self, provider: str, request: Request):
pass pass
@abstractmethod @abstractmethod
async def create_embeddings( async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
pass pass
@abstractmethod @abstractmethod
async def create_chat_completion( async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
pass pass

View File

@ -13,16 +13,74 @@ class Role(str, Enum):
FUNCTION = "function" FUNCTION = "function"
TOOL = "tool" TOOL = "tool"
@classmethod
def value_of(cls, origin_role: str) -> "Role":
if origin_role == "user":
return cls.USER
elif origin_role == "assistant":
return cls.ASSISTANT
elif origin_role == "system":
return cls.SYSTEM
elif origin_role == "function":
return cls.FUNCTION
elif origin_role == "tool":
return cls.TOOL
else:
raise ValueError(f"invalid origin role {origin_role}")
def to_origin_role(self) -> str:
if self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.SYSTEM:
return "system"
elif self == self.FUNCTION:
return "function"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class Finish(str, Enum): class Finish(str, Enum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"
TOOL = "tool_calls" TOOL = "tool_calls"
@classmethod
def value_of(cls, origin_finish: str) -> "Finish":
if origin_finish == "stop":
return cls.STOP
elif origin_finish == "length":
return cls.LENGTH
elif origin_finish == "tool_calls":
return cls.TOOL
else:
raise ValueError(f"invalid origin finish {origin_finish}")
def to_origin_finish(self) -> str:
if self == self.STOP:
return "stop"
elif self == self.LENGTH:
return "length"
elif self == self.TOOL:
return "tool_calls"
else:
raise ValueError(f"invalid finish {self}")
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Literal["model"] = "model" object: Literal[
"text-generation",
"embeddings",
"reranking",
"speech2text",
"moderation",
"tts",
"text2img",
] = "llm"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner" owned_by: Literal["owner"] = "owner"
@ -82,17 +140,17 @@ class ChatCompletionRequest(BaseModel):
tools: Optional[List[FunctionAvailable]] = None tools: Optional[List[FunctionAvailable]] = None
functions: Optional[List[FunctionDefinition]] = None functions: Optional[List[FunctionDefinition]] = None
function_call: Optional[FunctionCallDefinition] = None function_call: Optional[FunctionCallDefinition] = None
temperature: Optional[float] = None temperature: Optional[float] = 0.75
top_p: Optional[float] = None top_p: Optional[float] = 0.75
top_k: Optional[float] = None top_k: Optional[float] = None
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = 256
stop: Optional[list[str]] = (None,) stop: Optional[list[str]] = None
stream: Optional[bool] = False stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs): def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段 # 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict( return super().dict(
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
) )

View File

@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
else: else:
return None return None
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (
self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0
)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
""" """
Get custom credentials. Get custom credentials.

View File

@ -6,12 +6,82 @@ from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
@staticmethod
def value_of(value):
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel): class RestrictModel(BaseModel):
model: str model: str
base_model_name: Optional[str] = None base_model_name: Optional[str] = None
model_type: ModelType model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel): class CustomProviderConfiguration(BaseModel):
""" """
Model class for provider custom configuration. Model class for provider custom configuration.

View File

@ -245,6 +245,10 @@ class ModelManager:
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
) )
@property
def provider_manager(self) -> ProviderManager:
return self._provider_manager
def get_model_instance( def get_model_instance(
self, provider: str, model_type: ModelType, model: str self, provider: str, model_type: ModelType, model: str
) -> ModelInstance: ) -> ModelInstance:

View File

@ -28,6 +28,23 @@ class PromptMessageRole(Enum):
return mode return mode
raise ValueError(f"invalid prompt message type value {value}") raise ValueError(f"invalid prompt message type value {value}")
def to_origin_role(self) -> str:
"""
Get origin role from prompt message role.
:return: origin role
"""
if self == self.SYSTEM:
return "system"
elif self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class PromptMessageTool(BaseModel): class PromptMessageTool(BaseModel):
""" """

View File

@ -1,6 +1,6 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,7 +1,7 @@
import importlib import importlib
import logging import logging
import os import os
from typing import Optional from typing import Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -49,7 +49,9 @@ class ModelProviderFactory:
if init_cache: if init_cache:
self.get_providers() self.get_providers()
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]: def get_providers(
self, provider_name: Union[str, set] = ""
) -> list[ProviderEntity]:
""" """
Get all providers Get all providers
:return: list of providers :return: list of providers
@ -60,20 +62,36 @@ class ModelProviderFactory:
# traverse all model_provider_extensions # traverse all model_provider_extensions
providers = [] providers = []
for name, model_provider_extension in model_provider_extensions.items(): for name, model_provider_extension in model_provider_extensions.items():
if provider_name in (name, ""): if isinstance(provider_name, str):
# get model_provider instance if provider_name in (name, ""):
model_provider_instance = model_provider_extension.provider_instance # get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema # get provider schema
provider_schema = model_provider_instance.get_provider_schema() provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types: for model_type in provider_schema.supported_model_types:
# get predefined models for given model type # get predefined models for given model type
models = model_provider_instance.models(model_type) models = model_provider_instance.models(model_type)
if models: if models:
provider_schema.models.extend(models) provider_schema.models.extend(models)
providers.append(provider_schema) providers.append(provider_schema)
elif isinstance(provider_name, set):
if name in provider_name:
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
# return providers # return providers
return providers return providers

View File

@ -0,0 +1,43 @@
model: chatglm3-6b
label:
zh_Hans: chatglm3-6b
en_US: chatglm3-6b
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.002'
unit: '0.001'
currency: USD

View File

@ -1,7 +1,7 @@
import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional, Union
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -45,7 +45,7 @@ class ProviderManager:
provider_name_to_provider_model_records_dict provider_name_to_provider_model_records_dict
) )
def get_configurations(self, provider: str) -> ProviderConfigurations: def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
""" """
Get model provider configurations. Get model provider configurations.
@ -155,7 +155,7 @@ class ProviderManager:
default_model = {} default_model = {}
# Get provider configurations # Get provider configurations
provider_configurations = self.get_configurations() provider_configurations = self.get_configurations(provider="openai")
# get available models from provider_configurations # get available models from provider_configurations
available_models = provider_configurations.get_models( available_models = provider_configurations.get_models(
@ -212,7 +212,7 @@ class ProviderManager:
:return: :return:
""" """
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_variables(
provider_entity.provider_credential_schema.credential_form_schemas provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema if provider_entity.provider_credential_schema
else [] else []
@ -229,7 +229,7 @@ class ProviderManager:
) )
# Get provider model credential secret variables # Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables( model_credential_variables = self._extract_variables(
provider_entity.model_credential_schema.credential_form_schemas provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema if provider_entity.model_credential_schema
else [] else []
@ -242,7 +242,7 @@ class ProviderManager:
continue continue
provider_model_credentials = {} provider_model_credentials = {}
for variable in model_credential_secret_variables: for variable in model_credential_variables:
if variable in provider_model_record.get("model_credentials"): if variable in provider_model_record.get("model_credentials"):
try: try:
provider_model_credentials[ provider_model_credentials[
@ -253,7 +253,7 @@ class ProviderManager:
custom_model_configurations.append( custom_model_configurations.append(
CustomModelConfiguration( CustomModelConfiguration(
model=provider_model_record.get("model_name"), model=provider_model_record.get("model"),
model_type=ModelType.value_of( model_type=ModelType.value_of(
provider_model_record.get("model_type") provider_model_record.get("model_type")
), ),
@ -265,18 +265,17 @@ class ProviderManager:
provider=custom_provider_configuration, models=custom_model_configurations provider=custom_provider_configuration, models=custom_model_configurations
) )
def _extract_secret_variables( def _extract_variables(
self, credential_form_schemas: list[CredentialFormSchema] self, credential_form_schemas: list[CredentialFormSchema]
) -> list[str]: ) -> list[str]:
""" """
Extract secret input form variables. Extract input form variables.
:param credential_form_schemas: :param credential_form_schemas:
:return: :return:
""" """
secret_input_form_variables = [] input_form_variables = []
for credential_form_schema in credential_form_schemas: for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT: input_form_variables.append(credential_form_schema.variable)
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables return input_form_variables

View File

@ -0,0 +1,88 @@
import logging
import os
import time
logger = logging.getLogger(__name__)
class LoggerNameFilter(logging.Filter):
def filter(self, record):
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
# record.name.startswith("uvicorn.error")
# and record.getMessage().startswith("Uvicorn running on")
# )
return True
def get_log_file(log_path: str, sub_dir: str):
"""
sub_dir should contain a timestamp.
"""
log_dir = os.path.join(log_path, sub_dir)
# Here should be creating a new directory each time, so `exist_ok=False`
os.makedirs(log_dir, exist_ok=False)
return os.path.join(log_dir, "loom_core.log")
def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
# for windows, the path should be a raw string.
log_file_path = (
log_file_path.encode("unicode-escape").decode()
if os.name == "nt"
else log_file_path
)
log_level = log_level.upper()
config_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"formatter": {
"format": (
"%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
)
},
},
"filters": {
"logger_name_filter": {
"()": __name__ + ".LoggerNameFilter",
},
},
"handlers": {
"stream_handler": {
"class": "logging.StreamHandler",
"formatter": "formatter",
"level": log_level,
# "stream": "ext://sys.stdout",
# "filters": ["logger_name_filter"],
},
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "formatter",
"level": log_level,
"filename": log_file_path,
"mode": "a",
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"loggers": {
"loom_core": {
"handlers": ["stream_handler", "file_handler"],
"level": log_level,
"propagate": False,
}
},
"root": {
"level": log_level,
"handlers": ["stream_handler", "file_handler"],
},
}
return config_dict
def get_timestamp_ms():
t = time.time()
return int(round(t * 1000))

View File

@ -14,6 +14,8 @@ sse-starlette = "^1.8.2"
pyyaml = "6.0.1" pyyaml = "6.0.1"
pydantic = "2.6.4" pydantic = "2.6.4"
redis = "4.5.4" redis = "4.5.4"
# config manage
omegaconf = "2.0.6"
# modle_runtime # modle_runtime
openai = "1.13.3" openai = "1.13.3"
tiktoken = "0.5.2" tiktoken = "0.5.2"
@ -24,18 +26,14 @@ boto3 = "1.28.17"
# dependencies used for running tests (e.g., pytest, freezegun, response). # dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed. # Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0" pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" pytest-mock = "^3.10.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pandas = "^2.0.0"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2" syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
@ -46,12 +44,6 @@ optional = true
ruff = "^0.1.5" ruff = "^0.1.5"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.dev] [tool.poetry.group.dev]
optional = true optional = true
@ -186,7 +178,7 @@ build-backend = "poetry.core.masonry.api"
# #
# https://github.com/tophat/syrupy # https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
# Registering custom markers. # Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [ markers = [

View File

@ -0,0 +1,104 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,33 @@
import asyncio
import logging
import pytest
from model_providers import BootstrapWebBuilder
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,104 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture
def logging_conf() -> dict:
return get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)

View File

@ -0,0 +1,43 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
"/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = (
provider_model_bundle_emb.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")