使用BootstrapWebBuilder适配RESTFulOpenAIBootstrapBaseWeb加载

This commit is contained in:
glide-the 2024-03-29 18:25:16 +08:00
parent 3ed9162392
commit 032dc8f58d
4 changed files with 90 additions and 107 deletions

View File

@ -1,3 +1,5 @@
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.model_manager import ModelManager
from omegaconf import OmegaConf, DictConfig
@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig):
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg=model_providers_cfg)
# 基于配置管理器创建的模型实例
provider_manager = ModelManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
)
class BootstrapWebBuilder:
"""
创建一个模型实例创建工具
"""
_model_providers_cfg_path: str
_host: str
_port: int
def model_providers_cfg_path(self, model_providers_cfg_path: str):
self._model_providers_cfg_path = model_providers_cfg_path
return self
def host(self, host: str):
self._host = host
return self
def port(self, port: int):
self._port = port
return self
def build(self) -> OpenAIBootstrapBaseWeb:
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
# 读取配置文件
cfg = OmegaConf.load(self._model_providers_cfg_path)
# 转换配置文件
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
cfg)
# 创建模型管理器
provider_manager = ModelManager(provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict)
# 创建web服务
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
"host": self._host,
"port": self._port
})
restful.provider_manager = provider_manager
return restful

View File

@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import (
FunctionAvailable,
ModelList,
)
from model_providers.core.model_manager import ModelManager, ModelInstance
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
async def create_stream_chat_completion(
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
return response
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
"""
Bootstrap Server Lifecycle
@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
self._router.add_api_route(
"/v1/models",
"/{provider}/v1/models",
self.list_models,
response_model=ModelList,
methods=["GET"],
)
self._router.add_api_route(
"/v1/embeddings",
"/{provider}/v1/embeddings",
self.create_embeddings,
response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._router.add_api_route(
"/v1/chat/completions",
"/{provider}/v1/chat/completions",
self.create_chat_completion,
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None:
started_event.set()
async def list_models(self, request: Request):
async def list_models(self, provider: str, request: Request):
pass
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
response = None
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name="openai")
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_instance = self._provider_manager.get_model_instance(
provider=provider, model_type=ModelType.LLM, model=chat_request.model
)
if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request)
return EventSourceResponse(generator, media_type="text/event-stream")
else:
response = model_type_instance.invoke(
model="gpt-4",
credentials={
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
# Invoke model
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=False,
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
return EventSourceResponse(response, media_type="text/event-stream")
else:
# Invoke model
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg=cfg.get("run_openai_api", {})
)

View File

@ -3,9 +3,11 @@ from collections import deque
from fastapi import Request
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
from model_providers.core.model_manager import ModelManager
class Bootstrap:
"""最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1
@ -39,21 +41,31 @@ class Bootstrap:
class OpenAIBootstrapBaseWeb(Bootstrap):
_provider_manager: ModelManager
def __init__(self):
super().__init__()
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
@abstractmethod
async def list_models(self, request: Request):
async def list_models(self, provider: str, request: Request):
pass
@abstractmethod
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
pass
@abstractmethod
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
pass

View File

@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel):
def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict(
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
)