From 451fef8a31953e238e716537613971cbfa3e234d Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 28 Mar 2024 20:45:42 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E4=BD=BF=E7=94=A8yaml=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E9=85=8D=E7=BD=AE=E9=80=82=E9=85=8D=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model-providers/model_providers.yaml | 28 ++++++++++ model-providers/model_providers/__init__.py | 56 ++++++++++++++++--- model-providers/model_providers/__main__.py | 10 +--- .../model_providers/core/provider_manager.py | 2 +- model-providers/pyproject.toml | 4 +- 5 files changed, 80 insertions(+), 20 deletions(-) create mode 100644 model-providers/model_providers.yaml diff --git a/model-providers/model_providers.yaml b/model-providers/model_providers.yaml new file mode 100644 index 00000000..eb96fba7 --- /dev/null +++ b/model-providers/model_providers.yaml @@ -0,0 +1,28 @@ +openai: + model_credential: + - model: 'gpt-3.5-turbo' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + - model: 'gpt-4' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + + provider_credential: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + +xinference: + model_credential: + - model: 'gpt-3.5-turbo' + model_type: 'llm' + credential: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index 486e453c..04e8fe2b 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,17 +1,55 @@ -from chatchat.configs import MODEL_PLATFORMS from model_providers.core.model_manager import ModelManager -def _to_custom_provide_configuration(): +from omegaconf import OmegaConf, DictConfig + + +def _to_custom_provide_configuration(cfg: DictConfig): + """ + ``` + openai: + model_credential: + - model: 'gpt-3.5-turbo' + model_credentials: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + - model: 'gpt-4' + model_credentials: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + + provider_credential: + openai_api_key: '' + openai_organization: '' + openai_api_base: '' + + ``` + :param model_providers_cfg: + :return: + """ provider_name_to_provider_records_dict = {} provider_name_to_provider_model_records_dict = {} + + for key, item in cfg.items(): + model_credential = item.get('model_credential') + provider_credential = item.get('provider_credential') + # 转换omegaconf对象为基本属性 + if model_credential: + model_credential = OmegaConf.to_container(model_credential) + provider_name_to_provider_model_records_dict[key] = model_credential + if provider_credential: + provider_credential = OmegaConf.to_container(provider_credential) + provider_name_to_provider_records_dict[key] = provider_credential + return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict + +model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml") +provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg=model_providers_cfg) # 基于配置管理器创建的模型实例 provider_manager = ModelManager( - provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", - } - }, - provider_name_to_provider_model_records_dict={} -) \ No newline at end of file + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict +) diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py index ef5fab6e..b58dc035 100644 --- a/model-providers/model_providers/__main__.py +++ b/model-providers/model_providers/__main__.py @@ -1,6 +1,7 @@ import os from typing import cast, Generator +from model_providers import provider_manager from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage @@ -8,16 +9,7 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType if __name__ == '__main__': # 基于配置管理器创建的模型实例 - provider_manager = ModelManager( - provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", - } - }, - provider_name_to_provider_model_records_dict={} - ) - # # Invoke model model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4') diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index a2016963..f38d976d 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -230,7 +230,7 @@ class ProviderManager: custom_model_configurations.append( CustomModelConfiguration( - model=provider_model_record.get('model_name'), + model=provider_model_record.get('model'), model_type=ModelType.value_of(provider_model_record.get('model_type')), credentials=provider_model_credentials ) diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml index 2ab88d8e..87359676 100644 --- a/model-providers/pyproject.toml +++ b/model-providers/pyproject.toml @@ -14,6 +14,8 @@ sse-starlette = "^1.8.2" pyyaml = "6.0.1" pydantic = "2.6.4" redis = "4.5.4" +# config manage +omegaconf = "2.0.6" # modle_runtime openai = "1.13.3" tiktoken = "0.5.2" @@ -188,4 +190,4 @@ markers = [ "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them" ] -asyncio_mode = "auto" \ No newline at end of file +asyncio_mode = "auto" From 032dc8f58de754c785b3c375ded52001ed3982a0 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 29 Mar 2024 18:25:16 +0800 Subject: [PATCH 2/9] =?UTF-8?q?=E4=BD=BF=E7=94=A8BootstrapWebBuilder?= =?UTF-8?q?=E9=80=82=E9=85=8DRESTFulOpenAIBootstrapBaseWeb=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model-providers/model_providers/__init__.py | 48 +++++-- .../bootstrap_web/openai_bootstrap_web.py | 127 +++++------------- .../model_providers/core/bootstrap/base.py | 20 ++- .../core/bootstrap/openai_protocol.py | 2 +- 4 files changed, 90 insertions(+), 107 deletions(-) diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index 04e8fe2b..ebbdf71e 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,3 +1,5 @@ +from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb +from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.model_manager import ModelManager from omegaconf import OmegaConf, DictConfig @@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig): return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict -model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml") -provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg=model_providers_cfg) -# 基于配置管理器创建的模型实例 -provider_manager = ModelManager( - provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, - provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict -) +class BootstrapWebBuilder: + """ + 创建一个模型实例创建工具 + """ + _model_providers_cfg_path: str + _host: str + _port: int + + def model_providers_cfg_path(self, model_providers_cfg_path: str): + self._model_providers_cfg_path = model_providers_cfg_path + return self + + def host(self, host: str): + self._host = host + return self + + def port(self, port: int): + self._port = port + return self + + def build(self) -> OpenAIBootstrapBaseWeb: + assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None + # 读取配置文件 + cfg = OmegaConf.load(self._model_providers_cfg_path) + # 转换配置文件 + provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg) + # 创建模型管理器 + provider_manager = ModelManager(provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict) + # 创建web服务 + restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={ + "host": self._host, + "port": self._port + + }) + restful.provider_manager = provider_manager + return restful diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 56adcce9..b1032861 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import ( FunctionAvailable, ModelList, ) +from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.model_providers import model_provider_factory -from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( - LargeLanguageModel, -) from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) -async def create_stream_chat_completion( - model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest -): - try: - response = model_type_instance.invoke( - model=chat_request.model, - credentials={ - "openai_api_key": "sk-", - "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), - "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), - }, - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={**chat_request.to_model_parameters_dict()}, - stop=chat_request.stop, - stream=chat_request.stream, - user="abc-123", - ) - return response - - except Exception as e: - logger.exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ Bootstrap Server Lifecycle @@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) self._router.add_api_route( - "/v1/models", + "/{provider}/v1/models", self.list_models, response_model=ModelList, methods=["GET"], ) self._router.add_api_route( - "/v1/embeddings", + "/{provider}/v1/embeddings", self.create_embeddings, response_model=EmbeddingsResponse, status_code=status.HTTP_200_OK, methods=["POST"], ) self._router.add_api_route( - "/v1/chat/completions", + "/{provider}/v1/chat/completions", self.create_chat_completion, response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK, @@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): if started_event is not None: started_event.set() - async def list_models(self, request: Request): + async def list_models(self, provider: str, request: Request): pass async def create_embeddings( - self, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" ) - if os.environ["API_KEY"] is None: - authorization = request.headers.get("Authorization") - authorization = authorization.split("Bearer ")[-1] - else: - authorization = os.environ["API_KEY"] - client = ZhipuAI(api_key=authorization) - # 判断embeddings_request.input是否为list - input = None - if isinstance(embeddings_request.input, list): - tokens = embeddings_request.input - try: - encoding = tiktoken.encoding_for_model(embeddings_request.model) - except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") - model = "cl100k_base" - encoding = tiktoken.get_encoding(model) - for i, token in enumerate(tokens): - text = encoding.decode(token) - input += text - else: - input = embeddings_request.input - - response = client.embeddings.create( - model=embeddings_request.model, - input=input, - ) + response = None return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" ) - if os.environ["API_KEY"] is None: - authorization = request.headers.get("Authorization") - authorization = authorization.split("Bearer ")[-1] - else: - authorization = os.environ["API_KEY"] - model_provider_factory.get_providers(provider_name="openai") - provider_instance = model_provider_factory.get_provider_instance("openai") - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + model_instance = self._provider_manager.get_model_instance( + provider=provider, model_type=ModelType.LLM, model=chat_request.model + ) if chat_request.stream: - generator = create_stream_chat_completion(model_type_instance, chat_request) - return EventSourceResponse(generator, media_type="text/event-stream") - else: - response = model_type_instance.invoke( - model="gpt-4", - credentials={ - "openai_api_key": "sk-", - "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), - "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), - }, + # Invoke model + + response = model_instance.invoke_llm( prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={ - "temperature": 0.7, - "top_p": 1.0, - "top_k": 1, - "plugin_web_search": True, - }, - stop=["you"], - stream=False, + model_parameters={**chat_request.to_model_parameters_dict()}, + stop=chat_request.stop, + stream=chat_request.stream, + user="abc-123", + ) + + return EventSourceResponse(response, media_type="text/event-stream") + else: + # Invoke model + + response = model_instance.invoke_llm( + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={**chat_request.to_model_parameters_dict()}, + stop=chat_request.stop, + stream=chat_request.stream, user="abc-123", ) @@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: - import signal - - # 跳过键盘中断,使用xoscar的信号处理 - signal.signal(signal.SIGINT, lambda *_: None) api = RESTFulOpenAIBootstrapBaseWeb.from_config( cfg=cfg.get("run_openai_api", {}) ) diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index b2da1a0b..795eb565 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -3,9 +3,11 @@ from collections import deque from fastapi import Request +from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest +from model_providers.core.model_manager import ModelManager + class Bootstrap: - """最大的任务队列""" _MAX_ONGOING_TASKS: int = 1 @@ -39,21 +41,31 @@ class Bootstrap: class OpenAIBootstrapBaseWeb(Bootstrap): + _provider_manager: ModelManager + def __init__(self): super().__init__() + @property + def provider_manager(self) -> ModelManager: + return self._provider_manager + + @provider_manager.setter + def provider_manager(self, provider_manager: ModelManager): + self._provider_manager = provider_manager + @abstractmethod - async def list_models(self, request: Request): + async def list_models(self, provider: str, request: Request): pass @abstractmethod async def create_embeddings( - self, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): pass @abstractmethod async def create_chat_completion( - self, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): pass diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 1d7354cf..30d57c81 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel): def to_model_parameters_dict(self, *args, **kwargs): # 调用父类的to_dict方法,并排除tools字段 - helper.dump_model + return super().dict( exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs ) From f005ea3298b678bba168aa082a5d3064a100a445 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 15:08:30 +0800 Subject: [PATCH 3/9] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8?= =?UTF-8?q?=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chatchat-server/pyproject.toml | 16 ++- model-providers/model_providers/__main__.py | 50 ---------- .../bootstrap_web/openai_bootstrap_web.py | 25 ++++- .../core/bootstrap/openai_protocol.py | 2 +- .../model_providers/core/model_manager.py | 4 + .../model_runtime/entities/model_entities.py | 3 +- .../model_providers/core/utils/utils.py | 89 +++++++++++++++++ model-providers/pyproject.toml | 18 ++-- .../tests/server_unit_test/conftest.py | 99 +++++++++++++++++++ .../server_unit_test/test_init_server.py | 28 ++++++ model-providers/tests/unit_test/conftest.py | 99 +++++++++++++++++++ .../unit_test/test_provider_manager_models.py | 39 ++++++++ 12 files changed, 396 insertions(+), 76 deletions(-) delete mode 100644 model-providers/model_providers/__main__.py create mode 100644 model-providers/model_providers/core/utils/utils.py create mode 100644 model-providers/tests/server_unit_test/conftest.py create mode 100644 model-providers/tests/server_unit_test/test_init_server.py create mode 100644 model-providers/tests/unit_test/conftest.py create mode 100644 model-providers/tests/unit_test/test_provider_manager_models.py diff --git a/chatchat-server/pyproject.toml b/chatchat-server/pyproject.toml index 82cd1fa9..fa7cc8fe 100644 --- a/chatchat-server/pyproject.toml +++ b/chatchat-server/pyproject.toml @@ -57,18 +57,14 @@ optional = true # dependencies used for running tests (e.g., pytest, freezegun, response). # Any dependencies that do not meet that criteria will be removed. pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.7.0" -pytest-watcher = "^0.2.6" freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.20.3" -lark = "^1.1.5" -pandas = "^2.0.0" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" +pytest-mock = "^3.10.0" syrupy = "^4.0.2" +pytest-watcher = "^0.3.4" +pytest-asyncio = "^0.21.1" +grandalf = "^0.8" +pytest-profiling = "^1.7.0" +responses = "^0.25.0" model-providers = { path = "../model-providers", develop = true } diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py deleted file mode 100644 index decebcc4..00000000 --- a/model-providers/model_providers/__main__.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -from typing import Generator, cast - -from model_providers import provider_manager -from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.llm_entities import ( - LLMResultChunk, - LLMResultChunkDelta, -) -from model_providers.core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - UserPromptMessage, -) -from model_providers.core.model_runtime.entities.model_entities import ModelType - -if __name__ == '__main__': - # 基于配置管理器创建的模型实例 - - # Invoke model - model_instance = provider_manager.get_model_instance( - provider="openai", model_type=ModelType.LLM, model="gpt-4" - ) - - response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={ - "temperature": 0.7, - "top_p": 1.0, - "top_k": 1, - "plugin_web_search": True, - }, - stop=["you"], - stream=True, - user="abc-123", - ) - - assert isinstance(response, Generator) - total_message = "" - for chunk in response: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - total_message += chunk.delta.message.content - assert ( - len(chunk.delta.message.content) > 0 - if not chunk.delta.finish_reason - else True - ) - print(total_message) - assert "参考资料" in total_message diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index b1032861..e47ed678 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -21,13 +21,13 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsRequest, EmbeddingsResponse, FunctionAvailable, - ModelList, + ModelList, ModelCard, ) from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) @@ -110,7 +110,26 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): started_event.set() async def list_models(self, provider: str, request: Request): - pass + logger.info(f"Received list_models request for provider: {provider}") + # 返回ModelType所有的枚举 + llm_models: list[AIModelEntity] = [] + for model_type in ModelType.__members__.values(): + try: + provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle( + provider=provider, model_type=model_type + ) + llm_models.extend(provider_model_bundle.model_type_instance.predefined_models()) + except Exception as e: + logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}") + logger.error(e) + + # models list[AIModelEntity]转换称List[ModelCard] + + models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models] + + return ModelList( + data=models_list + ) async def create_embeddings( self, provider: str, request: Request, embeddings_request: EmbeddingsRequest diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 30d57c81..558dc95b 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -22,7 +22,7 @@ class Finish(str, Enum): class ModelCard(BaseModel): id: str - object: Literal["model"] = "model" + object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm" created: int = Field(default_factory=lambda: int(time.time())) owned_by: Literal["owner"] = "owner" diff --git a/model-providers/model_providers/core/model_manager.py b/model-providers/model_providers/core/model_manager.py index 5e2c69bd..af896423 100644 --- a/model-providers/model_providers/core/model_manager.py +++ b/model-providers/model_providers/core/model_manager.py @@ -245,6 +245,10 @@ class ModelManager: provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, ) + @property + def provider_manager(self) -> ProviderManager: + return self._provider_manager + def get_model_instance( self, provider: str, model_type: ModelType, model: str ) -> ModelInstance: diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index 307d459a..ae5938a4 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, List from pydantic import BaseModel @@ -74,6 +74,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/model-providers/model_providers/core/utils/utils.py b/model-providers/model_providers/core/utils/utils.py new file mode 100644 index 00000000..2bc3fde8 --- /dev/null +++ b/model-providers/model_providers/core/utils/utils.py @@ -0,0 +1,89 @@ +import logging + +import time +import os + +logger = logging.getLogger(__name__) + + +class LoggerNameFilter(logging.Filter): + def filter(self, record): + # return record.name.startswith("loom_core") or record.name in "ERROR" or ( + # record.name.startswith("uvicorn.error") + # and record.getMessage().startswith("Uvicorn running on") + # ) + return True + + +def get_log_file(log_path: str, sub_dir: str): + """ + sub_dir should contain a timestamp. + """ + log_dir = os.path.join(log_path, sub_dir) + # Here should be creating a new directory each time, so `exist_ok=False` + os.makedirs(log_dir, exist_ok=False) + return os.path.join(log_dir, "loom_core.log") + + +def get_config_dict( + log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int +) -> dict: + # for windows, the path should be a raw string. + log_file_path = ( + log_file_path.encode("unicode-escape").decode() + if os.name == "nt" + else log_file_path + ) + log_level = log_level.upper() + config_dict = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "formatter": { + "format": ( + "%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s" + ) + }, + }, + "filters": { + "logger_name_filter": { + "()": __name__ + ".LoggerNameFilter", + }, + }, + "handlers": { + "stream_handler": { + "class": "logging.StreamHandler", + "formatter": "formatter", + "level": log_level, + # "stream": "ext://sys.stdout", + # "filters": ["logger_name_filter"], + }, + "file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "formatter", + "level": log_level, + "filename": log_file_path, + "mode": "a", + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf8", + }, + }, + "loggers": { + "loom_core": { + "handlers": ["stream_handler", "file_handler"], + "level": log_level, + "propagate": False, + } + }, + "root": { + "level": log_level, + "handlers": ["stream_handler", "file_handler"], + }, + } + return config_dict + + +def get_timestamp_ms(): + t = time.time() + return int(round(t * 1000)) diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml index 44fc1529..859006de 100644 --- a/model-providers/pyproject.toml +++ b/model-providers/pyproject.toml @@ -26,18 +26,14 @@ boto3 = "1.28.17" # dependencies used for running tests (e.g., pytest, freezegun, response). # Any dependencies that do not meet that criteria will be removed. pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.7.0" -pytest-watcher = "^0.2.6" freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.20.3" -lark = "^1.1.5" -pandas = "^2.0.0" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" +pytest-mock = "^3.10.0" syrupy = "^4.0.2" +pytest-watcher = "^0.3.4" +pytest-asyncio = "^0.21.1" +grandalf = "^0.8" +pytest-profiling = "^1.7.0" +responses = "^0.25.0" @@ -182,7 +178,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/server_unit_test/conftest.py new file mode 100644 index 00000000..dc1af083 --- /dev/null +++ b/model-providers/tests/server_unit_test/conftest.py @@ -0,0 +1,99 @@ +"""Configuration for unit tests.""" +from importlib import util +from typing import Dict, Sequence, List +import logging +import pytest +from pytest import Config, Function, Parser + +from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + +@pytest.fixture +def logging_conf() -> dict: + return get_config_dict( + "DEBUG", + get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), + 122, + 111, + ) diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py new file mode 100644 index 00000000..d4145c77 --- /dev/null +++ b/model-providers/tests/server_unit_test/test_init_server.py @@ -0,0 +1,28 @@ +from model_providers import BootstrapWebBuilder +import logging +import asyncio + +import pytest +logger = logging.getLogger(__name__) + + +@pytest.mark.requires("fastapi") +def test_init_server(logging_conf: dict) -> None: + try: + boot = BootstrapWebBuilder() \ + .model_providers_cfg_path( + model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml") \ + .host(host="127.0.0.1") \ + .port(port=20000) \ + .build() + boot.set_app_event(started_event=None) + boot.serve(logging_conf=logging_conf) + + async def pool_join_thread(): + await boot.join() + + asyncio.run(pool_join_thread()) + except SystemExit: + logger.info("SystemExit raised, exiting") + raise diff --git a/model-providers/tests/unit_test/conftest.py b/model-providers/tests/unit_test/conftest.py new file mode 100644 index 00000000..dc1af083 --- /dev/null +++ b/model-providers/tests/unit_test/conftest.py @@ -0,0 +1,99 @@ +"""Configuration for unit tests.""" +from importlib import util +from typing import Dict, Sequence, List +import logging +import pytest +from pytest import Config, Function, Parser + +from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + +@pytest.fixture +def logging_conf() -> dict: + return get_config_dict( + "DEBUG", + get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), + 122, + 111, + ) diff --git a/model-providers/tests/unit_test/test_provider_manager_models.py b/model-providers/tests/unit_test/test_provider_manager_models.py new file mode 100644 index 00000000..c7afd8dc --- /dev/null +++ b/model-providers/tests/unit_test/test_provider_manager_models.py @@ -0,0 +1,39 @@ +from omegaconf import OmegaConf + +from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration +import logging +import asyncio + +import pytest + +from model_providers.core.model_manager import ModelManager +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.provider_manager import ProviderManager + +logger = logging.getLogger(__name__) + + +def test_provider_manager_models(logging_conf: dict) -> None: + + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml") + # 转换配置文件 + provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + provider_model_bundle_llm = provider_manager.get_provider_model_bundle( + provider="openai", model_type=ModelType.LLM + ) + provider_model_bundle_emb = provider_manager.get_provider_model_bundle( + provider="openai", model_type=ModelType.TEXT_EMBEDDING + ) + predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models() + + logger.info(f"predefined_models: {predefined_models}") From a2df71d9ea6a303bda387bc6233c77aac4ce6902 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 15:12:20 +0800 Subject: [PATCH 4/9] make format --- model-providers/model_providers/__init__.py | 44 ++++++++++++------- .../bootstrap_web/openai_bootstrap_web.py | 43 +++++++++++------- .../model_providers/core/bootstrap/base.py | 9 ++-- .../core/bootstrap/openai_protocol.py | 10 ++++- .../model_runtime/entities/model_entities.py | 3 +- .../model_providers/core/provider_manager.py | 8 ++-- .../model_providers/core/utils/utils.py | 5 +-- .../tests/server_unit_test/conftest.py | 11 +++-- .../server_unit_test/test_init_server.py | 19 +++++--- model-providers/tests/unit_test/conftest.py | 11 +++-- .../unit_test/test_provider_manager_models.py | 26 ++++++----- 11 files changed, 121 insertions(+), 68 deletions(-) diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index ebbdf71e..1be0f328 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,9 +1,11 @@ -from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb +from omegaconf import DictConfig, OmegaConf + +from model_providers.bootstrap_web.openai_bootstrap_web import ( + RESTFulOpenAIBootstrapBaseWeb, +) from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.model_manager import ModelManager -from omegaconf import OmegaConf, DictConfig - def _to_custom_provide_configuration(cfg: DictConfig): """ @@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig): provider_name_to_provider_model_records_dict = {} for key, item in cfg.items(): - model_credential = item.get('model_credential') - provider_credential = item.get('provider_credential') + model_credential = item.get("model_credential") + provider_credential = item.get("provider_credential") # 转换omegaconf对象为基本属性 if model_credential: model_credential = OmegaConf.to_container(model_credential) @@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig): provider_credential = OmegaConf.to_container(provider_credential) provider_name_to_provider_records_dict[key] = provider_credential - return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict + return ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) class BootstrapWebBuilder: """ 创建一个模型实例创建工具 """ + _model_providers_cfg_path: str _host: str _port: int @@ -68,20 +74,26 @@ class BootstrapWebBuilder: return self def build(self) -> OpenAIBootstrapBaseWeb: - assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None + assert ( + self._model_providers_cfg_path is not None + and self._host is not None + and self._port is not None + ) # 读取配置文件 cfg = OmegaConf.load(self._model_providers_cfg_path) # 转换配置文件 - provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg) + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) # 创建模型管理器 - provider_manager = ModelManager(provider_name_to_provider_records_dict, - provider_name_to_provider_model_records_dict) + provider_manager = ModelManager( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) # 创建web服务 - restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={ - "host": self._host, - "port": self._port - - }) + restful = RESTFulOpenAIBootstrapBaseWeb.from_config( + cfg={"host": self._host, "port": self._port} + ) restful.provider_manager = provider_manager return restful diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index e47ed678..6e692cee 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsRequest, EmbeddingsResponse, FunctionAvailable, - ModelList, ModelCard, + ModelCard, + ModelList, ) -from model_providers.core.model_manager import ModelManager, ModelInstance +from model_providers.core.model_manager import ModelInstance, ModelManager from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, +) from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) @@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): llm_models: list[AIModelEntity] = [] for model_type in ModelType.__members__.values(): try: - provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle( - provider=provider, model_type=model_type + provider_model_bundle = ( + self._provider_manager.provider_manager.get_provider_model_bundle( + provider=provider, model_type=model_type + ) + ) + llm_models.extend( + provider_model_bundle.model_type_instance.predefined_models() ) - llm_models.extend(provider_model_bundle.model_type_instance.predefined_models()) except Exception as e: - logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}") + logger.error( + f"Error while fetching models for provider: {provider}, model_type: {model_type}" + ) logger.error(e) # models list[AIModelEntity]转换称List[ModelCard] - models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models] + models_list = [ + ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) + for model in llm_models + ] - return ModelList( - data=models_list - ) + return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index 795eb565..8572323e 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -3,7 +3,10 @@ from collections import deque from fastapi import Request -from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionRequest, + EmbeddingsRequest, +) from model_providers.core.model_manager import ModelManager @@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap): @abstractmethod async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): pass @abstractmethod async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): pass diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 558dc95b..2753ad5d 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -22,7 +22,15 @@ class Finish(str, Enum): class ModelCard(BaseModel): id: str - object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm" + object: Literal[ + "text-generation", + "embeddings", + "reranking", + "speech2text", + "moderation", + "tts", + "text2img", + ] = "llm" created: int = Field(default_factory=lambda: int(time.time())) owned_by: Literal["owner"] = "owner" diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index ae5938a4..0fea8c1d 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Any, Optional, List +from typing import Any, List, Optional from pydantic import BaseModel @@ -74,7 +74,6 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") - class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index 5cb5ef43..5a99e314 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -253,9 +253,11 @@ class ProviderManager: custom_model_configurations.append( CustomModelConfiguration( - model=provider_model_record.get('model'), - model_type=ModelType.value_of(provider_model_record.get('model_type')), - credentials=provider_model_credentials + model=provider_model_record.get("model"), + model_type=ModelType.value_of( + provider_model_record.get("model_type") + ), + credentials=provider_model_credentials, ) ) diff --git a/model-providers/model_providers/core/utils/utils.py b/model-providers/model_providers/core/utils/utils.py index 2bc3fde8..dbfd7c6b 100644 --- a/model-providers/model_providers/core/utils/utils.py +++ b/model-providers/model_providers/core/utils/utils.py @@ -1,7 +1,6 @@ import logging - -import time import os +import time logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str): def get_config_dict( - log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int + log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int ) -> dict: # for windows, the path should be a raw string. log_file_path = ( diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/server_unit_test/conftest.py index dc1af083..eea02a65 100644 --- a/model-providers/tests/server_unit_test/conftest.py +++ b/model-providers/tests/server_unit_test/conftest.py @@ -1,11 +1,16 @@ """Configuration for unit tests.""" -from importlib import util -from typing import Dict, Sequence, List import logging +from importlib import util +from typing import Dict, List, Sequence + import pytest from pytest import Config, Function, Parser -from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) def pytest_addoption(parser: Parser) -> None: diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py index d4145c77..96210b89 100644 --- a/model-providers/tests/server_unit_test/test_init_server.py +++ b/model-providers/tests/server_unit_test/test_init_server.py @@ -1,21 +1,26 @@ -from model_providers import BootstrapWebBuilder -import logging import asyncio +import logging import pytest + +from model_providers import BootstrapWebBuilder + logger = logging.getLogger(__name__) @pytest.mark.requires("fastapi") def test_init_server(logging_conf: dict) -> None: try: - boot = BootstrapWebBuilder() \ + boot = ( + BootstrapWebBuilder() .model_providers_cfg_path( - model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" - "/model_providers.yaml") \ - .host(host="127.0.0.1") \ - .port(port=20000) \ + model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml" + ) + .host(host="127.0.0.1") + .port(port=20000) .build() + ) boot.set_app_event(started_event=None) boot.serve(logging_conf=logging_conf) diff --git a/model-providers/tests/unit_test/conftest.py b/model-providers/tests/unit_test/conftest.py index dc1af083..eea02a65 100644 --- a/model-providers/tests/unit_test/conftest.py +++ b/model-providers/tests/unit_test/conftest.py @@ -1,11 +1,16 @@ """Configuration for unit tests.""" -from importlib import util -from typing import Dict, Sequence, List import logging +from importlib import util +from typing import Dict, List, Sequence + import pytest from pytest import Config, Function, Parser -from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) def pytest_addoption(parser: Parser) -> None: diff --git a/model-providers/tests/unit_test/test_provider_manager_models.py b/model-providers/tests/unit_test/test_provider_manager_models.py index c7afd8dc..023c48ec 100644 --- a/model-providers/tests/unit_test/test_provider_manager_models.py +++ b/model-providers/tests/unit_test/test_provider_manager_models.py @@ -1,11 +1,10 @@ +import asyncio +import logging + +import pytest from omegaconf import OmegaConf from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration -import logging -import asyncio - -import pytest - from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.provider_manager import ProviderManager @@ -14,14 +13,17 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict) -> None: - logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" - "/model_providers.yaml") + cfg = OmegaConf.load( + "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml" + ) # 转换配置文件 - provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg) + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) # 创建模型管理器 provider_manager = ProviderManager( provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, @@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None: provider_model_bundle_emb = provider_manager.get_provider_model_bundle( provider="openai", model_type=ModelType.TEXT_EMBEDDING ) - predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models() + predefined_models = ( + provider_model_bundle_emb.model_type_instance.predefined_models() + ) logger.info(f"predefined_models: {predefined_models}") From 2f1c9bfd1197f972657a6bc71ca22d3fea5963a6 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 17:55:32 +0800 Subject: [PATCH 5/9] =?UTF-8?q?chat=5Fcompletions=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E6=8A=A5=E6=96=87=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model_providers/bootstrap_web/common.py | 22 ++ .../bootstrap_web/openai_bootstrap_web.py | 253 ++++++++++++++++-- .../core/bootstrap/openai_protocol.py | 52 +++- .../entities/message_entities.py | 17 ++ 4 files changed, 317 insertions(+), 27 deletions(-) create mode 100644 model-providers/model_providers/bootstrap_web/common.py diff --git a/model-providers/model_providers/bootstrap_web/common.py b/model-providers/model_providers/bootstrap_web/common.py new file mode 100644 index 00000000..0566e2cd --- /dev/null +++ b/model-providers/model_providers/bootstrap_web/common.py @@ -0,0 +1,22 @@ +import typing +from subprocess import Popen +from typing import Optional + +from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \ + ChatCompletionStreamResponse, Finish +from model_providers.core.utils.generic import jsonify + +if typing.TYPE_CHECKING: + from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage + + +def create_stream_chunk( + request_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional[Finish] = None, +) -> str: + choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) + return jsonify(chunk) diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 6e692cee..7a348002 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -5,14 +5,14 @@ import multiprocessing as mp import os import pprint import threading -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator -import tiktoken from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse from uvicorn import Config, Server +from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( ChatCompletionRequest, @@ -22,20 +22,212 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsResponse, FunctionAvailable, ModelCard, - ModelList, + ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo, ) -from model_providers.core.model_manager import ModelInstance, ModelManager +from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk + from model_providers.core.model_runtime.entities.message_entities import ( - UserPromptMessage, + UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage, + PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent, + PromptMessageTool, ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, ModelType, ) +from model_providers.core.model_runtime.errors.invoke import InvokeError from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) +MessageLike = Union[ChatMessage, PromptMessage] + +MessageLikeRepresentation = Union[ + MessageLike, + Tuple[Union[str, Type], Union[str, List[dict], List[object]]], + str, +] + + +def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError("User message content must be str") + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + # check if last message is user message + message = cast(ToolPromptMessage, message) + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + +def _create_template_from_message_type( + message_type: str, template: Union[str, list] +) -> PromptMessage: + """Create a message prompt template from a message type and template string. + + Args: + message_type: str the type of the message template (e.g., "human", "ai", etc.) + template: str the template string. + + Returns: + a message prompt template of the appropriate type. + """ + if isinstance(template, str): + content = template + elif isinstance(template, list): + content = [] + for tmpl in template: + if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if isinstance(tmpl, str): + text: str = tmpl + else: + text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 + content.append( + TextPromptMessageContent(data=text) + ) + elif isinstance(tmpl, dict) and "image_url" in tmpl: + img_template = cast(dict, tmpl)["image_url"] + if isinstance(img_template, str): + img_template_obj = ImagePromptMessageContent(data=img_template) + elif isinstance(img_template, dict): + img_template = dict(img_template) + if "url" in img_template: + url = img_template["url"] + else: + url = None + img_template_obj = ImagePromptMessageContent(data=url) + else: + raise ValueError() + content.append(img_template_obj) + else: + raise ValueError() + else: + raise ValueError() + + if message_type in ("human", "user"): + _message = UserPromptMessage(content=content) + elif message_type in ("ai", "assistant"): + _message = AssistantPromptMessage(content=content) + elif message_type == "system": + _message = SystemPromptMessage(content=content) + elif message_type in ("function", "tool"): + _message = ToolPromptMessage(content=content) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human'," + f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'." + ) + + return _message + + +def _convert_to_message( + message: MessageLikeRepresentation, +) -> Union[PromptMessage]: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - 2-tuple of (message class, template) + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, ChatMessage): + _message = _create_template_from_message_type(message.role.to_origin_role(), message.content) + + elif isinstance(message, PromptMessage): + _message = message + elif isinstance(message, str): + _message = _create_template_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + if isinstance(message_type_str, str): + _message = _create_template_from_message_type(message_type_str, template) + else: + raise ValueError( + f"Expected message type string, got {message_type_str}" + ) + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message + + +async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]: + request_id, model = None, None + for chunk in response: + if not isinstance(chunk, LLMResultChunk): + yield "[ERROR]" + return + + if model is None: + model = chunk.model + if request_id is None: + request_id = "request_id" + yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content="")) + + new_token = chunk.delta.message.content + + if new_token: + delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()), + content=new_token, + tool_calls=chunk.delta.message.tool_calls) + yield create_stream_chunk(request_id=request_id, + model=model, delta=delta, + index=chunk.delta.index, + finish_reason=chunk.delta.finish_reason) + + yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP) + yield "[DONE]" + + +async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: + choice = ChatCompletionResponseChoice( + index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)), + finish_reason=Finish.STOP + ) + usage = UsageInfo( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ) + return ChatCompletionResponse( + id="request_id", + model=response.model, + choices=[choice], + usage=usage, + ) + class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ @@ -143,7 +335,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -153,7 +345,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -162,38 +354,47 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): model_instance = self._provider_manager.get_model_instance( provider=provider, model_type=ModelType.LLM, model=chat_request.model ) - if chat_request.stream: - # Invoke model + prompt_messages = [_convert_to_message(message) for message in chat_request.messages] + + tools = [PromptMessageTool(name=f.function.name, + description=f.function.description, + parameters=f.function.parameters + ) + + for f in chat_request.tools] + if chat_request.functions: + tools.extend([PromptMessageTool(name=f.name, + description=f.description, + parameters=f.parameters + ) for f in chat_request.functions]) + + try: response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + prompt_messages=prompt_messages, model_parameters={**chat_request.to_model_parameters_dict()}, + tools=tools, stop=chat_request.stop, stream=chat_request.stream, user="abc-123", ) - return EventSourceResponse(response, media_type="text/event-stream") - else: - # Invoke model + if chat_request.stream: - response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={**chat_request.to_model_parameters_dict()}, - stop=chat_request.stop, - stream=chat_request.stream, - user="abc-123", - ) + return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream") + else: + return await _openai_chat_completion(response) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - chat_response = ChatCompletionResponse(**dictify(response)) - - return chat_response + except InvokeError as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 2753ad5d..ec5ddc3f 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -13,12 +13,62 @@ class Role(str, Enum): FUNCTION = "function" TOOL = "tool" + @classmethod + def value_of(cls, origin_role: str) -> "Role": + if origin_role == "user": + return cls.USER + elif origin_role == "assistant": + return cls.ASSISTANT + elif origin_role == "system": + return cls.SYSTEM + elif origin_role == "function": + return cls.FUNCTION + elif origin_role == "tool": + return cls.TOOL + else: + raise ValueError(f"invalid origin role {origin_role}") + + def to_origin_role(self) -> str: + if self == self.USER: + return "user" + elif self == self.ASSISTANT: + return "assistant" + elif self == self.SYSTEM: + return "system" + elif self == self.FUNCTION: + return "function" + elif self == self.TOOL: + return "tool" + else: + raise ValueError(f"invalid role {self}") + class Finish(str, Enum): STOP = "stop" LENGTH = "length" TOOL = "tool_calls" + @classmethod + def value_of(cls, origin_finish: str) -> "Finish": + if origin_finish == "stop": + return cls.STOP + elif origin_finish == "length": + return cls.LENGTH + elif origin_finish == "tool_calls": + return cls.TOOL + else: + raise ValueError(f"invalid origin finish {origin_finish}") + + def to_origin_finish(self) -> str: + if self == self.STOP: + return "stop" + elif self == self.LENGTH: + return "length" + elif self == self.TOOL: + return "tool_calls" + else: + raise ValueError(f"invalid finish {self}") + class ModelCard(BaseModel): id: str @@ -95,7 +145,7 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[float] = None n: int = 1 max_tokens: Optional[int] = None - stop: Optional[list[str]] = (None,) + stop: Optional[list[str]] = None stream: Optional[bool] = False def to_model_parameters_dict(self, *args, **kwargs): diff --git a/model-providers/model_providers/core/model_runtime/entities/message_entities.py b/model-providers/model_providers/core/model_runtime/entities/message_entities.py index c9a823c0..a66294ad 100644 --- a/model-providers/model_providers/core/model_runtime/entities/message_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/message_entities.py @@ -28,6 +28,23 @@ class PromptMessageRole(Enum): return mode raise ValueError(f"invalid prompt message type value {value}") + def to_origin_role(self) -> str: + """ + Get origin role from prompt message role. + + :return: origin role + """ + if self == self.SYSTEM: + return "system" + elif self == self.USER: + return "user" + elif self == self.ASSISTANT: + return "assistant" + elif self == self.TOOL: + return "tool" + else: + raise ValueError(f"invalid role {self}") + class PromptMessageTool(BaseModel): """ From 056b15b99b1ac969b40f5357a88332f12103d4b6 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 17:55:57 +0800 Subject: [PATCH 6/9] make format --- .../model_providers/bootstrap_web/common.py | 21 ++- .../bootstrap_web/openai_bootstrap_web.py | 145 ++++++++++++------ 2 files changed, 111 insertions(+), 55 deletions(-) diff --git a/model-providers/model_providers/bootstrap_web/common.py b/model-providers/model_providers/bootstrap_web/common.py index 0566e2cd..a06a3064 100644 --- a/model-providers/model_providers/bootstrap_web/common.py +++ b/model-providers/model_providers/bootstrap_web/common.py @@ -2,8 +2,11 @@ import typing from subprocess import Popen from typing import Optional -from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \ - ChatCompletionStreamResponse, Finish +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + Finish, +) from model_providers.core.utils.generic import jsonify if typing.TYPE_CHECKING: @@ -11,12 +14,14 @@ if typing.TYPE_CHECKING: def create_stream_chunk( - request_id: str, - model: str, - delta: "ChatCompletionMessage", - index: Optional[int] = 0, - finish_reason: Optional[Finish] = None, + request_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional[Finish] = None, ) -> str: - choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + choice = ChatCompletionStreamResponseChoice( + index=index, delta=delta, finish_reason=finish_reason + ) chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) return jsonify(chunk) diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 7a348002..31b2cb77 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -5,7 +5,18 @@ import multiprocessing as mp import os import pprint import threading -from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -15,21 +26,36 @@ from uvicorn import Config, Server from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionStreamResponse, + ChatMessage, EmbeddingsRequest, EmbeddingsResponse, + Finish, FunctionAvailable, ModelCard, - ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo, + ModelList, + Role, + UsageInfo, +) +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, ) -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk - from model_providers.core.model_runtime.entities.message_entities import ( - UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage, - PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent, + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -81,7 +107,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: def _create_template_from_message_type( - message_type: str, template: Union[str, list] + message_type: str, template: Union[str, list] ) -> PromptMessage: """Create a message prompt template from a message type and template string. @@ -102,9 +128,7 @@ def _create_template_from_message_type( text: str = tmpl else: text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 - content.append( - TextPromptMessageContent(data=text) - ) + content.append(TextPromptMessageContent(data=text)) elif isinstance(tmpl, dict) and "image_url" in tmpl: img_template = cast(dict, tmpl)["image_url"] if isinstance(img_template, str): @@ -142,7 +166,7 @@ def _create_template_from_message_type( def _convert_to_message( - message: MessageLikeRepresentation, + message: MessageLikeRepresentation, ) -> Union[PromptMessage]: """Instantiate a message from a variety of message formats. @@ -161,7 +185,9 @@ def _convert_to_message( an instance of a message or a message template """ if isinstance(message, ChatMessage): - _message = _create_template_from_message_type(message.role.to_origin_role(), message.content) + _message = _create_template_from_message_type( + message.role.to_origin_role(), message.content + ) elif isinstance(message, PromptMessage): _message = message @@ -174,16 +200,16 @@ def _convert_to_message( if isinstance(message_type_str, str): _message = _create_template_from_message_type(message_type_str, template) else: - raise ValueError( - f"Expected message type string, got {message_type_str}" - ) + raise ValueError(f"Expected message type string, got {message_type_str}") else: raise NotImplementedError(f"Unsupported message type: {type(message)}") return _message -async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]: +async def _stream_openai_chat_completion( + response: Generator, +) -> AsyncGenerator[str, None]: request_id, model = None, None for chunk in response: if not isinstance(chunk, LLMResultChunk): @@ -194,27 +220,41 @@ async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[ model = chunk.model if request_id is None: request_id = "request_id" - yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content="")) + yield create_stream_chunk( + request_id, + model, + ChatCompletionMessage(role=Role.ASSISTANT, content=""), + ) new_token = chunk.delta.message.content if new_token: - delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()), - content=new_token, - tool_calls=chunk.delta.message.tool_calls) - yield create_stream_chunk(request_id=request_id, - model=model, delta=delta, - index=chunk.delta.index, - finish_reason=chunk.delta.finish_reason) + delta = ChatCompletionMessage( + role=Role.value_of(chunk.delta.message.role.to_origin_role()), + content=new_token, + tool_calls=chunk.delta.message.tool_calls, + ) + yield create_stream_chunk( + request_id=request_id, + model=model, + delta=delta, + index=chunk.delta.index, + finish_reason=chunk.delta.finish_reason, + ) - yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP) + yield create_stream_chunk( + request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP + ) yield "[DONE]" async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: choice = ChatCompletionResponseChoice( - index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)), - finish_reason=Finish.STOP + index=0, + message=ChatCompletionMessage( + **_convert_prompt_message_to_dict(message=response.message) + ), + finish_reason=Finish.STOP, ) usage = UsageInfo( prompt_tokens=response.usage.prompt_tokens, @@ -335,7 +375,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -345,7 +385,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -354,22 +394,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): model_instance = self._provider_manager.get_model_instance( provider=provider, model_type=ModelType.LLM, model=chat_request.model ) - prompt_messages = [_convert_to_message(message) for message in chat_request.messages] + prompt_messages = [ + _convert_to_message(message) for message in chat_request.messages + ] - tools = [PromptMessageTool(name=f.function.name, - description=f.function.description, - parameters=f.function.parameters - ) - - for f in chat_request.tools] + tools = [ + PromptMessageTool( + name=f.function.name, + description=f.function.description, + parameters=f.function.parameters, + ) + for f in chat_request.tools + ] if chat_request.functions: - tools.extend([PromptMessageTool(name=f.name, - description=f.description, - parameters=f.parameters - ) for f in chat_request.functions]) + tools.extend( + [ + PromptMessageTool( + name=f.name, description=f.description, parameters=f.parameters + ) + for f in chat_request.functions + ] + ) try: - response = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={**chat_request.to_model_parameters_dict()}, @@ -380,21 +427,25 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) if chat_request.stream: - - return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream") + return EventSourceResponse( + _stream_openai_chat_completion(response), + media_type="text/event-stream", + ) else: return await _openai_chat_completion(response) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except InvokeError as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: From a4017791200e92fe136565127b5388d1f4170a93 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 18:53:45 +0800 Subject: [PATCH 7/9] =?UTF-8?q?xinference=20=E6=8F=92=E4=BB=B6=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model-providers/model_providers.yaml | 11 ++--- .../xinference/llm/_position.yaml | 1 + .../xinference/llm/chatglm3-6b.yaml | 43 +++++++++++++++++++ .../model_providers/core/provider_manager.py | 17 ++++---- 4 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/_position.yaml create mode 100644 model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/chatglm3-6b.yaml diff --git a/model-providers/model_providers.yaml b/model-providers/model_providers.yaml index eb96fba7..86171fb1 100644 --- a/model-providers/model_providers.yaml +++ b/model-providers/model_providers.yaml @@ -20,9 +20,10 @@ openai: xinference: model_credential: - - model: 'gpt-3.5-turbo' + - model: 'chatglm3-6b' model_type: 'llm' - credential: - openai_api_key: '' - openai_organization: '' - openai_api_base: '' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'gpt-3.5-turbo' + + diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/_position.yaml b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/_position.yaml new file mode 100644 index 00000000..36640c5e --- /dev/null +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/_position.yaml @@ -0,0 +1 @@ +- chatglm3-6b diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/chatglm3-6b.yaml b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/chatglm3-6b.yaml new file mode 100644 index 00000000..371b0126 --- /dev/null +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/chatglm3-6b.yaml @@ -0,0 +1,43 @@ +model: chatglm3-6b +label: + zh_Hans: chatglm3-6b + en_US: chatglm3-6b +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '0.001' + output: '0.002' + unit: '0.001' + currency: USD diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index 5a99e314..5dda1158 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -212,7 +212,7 @@ class ProviderManager: :return: """ # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( + provider_credential_secret_variables = self._extract_variables( provider_entity.provider_credential_schema.credential_form_schemas if provider_entity.provider_credential_schema else [] @@ -229,7 +229,7 @@ class ProviderManager: ) # Get provider model credential secret variables - model_credential_secret_variables = self._extract_secret_variables( + model_credential_variables = self._extract_variables( provider_entity.model_credential_schema.credential_form_schemas if provider_entity.model_credential_schema else [] @@ -242,7 +242,7 @@ class ProviderManager: continue provider_model_credentials = {} - for variable in model_credential_secret_variables: + for variable in model_credential_variables: if variable in provider_model_record.get("model_credentials"): try: provider_model_credentials[ @@ -265,18 +265,17 @@ class ProviderManager: provider=custom_provider_configuration, models=custom_model_configurations ) - def _extract_secret_variables( + def _extract_variables( self, credential_form_schemas: list[CredentialFormSchema] ) -> list[str]: """ - Extract secret input form variables. + Extract input form variables. :param credential_form_schemas: :return: """ - secret_input_form_variables = [] + input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type == FormType.SECRET_INPUT: - secret_input_form_variables.append(credential_form_schema.variable) + input_form_variables.append(credential_form_schema.variable) - return secret_input_form_variables + return input_form_variables From 3c4e8dadd669a70694ae095d9b1d8409c82252f2 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 19:45:55 +0800 Subject: [PATCH 8/9] =?UTF-8?q?=E4=B8=80=E4=BA=9B=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model-providers/model_providers.yaml | 2 +- .../bootstrap_web/openai_bootstrap_web.py | 18 ++++++++++-------- .../core/bootstrap/openai_protocol.py | 6 +++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/model-providers/model_providers.yaml b/model-providers/model_providers.yaml index 86171fb1..d88736b3 100644 --- a/model-providers/model_providers.yaml +++ b/model-providers/model_providers.yaml @@ -24,6 +24,6 @@ xinference: model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'gpt-3.5-turbo' + model_uid: 'chatglm3-6b' diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 31b2cb77..9e20d7aa 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -398,14 +398,16 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): _convert_to_message(message) for message in chat_request.messages ] - tools = [ - PromptMessageTool( - name=f.function.name, - description=f.function.description, - parameters=f.function.parameters, - ) - for f in chat_request.tools - ] + tools = [] + if chat_request.tools: + tools = [ + PromptMessageTool( + name=f.function.name, + description=f.function.description, + parameters=f.function.parameters, + ) + for f in chat_request.tools + ] if chat_request.functions: tools.extend( [ diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index ec5ddc3f..2945c0ba 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -140,11 +140,11 @@ class ChatCompletionRequest(BaseModel): tools: Optional[List[FunctionAvailable]] = None functions: Optional[List[FunctionDefinition]] = None function_call: Optional[FunctionCallDefinition] = None - temperature: Optional[float] = None - top_p: Optional[float] = None + temperature: Optional[float] = 0.75 + top_p: Optional[float] = 0.75 top_k: Optional[float] = None n: int = 1 - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 256 stop: Optional[list[str]] = None stream: Optional[bool] = False From a1fe8d714ff8ba37ab091798e1de9173596eac03 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 1 Apr 2024 20:09:12 +0800 Subject: [PATCH 9/9] =?UTF-8?q?provider=5Fconfiguration.py=20=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E6=89=80=E6=9C=89=E7=9A=84=E5=B9=B3=E5=8F=B0=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=EF=BC=8C=E5=8C=85=E5=90=AB=E8=AE=A1=E8=B4=B9=E7=AD=96?= =?UTF-8?q?=E7=95=A5=E5=92=8C=E9=85=8D=E7=BD=AEschema=5Fvalidators(?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=BF=85=E5=A1=AB=E4=BF=A1=E6=81=AF=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C=E8=A7=84=E5=88=99)=20/workspaces/current/model-provid?= =?UTF-8?q?ers=20=E6=9F=A5=E8=AF=A2=E5=B9=B3=E5=8F=B0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=86=E7=B1=BB=E7=9A=84=E8=AF=A6=E7=BB=86=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=EF=BC=8C=E5=8C=85=E5=90=AB=E4=BA=86=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=B1=BB=E5=9E=8B=EF=BC=8C=E6=A8=A1=E5=9E=8B=E5=8F=82?= =?UTF-8?q?=E6=95=B0=EF=BC=8C=E6=A8=A1=E5=9E=8B=E7=8A=B6=E6=80=81=20worksp?= =?UTF-8?q?aces/current/models/model-types/{model=5Ftype}?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bootstrap_web/entities/__init__.py | 0 .../entities/model_provider_entities.py | 183 ++++++++++++++++++ .../bootstrap_web/openai_bootstrap_web.py | 24 +++ .../model_providers/core/bootstrap/base.py | 163 +++++++++++++++- .../core/entities/provider_configuration.py | 10 + .../core/entities/provider_entities.py | 70 +++++++ .../model_providers/model_provider_factory.py | 44 +++-- .../model_providers/core/provider_manager.py | 6 +- 8 files changed, 474 insertions(+), 26 deletions(-) create mode 100644 model-providers/model_providers/bootstrap_web/entities/__init__.py create mode 100644 model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py diff --git a/model-providers/model_providers/bootstrap_web/entities/__init__.py b/model-providers/model_providers/bootstrap_web/entities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py new file mode 100644 index 00000000..e7899a0c --- /dev/null +++ b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py @@ -0,0 +1,183 @@ +from enum import Enum +from typing import List, Literal, Optional + +from pydantic import BaseModel + +from model_providers.core.entities.model_entities import ( + ModelStatus, + ModelWithProviderEntity, +) +from model_providers.core.entities.provider_entities import ( + ProviderQuotaType, + ProviderType, + QuotaConfiguration, +) +from model_providers.core.model_runtime.entities.common_entities import I18nObject +from model_providers.core.model_runtime.entities.model_entities import ( + ModelType, + ProviderModel, +) +from model_providers.core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) + + +class CustomConfigurationStatus(Enum): + """ + Enum class for custom configuration status. + """ + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" + + +class CustomConfigurationResponse(BaseModel): + """ + Model class for provider custom configuration response. + """ + + status: CustomConfigurationStatus + + +class SystemConfigurationResponse(BaseModel): + """ + Model class for provider system configuration response. + """ + + enabled: bool + current_quota_type: Optional[ProviderQuotaType] = None + quota_configurations: list[QuotaConfiguration] = [] + + +class ProviderResponse(BaseModel): + """ + Model class for provider response. + """ + + provider: str + label: I18nObject + description: Optional[I18nObject] = None + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + background: Optional[str] = None + help: Optional[ProviderHelpEntity] = None + supported_model_types: list[ModelType] + configurate_methods: list[ConfigurateMethod] + provider_credential_schema: Optional[ProviderCredentialSchema] = None + model_credential_schema: Optional[ModelCredentialSchema] = None + preferred_provider_type: ProviderType + custom_configuration: CustomConfigurationResponse + system_configuration: SystemConfigurationResponse + + def __init__(self, **data) -> None: + super().__init__(**data) + # + # url_prefix = (current_app.config.get("CONSOLE_API_URL") + # + f"/console/api/workspaces/current/model-providers/{self.provider}") + # if self.icon_small is not None: + # self.icon_small = I18nObject( + # en_US=f"{url_prefix}/icon_small/en_US", + # zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + # ) + # + # if self.icon_large is not None: + # self.icon_large = I18nObject( + # en_US=f"{url_prefix}/icon_large/en_US", + # zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + # ) + + +class ProviderListResponse(BaseModel): + object: Literal["list"] = "list" + data: List[ProviderResponse] = [] + + +class ModelResponse(ProviderModel): + """ + Model class for model response. + """ + + status: ModelStatus + + +class ProviderWithModelsResponse(BaseModel): + """ + Model class for provider with models response. + """ + + provider: str + label: I18nObject + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + status: CustomConfigurationStatus + models: list[ModelResponse] + + def __init__(self, **data) -> None: + super().__init__(**data) + + # url_prefix = (current_app.config.get("CONSOLE_API_URL") + # + f"/console/api/workspaces/current/model-providers/{self.provider}") + # if self.icon_small is not None: + # self.icon_small = I18nObject( + # en_US=f"{url_prefix}/icon_small/en_US", + # zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + # ) + # + # if self.icon_large is not None: + # self.icon_large = I18nObject( + # en_US=f"{url_prefix}/icon_large/en_US", + # zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + # ) + + +class ProviderModelTypeResponse(BaseModel): + object: Literal["list"] = "list" + data: List[ProviderWithModelsResponse] = [] + + +class SimpleProviderEntityResponse(SimpleProviderEntity): + """ + Simple provider entity response. + """ + + def __init__(self, **data) -> None: + super().__init__(**data) + + # url_prefix = (current_app.config.get("CONSOLE_API_URL") + # + f"/console/api/workspaces/current/model-providers/{self.provider}") + # if self.icon_small is not None: + # self.icon_small = I18nObject( + # en_US=f"{url_prefix}/icon_small/en_US", + # zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + # ) + # + # if self.icon_large is not None: + # self.icon_large = I18nObject( + # en_US=f"{url_prefix}/icon_large/en_US", + # zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + # ) + + +class DefaultModelResponse(BaseModel): + """ + Default model entity. + """ + + model: str + model_type: ModelType + provider: SimpleProviderEntityResponse + + +class ModelWithProviderEntityResponse(ModelWithProviderEntity): + """ + Model with provider entity. + """ + + provider: SimpleProviderEntityResponse + + def __init__(self, model: ModelWithProviderEntity) -> None: + super().__init__(**model.dict()) diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 9e20d7aa..398b0bc9 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -24,6 +24,10 @@ from sse_starlette import EventSourceResponse from uvicorn import Config, Server from model_providers.bootstrap_web.common import create_stream_chunk +from model_providers.bootstrap_web.entities.model_provider_entities import ( + ProviderListResponse, + ProviderModelTypeResponse, +) from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( ChatCompletionMessage, @@ -301,6 +305,18 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): allow_headers=["*"], ) + self._router.add_api_route( + "/workspaces/current/model-providers", + self.workspaces_model_providers, + response_model=ProviderListResponse, + methods=["GET"], + ) + self._router.add_api_route( + "/workspaces/current/models/model-types/{model_type}", + self.workspaces_model_types, + response_model=ProviderModelTypeResponse, + methods=["GET"], + ) self._router.add_api_route( "/{provider}/v1/models", self.list_models, @@ -345,6 +361,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): if started_event is not None: started_event.set() + async def workspaces_model_providers(self, request: Request): + provider_list = self.get_provider_list(model_type=request.get("model_type")) + return ProviderListResponse(data=provider_list) + + async def workspaces_model_types(self, model_type: str, request: Request): + models_by_model_type = self.get_models_by_model_type(model_type=model_type) + return ProviderModelTypeResponse(data=models_by_model_type) + async def list_models(self, provider: str, request: Request): logger.info(f"Received list_models request for provider: {provider}") # 返回ModelType所有的枚举 diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index 8572323e..f74c5dd8 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -1,13 +1,25 @@ from abc import abstractmethod from collections import deque +from typing import List, Optional from fastapi import Request +from model_providers.bootstrap_web.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + ModelResponse, + ProviderResponse, + ProviderWithModelsResponse, + SystemConfigurationResponse, +) from model_providers.core.bootstrap.openai_protocol import ( ChatCompletionRequest, EmbeddingsRequest, ) +from model_providers.core.entities.model_entities import ModelStatus +from model_providers.core.entities.provider_entities import ProviderType from model_providers.core.model_manager import ModelManager +from model_providers.core.model_runtime.entities.model_entities import ModelType class Bootstrap: @@ -18,9 +30,150 @@ class Bootstrap: """任务队列""" _QUEUE: deque = deque() + _provider_manager: ModelManager + def __init__(self): self._version = "v0.0.1" + @property + def provider_manager(self) -> ModelManager: + return self._provider_manager + + @provider_manager.setter + def provider_manager(self, provider_manager: ModelManager): + self._provider_manager = provider_manager + + def get_provider_list( + self, model_type: Optional[str] = None + ) -> List[ProviderResponse]: + """ + get provider list. + + :param model_type: model type + :return: + """ + # 合并两个字典的键 + provider = set( + self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys() + ) + provider.update( + self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys() + ) + # Get all provider configurations of the current workspace + provider_configurations = ( + self.provider_manager.provider_manager.get_configurations(provider=provider) + ) + + provider_responses = [] + for provider_configuration in provider_configurations.values(): + if model_type: + model_type_entity = ModelType.value_of(model_type) + if ( + model_type_entity + not in provider_configuration.provider.supported_model_types + ): + continue + + provider_response = ProviderResponse( + provider=provider_configuration.provider.provider, + label=provider_configuration.provider.label, + description=provider_configuration.provider.description, + icon_small=provider_configuration.provider.icon_small, + icon_large=provider_configuration.provider.icon_large, + background=provider_configuration.provider.background, + help=provider_configuration.provider.help, + supported_model_types=provider_configuration.provider.supported_model_types, + configurate_methods=provider_configuration.provider.configurate_methods, + provider_credential_schema=provider_configuration.provider.provider_credential_schema, + model_credential_schema=provider_configuration.provider.model_credential_schema, + preferred_provider_type=ProviderType.value_of("custom"), + custom_configuration=CustomConfigurationResponse( + status=CustomConfigurationStatus.ACTIVE + if provider_configuration.is_custom_configuration_available() + else CustomConfigurationStatus.NO_CONFIGURE + ), + system_configuration=SystemConfigurationResponse(enabled=False), + ) + + provider_responses.append(provider_response) + + return provider_responses + + def get_models_by_model_type( + self, model_type: str + ) -> List[ProviderWithModelsResponse]: + """ + get models by model type. + + :param model_type: model type + :return: + """ + # 合并两个字典的键 + provider = set( + self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys() + ) + provider.update( + self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys() + ) + # Get all provider configurations of the current workspace + provider_configurations = ( + self.provider_manager.provider_manager.get_configurations(provider=provider) + ) + + # Get provider available models + models = provider_configurations.get_models( + model_type=ModelType.value_of(model_type) + ) + + # Group models by provider + provider_models = {} + for model in models: + if model.provider.provider not in provider_models: + provider_models[model.provider.provider] = [] + + if model.deprecated: + continue + + provider_models[model.provider.provider].append(model) + + # convert to ProviderWithModelsResponse list + providers_with_models: list[ProviderWithModelsResponse] = [] + for provider, models in provider_models.items(): + if not models: + continue + + first_model = models[0] + + has_active_models = any( + [model.status == ModelStatus.ACTIVE for model in models] + ) + + providers_with_models.append( + ProviderWithModelsResponse( + provider=provider, + label=first_model.provider.label, + icon_small=first_model.provider.icon_small, + icon_large=first_model.provider.icon_large, + status=CustomConfigurationStatus.ACTIVE + if has_active_models + else CustomConfigurationStatus.NO_CONFIGURE, + models=[ + ModelResponse( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + ) + for model in models + ], + ) + ) + + return providers_with_models + @classmethod @abstractmethod def from_config(cls, cfg=None): @@ -44,19 +197,9 @@ class Bootstrap: class OpenAIBootstrapBaseWeb(Bootstrap): - _provider_manager: ModelManager - def __init__(self): super().__init__() - @property - def provider_manager(self) -> ModelManager: - return self._provider_manager - - @provider_manager.setter - def provider_manager(self, provider_manager: ModelManager): - self._provider_manager = provider_manager - @abstractmethod async def list_models(self, provider: str, request: Request): pass diff --git a/model-providers/model_providers/core/entities/provider_configuration.py b/model-providers/model_providers/core/entities/provider_configuration.py index 947a2900..22d42587 100644 --- a/model-providers/model_providers/core/entities/provider_configuration.py +++ b/model-providers/model_providers/core/entities/provider_configuration.py @@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel): else: return None + def is_custom_configuration_available(self) -> bool: + """ + Check custom configuration available. + :return: + """ + return ( + self.custom_configuration.provider is not None + or len(self.custom_configuration.models) > 0 + ) + def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ Get custom credentials. diff --git a/model-providers/model_providers/core/entities/provider_entities.py b/model-providers/model_providers/core/entities/provider_entities.py index 0f6ebd49..ba4a3fb1 100644 --- a/model-providers/model_providers/core/entities/provider_entities.py +++ b/model-providers/model_providers/core/entities/provider_entities.py @@ -6,12 +6,82 @@ from pydantic import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelType +class ProviderType(Enum): + CUSTOM = "custom" + SYSTEM = "system" + + @staticmethod + def value_of(value): + for member in ProviderType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class ProviderQuotaType(Enum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value): + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class QuotaUnit(Enum): + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" + + +class SystemConfigurationStatus(Enum): + """ + Enum class for system configuration status. + """ + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" + + class RestrictModel(BaseModel): model: str base_model_name: Optional[str] = None model_type: ModelType +class QuotaConfiguration(BaseModel): + """ + Model class for provider quota configuration. + """ + + quota_type: ProviderQuotaType + quota_unit: QuotaUnit + quota_limit: int + quota_used: int + is_valid: bool + restrict_models: list[RestrictModel] = [] + + +class SystemConfiguration(BaseModel): + """ + Model class for provider system configuration. + """ + + enabled: bool + current_quota_type: Optional[ProviderQuotaType] = None + quota_configurations: list[QuotaConfiguration] = [] + credentials: Optional[dict] = None + + class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. diff --git a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py index bb3c4c91..fbec3157 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py @@ -1,7 +1,7 @@ import importlib import logging import os -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel @@ -49,7 +49,9 @@ class ModelProviderFactory: if init_cache: self.get_providers() - def get_providers(self, provider_name: str = "") -> list[ProviderEntity]: + def get_providers( + self, provider_name: Union[str, set] = "" + ) -> list[ProviderEntity]: """ Get all providers :return: list of providers @@ -60,20 +62,36 @@ class ModelProviderFactory: # traverse all model_provider_extensions providers = [] for name, model_provider_extension in model_provider_extensions.items(): - if provider_name in (name, ""): - # get model_provider instance - model_provider_instance = model_provider_extension.provider_instance + if isinstance(provider_name, str): + if provider_name in (name, ""): + # get model_provider instance + model_provider_instance = model_provider_extension.provider_instance - # get provider schema - provider_schema = model_provider_instance.get_provider_schema() + # get provider schema + provider_schema = model_provider_instance.get_provider_schema() - for model_type in provider_schema.supported_model_types: - # get predefined models for given model type - models = model_provider_instance.models(model_type) - if models: - provider_schema.models.extend(models) + for model_type in provider_schema.supported_model_types: + # get predefined models for given model type + models = model_provider_instance.models(model_type) + if models: + provider_schema.models.extend(models) - providers.append(provider_schema) + providers.append(provider_schema) + elif isinstance(provider_name, set): + if name in provider_name: + # get model_provider instance + model_provider_instance = model_provider_extension.provider_instance + + # get provider schema + provider_schema = model_provider_instance.get_provider_schema() + + for model_type in provider_schema.supported_model_types: + # get predefined models for given model type + models = model_provider_instance.models(model_type) + if models: + provider_schema.models.extend(models) + + providers.append(provider_schema) # return providers return providers diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index 5dda1158..e8703e45 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from sqlalchemy.exc import IntegrityError @@ -45,7 +45,7 @@ class ProviderManager: provider_name_to_provider_model_records_dict ) - def get_configurations(self, provider: str) -> ProviderConfigurations: + def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations: """ Get model provider configurations. @@ -155,7 +155,7 @@ class ProviderManager: default_model = {} # Get provider configurations - provider_configurations = self.get_configurations() + provider_configurations = self.get_configurations(provider="openai") # get available models from provider_configurations available_models = provider_configurations.get_models(