使用BootstrapWebBuilder适配RESTFulOpenAIBootstrapBaseWeb加载

This commit is contained in:
glide-the 2024-03-29 18:25:16 +08:00
parent 3ed9162392
commit 032dc8f58d
4 changed files with 90 additions and 107 deletions

View File

@ -1,3 +1,5 @@
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig):
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml") class BootstrapWebBuilder:
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( """
cfg=model_providers_cfg) 创建一个模型实例创建工具
# 基于配置管理器创建的模型实例 """
provider_manager = ModelManager( _model_providers_cfg_path: str
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, _host: str
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict _port: int
)
def model_providers_cfg_path(self, model_providers_cfg_path: str):
self._model_providers_cfg_path = model_providers_cfg_path
return self
def host(self, host: str):
self._host = host
return self
def port(self, port: int):
self._port = port
return self
def build(self) -> OpenAIBootstrapBaseWeb:
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
# 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
# 创建模型管理器
provider_manager = ModelManager(provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict)
# 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
"host": self._host,
"port": self._port
})
restful.provider_manager = provider_manager
return restful

View File

@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import (
FunctionAvailable, FunctionAvailable,
ModelList, ModelList,
) )
from model_providers.core.model_manager import ModelManager, ModelInstance
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def create_stream_chat_completion(
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
return response
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
""" """
Bootstrap Server Lifecycle Bootstrap Server Lifecycle
@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/models", "/{provider}/v1/models",
self.list_models, self.list_models,
response_model=ModelList, response_model=ModelList,
methods=["GET"], methods=["GET"],
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/embeddings", "/{provider}/v1/embeddings",
self.create_embeddings, self.create_embeddings,
response_model=EmbeddingsResponse, response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
methods=["POST"], methods=["POST"],
) )
self._router.add_api_route( self._router.add_api_route(
"/v1/chat/completions", "/{provider}/v1/chat/completions",
self.create_chat_completion, self.create_chat_completion,
response_model=ChatCompletionResponse, response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None: if started_event is not None:
started_event.set() started_event.set()
async def list_models(self, request: Request): async def list_models(self, provider: str, request: Request):
pass pass
async def create_embeddings( async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
) )
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else: response = None
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
return EmbeddingsResponse(**dictify(response)) return EmbeddingsResponse(**dictify(response))
async def create_chat_completion( async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
logger.info( logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}" f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
) )
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization") model_instance = self._provider_manager.get_model_instance(
authorization = authorization.split("Bearer ")[-1] provider=provider, model_type=ModelType.LLM, model=chat_request.model
else: )
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name="openai")
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream: if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request) # Invoke model
return EventSourceResponse(generator, media_type="text/event-stream")
else: response = model_instance.invoke_llm(
response = model_type_instance.invoke(
model="gpt-4",
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={ model_parameters={**chat_request.to_model_parameters_dict()},
"temperature": 0.7, stop=chat_request.stop,
"top_p": 1.0, stream=chat_request.stream,
"top_k": 1, user="abc-123",
"plugin_web_search": True, )
},
stop=["you"], return EventSourceResponse(response, media_type="text/event-stream")
stream=False, else:
# Invoke model
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123", user="abc-123",
) )
@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run( def run(
cfg: Dict, cfg: Dict,
logging_conf: Optional[dict] = None, logging_conf: Optional[dict] = None,
started_event: mp.Event = None, started_event: mp.Event = None,
): ):
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
try: try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config( api = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg=cfg.get("run_openai_api", {}) cfg=cfg.get("run_openai_api", {})
) )

View File

@ -3,9 +3,11 @@ from collections import deque
from fastapi import Request from fastapi import Request
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
from model_providers.core.model_manager import ModelManager
class Bootstrap: class Bootstrap:
"""最大的任务队列""" """最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1 _MAX_ONGOING_TASKS: int = 1
@ -39,21 +41,31 @@ class Bootstrap:
class OpenAIBootstrapBaseWeb(Bootstrap): class OpenAIBootstrapBaseWeb(Bootstrap):
_provider_manager: ModelManager
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
@abstractmethod @abstractmethod
async def list_models(self, request: Request): async def list_models(self, provider: str, request: Request):
pass pass
@abstractmethod @abstractmethod
async def create_embeddings( async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
pass pass
@abstractmethod @abstractmethod
async def create_chat_completion( async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
pass pass

View File

@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel):
def to_model_parameters_dict(self, *args, **kwargs): def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段 # 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict( return super().dict(
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
) )