mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
chat_completions接口报文适配
This commit is contained in:
parent
a2df71d9ea
commit
2f1c9bfd11
22
model-providers/model_providers/bootstrap_web/common.py
Normal file
22
model-providers/model_providers/bootstrap_web/common.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import typing
|
||||||
|
from subprocess import Popen
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \
|
||||||
|
ChatCompletionStreamResponse, Finish
|
||||||
|
from model_providers.core.utils.generic import jsonify
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_chunk(
|
||||||
|
request_id: str,
|
||||||
|
model: str,
|
||||||
|
delta: "ChatCompletionMessage",
|
||||||
|
index: Optional[int] = 0,
|
||||||
|
finish_reason: Optional[Finish] = None,
|
||||||
|
) -> str:
|
||||||
|
choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
||||||
|
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
|
||||||
|
return jsonify(chunk)
|
||||||
@ -5,14 +5,14 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator
|
||||||
|
|
||||||
import tiktoken
|
|
||||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
|
from model_providers.bootstrap_web.common import create_stream_chunk
|
||||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@ -22,20 +22,212 @@ from model_providers.core.bootstrap.openai_protocol import (
|
|||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
FunctionAvailable,
|
FunctionAvailable,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_manager import ModelInstance, ModelManager
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
UserPromptMessage,
|
UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage,
|
||||||
|
PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent,
|
||||||
|
PromptMessageTool,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
from model_providers.core.utils.generic import dictify, jsonify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MessageLike = Union[ChatMessage, PromptMessage]
|
||||||
|
|
||||||
|
MessageLikeRepresentation = Union[
|
||||||
|
MessageLike,
|
||||||
|
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError("User message content must be str")
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls and len(message.tool_calls) > 0:
|
||||||
|
message_dict["function_call"] = {
|
||||||
|
"name": message.tool_calls[0].function.name,
|
||||||
|
"arguments": message.tool_calls[0].function.arguments,
|
||||||
|
}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
# check if last message is user message
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"role": "function", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_template_from_message_type(
|
||||||
|
message_type: str, template: Union[str, list]
|
||||||
|
) -> PromptMessage:
|
||||||
|
"""Create a message prompt template from a message type and template string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
||||||
|
template: str the template string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a message prompt template of the appropriate type.
|
||||||
|
"""
|
||||||
|
if isinstance(template, str):
|
||||||
|
content = template
|
||||||
|
elif isinstance(template, list):
|
||||||
|
content = []
|
||||||
|
for tmpl in template:
|
||||||
|
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||||
|
if isinstance(tmpl, str):
|
||||||
|
text: str = tmpl
|
||||||
|
else:
|
||||||
|
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||||
|
content.append(
|
||||||
|
TextPromptMessageContent(data=text)
|
||||||
|
)
|
||||||
|
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||||
|
img_template = cast(dict, tmpl)["image_url"]
|
||||||
|
if isinstance(img_template, str):
|
||||||
|
img_template_obj = ImagePromptMessageContent(data=img_template)
|
||||||
|
elif isinstance(img_template, dict):
|
||||||
|
img_template = dict(img_template)
|
||||||
|
if "url" in img_template:
|
||||||
|
url = img_template["url"]
|
||||||
|
else:
|
||||||
|
url = None
|
||||||
|
img_template_obj = ImagePromptMessageContent(data=url)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
content.append(img_template_obj)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if message_type in ("human", "user"):
|
||||||
|
_message = UserPromptMessage(content=content)
|
||||||
|
elif message_type in ("ai", "assistant"):
|
||||||
|
_message = AssistantPromptMessage(content=content)
|
||||||
|
elif message_type == "system":
|
||||||
|
_message = SystemPromptMessage(content=content)
|
||||||
|
elif message_type in ("function", "tool"):
|
||||||
|
_message = ToolPromptMessage(content=content)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||||
|
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_message(
|
||||||
|
message: MessageLikeRepresentation,
|
||||||
|
) -> Union[PromptMessage]:
|
||||||
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
|
The message format can be one of the following:
|
||||||
|
|
||||||
|
- BaseMessagePromptTemplate
|
||||||
|
- BaseMessage
|
||||||
|
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||||
|
- 2-tuple of (message class, template)
|
||||||
|
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: a representation of a message in one of the supported formats
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an instance of a message or a message template
|
||||||
|
"""
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
_message = _create_template_from_message_type(message.role.to_origin_role(), message.content)
|
||||||
|
|
||||||
|
elif isinstance(message, PromptMessage):
|
||||||
|
_message = message
|
||||||
|
elif isinstance(message, str):
|
||||||
|
_message = _create_template_from_message_type("human", message)
|
||||||
|
elif isinstance(message, tuple):
|
||||||
|
if len(message) != 2:
|
||||||
|
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||||
|
message_type_str, template = message
|
||||||
|
if isinstance(message_type_str, str):
|
||||||
|
_message = _create_template_from_message_type(message_type_str, template)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected message type string, got {message_type_str}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]:
|
||||||
|
request_id, model = None, None
|
||||||
|
for chunk in response:
|
||||||
|
if not isinstance(chunk, LLMResultChunk):
|
||||||
|
yield "[ERROR]"
|
||||||
|
return
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = chunk.model
|
||||||
|
if request_id is None:
|
||||||
|
request_id = "request_id"
|
||||||
|
yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content=""))
|
||||||
|
|
||||||
|
new_token = chunk.delta.message.content
|
||||||
|
|
||||||
|
if new_token:
|
||||||
|
delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()),
|
||||||
|
content=new_token,
|
||||||
|
tool_calls=chunk.delta.message.tool_calls)
|
||||||
|
yield create_stream_chunk(request_id=request_id,
|
||||||
|
model=model, delta=delta,
|
||||||
|
index=chunk.delta.index,
|
||||||
|
finish_reason=chunk.delta.finish_reason)
|
||||||
|
|
||||||
|
yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP)
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
|
||||||
|
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
|
||||||
|
choice = ChatCompletionResponseChoice(
|
||||||
|
index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)),
|
||||||
|
finish_reason=Finish.STOP
|
||||||
|
)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id="request_id",
|
||||||
|
model=response.model,
|
||||||
|
choices=[choice],
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||||
"""
|
"""
|
||||||
@ -143,7 +335,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return ModelList(data=models_list)
|
return ModelList(data=models_list)
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
@ -153,7 +345,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return EmbeddingsResponse(**dictify(response))
|
return EmbeddingsResponse(**dictify(response))
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
@ -162,38 +354,47 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
model_instance = self._provider_manager.get_model_instance(
|
model_instance = self._provider_manager.get_model_instance(
|
||||||
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
||||||
)
|
)
|
||||||
if chat_request.stream:
|
prompt_messages = [_convert_to_message(message) for message in chat_request.messages]
|
||||||
# Invoke model
|
|
||||||
|
tools = [PromptMessageTool(name=f.function.name,
|
||||||
|
description=f.function.description,
|
||||||
|
parameters=f.function.parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
for f in chat_request.tools]
|
||||||
|
if chat_request.functions:
|
||||||
|
tools.extend([PromptMessageTool(name=f.name,
|
||||||
|
description=f.description,
|
||||||
|
parameters=f.parameters
|
||||||
|
) for f in chat_request.functions])
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
response = model_instance.invoke_llm(
|
response = model_instance.invoke_llm(
|
||||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
prompt_messages=prompt_messages,
|
||||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||||
|
tools=tools,
|
||||||
stop=chat_request.stop,
|
stop=chat_request.stop,
|
||||||
stream=chat_request.stream,
|
stream=chat_request.stream,
|
||||||
user="abc-123",
|
user="abc-123",
|
||||||
)
|
)
|
||||||
|
|
||||||
return EventSourceResponse(response, media_type="text/event-stream")
|
if chat_request.stream:
|
||||||
else:
|
|
||||||
# Invoke model
|
|
||||||
|
|
||||||
response = model_instance.invoke_llm(
|
return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream")
|
||||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
else:
|
||||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
return await _openai_chat_completion(response)
|
||||||
stop=chat_request.stop,
|
except ValueError as e:
|
||||||
stream=chat_request.stream,
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
user="abc-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_response = ChatCompletionResponse(**dictify(response))
|
except InvokeError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||||
return chat_response
|
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
logging_conf: Optional[dict] = None,
|
logging_conf: Optional[dict] = None,
|
||||||
started_event: mp.Event = None,
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -13,12 +13,62 @@ class Role(str, Enum):
|
|||||||
FUNCTION = "function"
|
FUNCTION = "function"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, origin_role: str) -> "Role":
|
||||||
|
if origin_role == "user":
|
||||||
|
return cls.USER
|
||||||
|
elif origin_role == "assistant":
|
||||||
|
return cls.ASSISTANT
|
||||||
|
elif origin_role == "system":
|
||||||
|
return cls.SYSTEM
|
||||||
|
elif origin_role == "function":
|
||||||
|
return cls.FUNCTION
|
||||||
|
elif origin_role == "tool":
|
||||||
|
return cls.TOOL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid origin role {origin_role}")
|
||||||
|
|
||||||
|
def to_origin_role(self) -> str:
|
||||||
|
if self == self.USER:
|
||||||
|
return "user"
|
||||||
|
elif self == self.ASSISTANT:
|
||||||
|
return "assistant"
|
||||||
|
elif self == self.SYSTEM:
|
||||||
|
return "system"
|
||||||
|
elif self == self.FUNCTION:
|
||||||
|
return "function"
|
||||||
|
elif self == self.TOOL:
|
||||||
|
return "tool"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid role {self}")
|
||||||
|
|
||||||
|
|
||||||
class Finish(str, Enum):
|
class Finish(str, Enum):
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
LENGTH = "length"
|
LENGTH = "length"
|
||||||
TOOL = "tool_calls"
|
TOOL = "tool_calls"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, origin_finish: str) -> "Finish":
|
||||||
|
if origin_finish == "stop":
|
||||||
|
return cls.STOP
|
||||||
|
elif origin_finish == "length":
|
||||||
|
return cls.LENGTH
|
||||||
|
elif origin_finish == "tool_calls":
|
||||||
|
return cls.TOOL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid origin finish {origin_finish}")
|
||||||
|
|
||||||
|
def to_origin_finish(self) -> str:
|
||||||
|
if self == self.STOP:
|
||||||
|
return "stop"
|
||||||
|
elif self == self.LENGTH:
|
||||||
|
return "length"
|
||||||
|
elif self == self.TOOL:
|
||||||
|
return "tool_calls"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid finish {self}")
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
@ -95,7 +145,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_k: Optional[float] = None
|
top_k: Optional[float] = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stop: Optional[list[str]] = (None,)
|
stop: Optional[list[str]] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
def to_model_parameters_dict(self, *args, **kwargs):
|
def to_model_parameters_dict(self, *args, **kwargs):
|
||||||
|
|||||||
@ -28,6 +28,23 @@ class PromptMessageRole(Enum):
|
|||||||
return mode
|
return mode
|
||||||
raise ValueError(f"invalid prompt message type value {value}")
|
raise ValueError(f"invalid prompt message type value {value}")
|
||||||
|
|
||||||
|
def to_origin_role(self) -> str:
|
||||||
|
"""
|
||||||
|
Get origin role from prompt message role.
|
||||||
|
|
||||||
|
:return: origin role
|
||||||
|
"""
|
||||||
|
if self == self.SYSTEM:
|
||||||
|
return "system"
|
||||||
|
elif self == self.USER:
|
||||||
|
return "user"
|
||||||
|
elif self == self.ASSISTANT:
|
||||||
|
return "assistant"
|
||||||
|
elif self == self.TOOL:
|
||||||
|
return "tool"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid role {self}")
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageTool(BaseModel):
|
class PromptMessageTool(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user