make format

This commit is contained in:
glide-the 2024-03-31 17:55:57 +08:00
parent 2f1c9bfd11
commit 056b15b99b
2 changed files with 111 additions and 55 deletions

View File

@ -2,8 +2,11 @@ import typing
from subprocess import Popen from subprocess import Popen
from typing import Optional from typing import Optional
from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \ from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse, Finish ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -11,12 +14,14 @@ if typing.TYPE_CHECKING:
def create_stream_chunk( def create_stream_chunk(
request_id: str, request_id: str,
model: str, model: str,
delta: "ChatCompletionMessage", delta: "ChatCompletionMessage",
index: Optional[int] = 0, index: Optional[int] = 0,
finish_reason: Optional[Finish] = None, finish_reason: Optional[Finish] = None,
) -> str: ) -> str:
choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk) return jsonify(chunk)

View File

@ -5,7 +5,18 @@ import multiprocessing as mp
import os import os
import pprint import pprint
import threading import threading
from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -15,21 +26,36 @@ from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import ( from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
Finish,
FunctionAvailable, FunctionAvailable,
ModelCard, ModelCard,
ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo, ModelList,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
) )
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage, AssistantPromptMessage,
PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
@ -81,7 +107,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
def _create_template_from_message_type( def _create_template_from_message_type(
message_type: str, template: Union[str, list] message_type: str, template: Union[str, list]
) -> PromptMessage: ) -> PromptMessage:
"""Create a message prompt template from a message type and template string. """Create a message prompt template from a message type and template string.
@ -102,9 +128,7 @@ def _create_template_from_message_type(
text: str = tmpl text: str = tmpl
else: else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append( content.append(TextPromptMessageContent(data=text))
TextPromptMessageContent(data=text)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl: elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"] img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str): if isinstance(img_template, str):
@ -142,7 +166,7 @@ def _create_template_from_message_type(
def _convert_to_message( def _convert_to_message(
message: MessageLikeRepresentation, message: MessageLikeRepresentation,
) -> Union[PromptMessage]: ) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats. """Instantiate a message from a variety of message formats.
@ -161,7 +185,9 @@ def _convert_to_message(
an instance of a message or a message template an instance of a message or a message template
""" """
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(message.role.to_origin_role(), message.content) _message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage): elif isinstance(message, PromptMessage):
_message = message _message = message
@ -174,16 +200,16 @@ def _convert_to_message(
if isinstance(message_type_str, str): if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template) _message = _create_template_from_message_type(message_type_str, template)
else: else:
raise ValueError( raise ValueError(f"Expected message type string, got {message_type_str}")
f"Expected message type string, got {message_type_str}"
)
else: else:
raise NotImplementedError(f"Unsupported message type: {type(message)}") raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message return _message
async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]: async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None request_id, model = None, None
for chunk in response: for chunk in response:
if not isinstance(chunk, LLMResultChunk): if not isinstance(chunk, LLMResultChunk):
@ -194,27 +220,41 @@ async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[
model = chunk.model model = chunk.model
if request_id is None: if request_id is None:
request_id = "request_id" request_id = "request_id"
yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content="")) yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content new_token = chunk.delta.message.content
if new_token: if new_token:
delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()), delta = ChatCompletionMessage(
content=new_token, role=Role.value_of(chunk.delta.message.role.to_origin_role()),
tool_calls=chunk.delta.message.tool_calls) content=new_token,
yield create_stream_chunk(request_id=request_id, tool_calls=chunk.delta.message.tool_calls,
model=model, delta=delta, )
index=chunk.delta.index, yield create_stream_chunk(
finish_reason=chunk.delta.finish_reason) request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP) yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]" yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice( choice = ChatCompletionResponseChoice(
index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)), index=0,
finish_reason=Finish.STOP message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
) )
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens, prompt_tokens=response.usage.prompt_tokens,
@ -335,7 +375,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list) return ModelList(data=models_list)
async def create_embeddings( async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -345,7 +385,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response)) return EmbeddingsResponse(**dictify(response))
async def create_chat_completion( async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
logger.info( logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}" f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -354,22 +394,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
model_instance = self._provider_manager.get_model_instance( model_instance = self._provider_manager.get_model_instance(
provider=provider, model_type=ModelType.LLM, model=chat_request.model provider=provider, model_type=ModelType.LLM, model=chat_request.model
) )
prompt_messages = [_convert_to_message(message) for message in chat_request.messages] prompt_messages = [
_convert_to_message(message) for message in chat_request.messages
]
tools = [PromptMessageTool(name=f.function.name, tools = [
description=f.function.description, PromptMessageTool(
parameters=f.function.parameters name=f.function.name,
) description=f.function.description,
parameters=f.function.parameters,
for f in chat_request.tools] )
for f in chat_request.tools
]
if chat_request.functions: if chat_request.functions:
tools.extend([PromptMessageTool(name=f.name, tools.extend(
description=f.description, [
parameters=f.parameters PromptMessageTool(
) for f in chat_request.functions]) name=f.name, description=f.description, parameters=f.parameters
)
for f in chat_request.functions
]
)
try: try:
response = model_instance.invoke_llm( response = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters={**chat_request.to_model_parameters_dict()}, model_parameters={**chat_request.to_model_parameters_dict()},
@ -380,21 +427,25 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
) )
if chat_request.stream: if chat_request.stream:
return EventSourceResponse(
return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream") _stream_openai_chat_completion(response),
media_type="text/event-stream",
)
else: else:
return await _openai_chat_completion(response) return await _openai_chat_completion(response)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e: except InvokeError as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
def run( def run(
cfg: Dict, cfg: Dict,
logging_conf: Optional[dict] = None, logging_conf: Optional[dict] = None,
started_event: mp.Event = None, started_event: mp.Event = None,
): ):
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
try: try: