diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index 04e8fe2b..ebbdf71e 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,3 +1,5 @@ +from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb +from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.model_manager import ModelManager from omegaconf import OmegaConf, DictConfig @@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig): return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict -model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml") -provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg=model_providers_cfg) -# 基于配置管理器创建的模型实例 -provider_manager = ModelManager( - provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, - provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict -) +class BootstrapWebBuilder: + """ + 创建一个模型实例创建工具 + """ + _model_providers_cfg_path: str + _host: str + _port: int + + def model_providers_cfg_path(self, model_providers_cfg_path: str): + self._model_providers_cfg_path = model_providers_cfg_path + return self + + def host(self, host: str): + self._host = host + return self + + def port(self, port: int): + self._port = port + return self + + def build(self) -> OpenAIBootstrapBaseWeb: + assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None + # 读取配置文件 + cfg = OmegaConf.load(self._model_providers_cfg_path) + # 转换配置文件 + provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( + cfg) + # 创建模型管理器 + provider_manager = ModelManager(provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict) + # 创建web服务 + restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={ + "host": self._host, + "port": self._port + + }) + restful.provider_manager = provider_manager + return restful diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 56adcce9..b1032861 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import ( FunctionAvailable, ModelList, ) +from model_providers.core.model_manager import ModelManager, ModelInstance from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.model_providers import model_provider_factory -from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( - LargeLanguageModel, -) from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) -async def create_stream_chat_completion( - model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest -): - try: - response = model_type_instance.invoke( - model=chat_request.model, - credentials={ - "openai_api_key": "sk-", - "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), - "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), - }, - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={**chat_request.to_model_parameters_dict()}, - stop=chat_request.stop, - stream=chat_request.stream, - user="abc-123", - ) - return response - - except Exception as e: - logger.exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ Bootstrap Server Lifecycle @@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) self._router.add_api_route( - "/v1/models", + "/{provider}/v1/models", self.list_models, response_model=ModelList, methods=["GET"], ) self._router.add_api_route( - "/v1/embeddings", + "/{provider}/v1/embeddings", self.create_embeddings, response_model=EmbeddingsResponse, status_code=status.HTTP_200_OK, methods=["POST"], ) self._router.add_api_route( - "/v1/chat/completions", + "/{provider}/v1/chat/completions", self.create_chat_completion, response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK, @@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): if started_event is not None: started_event.set() - async def list_models(self, request: Request): + async def list_models(self, provider: str, request: Request): pass async def create_embeddings( - self, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" ) - if os.environ["API_KEY"] is None: - authorization = request.headers.get("Authorization") - authorization = authorization.split("Bearer ")[-1] - else: - authorization = os.environ["API_KEY"] - client = ZhipuAI(api_key=authorization) - # 判断embeddings_request.input是否为list - input = None - if isinstance(embeddings_request.input, list): - tokens = embeddings_request.input - try: - encoding = tiktoken.encoding_for_model(embeddings_request.model) - except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") - model = "cl100k_base" - encoding = tiktoken.get_encoding(model) - for i, token in enumerate(tokens): - text = encoding.decode(token) - input += text - else: - input = embeddings_request.input - - response = client.embeddings.create( - model=embeddings_request.model, - input=input, - ) + response = None return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" ) - if os.environ["API_KEY"] is None: - authorization = request.headers.get("Authorization") - authorization = authorization.split("Bearer ")[-1] - else: - authorization = os.environ["API_KEY"] - model_provider_factory.get_providers(provider_name="openai") - provider_instance = model_provider_factory.get_provider_instance("openai") - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + model_instance = self._provider_manager.get_model_instance( + provider=provider, model_type=ModelType.LLM, model=chat_request.model + ) if chat_request.stream: - generator = create_stream_chat_completion(model_type_instance, chat_request) - return EventSourceResponse(generator, media_type="text/event-stream") - else: - response = model_type_instance.invoke( - model="gpt-4", - credentials={ - "openai_api_key": "sk-", - "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), - "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), - }, + # Invoke model + + response = model_instance.invoke_llm( prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={ - "temperature": 0.7, - "top_p": 1.0, - "top_k": 1, - "plugin_web_search": True, - }, - stop=["you"], - stream=False, + model_parameters={**chat_request.to_model_parameters_dict()}, + stop=chat_request.stop, + stream=chat_request.stream, + user="abc-123", + ) + + return EventSourceResponse(response, media_type="text/event-stream") + else: + # Invoke model + + response = model_instance.invoke_llm( + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={**chat_request.to_model_parameters_dict()}, + stop=chat_request.stop, + stream=chat_request.stream, user="abc-123", ) @@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: - import signal - - # 跳过键盘中断,使用xoscar的信号处理 - signal.signal(signal.SIGINT, lambda *_: None) api = RESTFulOpenAIBootstrapBaseWeb.from_config( cfg=cfg.get("run_openai_api", {}) ) diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index b2da1a0b..795eb565 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -3,9 +3,11 @@ from collections import deque from fastapi import Request +from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest +from model_providers.core.model_manager import ModelManager + class Bootstrap: - """最大的任务队列""" _MAX_ONGOING_TASKS: int = 1 @@ -39,21 +41,31 @@ class Bootstrap: class OpenAIBootstrapBaseWeb(Bootstrap): + _provider_manager: ModelManager + def __init__(self): super().__init__() + @property + def provider_manager(self) -> ModelManager: + return self._provider_manager + + @provider_manager.setter + def provider_manager(self, provider_manager: ModelManager): + self._provider_manager = provider_manager + @abstractmethod - async def list_models(self, request: Request): + async def list_models(self, provider: str, request: Request): pass @abstractmethod async def create_embeddings( - self, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): pass @abstractmethod async def create_chat_completion( - self, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): pass diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 1d7354cf..30d57c81 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel): def to_model_parameters_dict(self, *args, **kwargs): # 调用父类的to_dict方法,并排除tools字段 - helper.dump_model + return super().dict( exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs )