mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-25 00:05:58 +08:00
351 lines
12 KiB
Python
351 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import multiprocessing as mp
|
|
import os
|
|
import pprint
|
|
import threading
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import tiktoken
|
|
import uvicorn
|
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from sse_starlette import EventSourceResponse
|
|
from uvicorn import Config, Server
|
|
|
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
|
ProviderListResponse,
|
|
ProviderModelTypeResponse,
|
|
)
|
|
from model_providers.bootstrap_web.message_convert import (
|
|
convert_to_message,
|
|
openai_chat_completion,
|
|
openai_embedding_text,
|
|
stream_openai_chat_completion,
|
|
)
|
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
|
from model_providers.core.bootstrap.openai_protocol import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
EmbeddingsRequest,
|
|
EmbeddingsResponse,
|
|
ModelCard,
|
|
ModelList,
|
|
)
|
|
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
|
from model_providers.core.model_runtime.entities.message_entities import (
|
|
PromptMessageTool,
|
|
)
|
|
from model_providers.core.model_runtime.entities.model_entities import (
|
|
AIModelEntity,
|
|
ModelType,
|
|
)
|
|
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
|
from model_providers.core.utils.generic import dictify
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|
"""
|
|
Bootstrap Server Lifecycle
|
|
"""
|
|
|
|
def __init__(self, host: str, port: int):
|
|
super().__init__()
|
|
self._host = host
|
|
self._port = port
|
|
self._router = APIRouter()
|
|
self._app = FastAPI()
|
|
self._logging_conf = None
|
|
self._server = None
|
|
self._server_thread = None
|
|
|
|
def logging_conf(self, logging_conf: Optional[dict] = None):
|
|
self._logging_conf = logging_conf
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
host = cfg.get("host", "127.0.0.1")
|
|
port = cfg.get("port", 20000)
|
|
|
|
logger.info(
|
|
f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}"
|
|
)
|
|
return cls(host=host, port=port)
|
|
|
|
def run(self):
|
|
self._app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
self._router.add_api_route(
|
|
"/workspaces/current/model-providers",
|
|
self.workspaces_model_providers,
|
|
response_model=ProviderListResponse,
|
|
methods=["GET"],
|
|
)
|
|
self._router.add_api_route(
|
|
"/workspaces/current/models/model-types/{model_type}",
|
|
self.workspaces_model_types,
|
|
response_model=ProviderModelTypeResponse,
|
|
methods=["GET"],
|
|
)
|
|
self._router.add_api_route(
|
|
"/{provider}/v1/models",
|
|
self.list_models,
|
|
response_model=ModelList,
|
|
methods=["GET"],
|
|
)
|
|
|
|
self._router.add_api_route(
|
|
"/{provider}/v1/embeddings",
|
|
self.create_embeddings,
|
|
response_model=EmbeddingsResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
methods=["POST"],
|
|
)
|
|
self._router.add_api_route(
|
|
"/{provider}/v1/chat/completions",
|
|
self.create_chat_completion,
|
|
response_model=ChatCompletionResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
methods=["POST"],
|
|
)
|
|
|
|
self._app.include_router(self._router)
|
|
|
|
config = Config(
|
|
app=self._app,
|
|
host=self._host,
|
|
port=self._port,
|
|
log_config=self._logging_conf,
|
|
)
|
|
self._server = Server(config)
|
|
|
|
def run_server():
|
|
self._server.shutdown_timeout = 2 # 设置为2秒
|
|
|
|
self._server.run()
|
|
|
|
self._server_thread = threading.Thread(target=run_server)
|
|
self._server_thread.start()
|
|
|
|
def destroy(self):
|
|
logger.info("Shutting down server")
|
|
self._server.should_exit = True # 设置退出标志
|
|
self._server.shutdown() # 停止服务器
|
|
|
|
self.join()
|
|
|
|
def join(self):
|
|
self._server_thread.join()
|
|
|
|
def set_app_event(self, started_event: mp.Event = None):
|
|
@self._app.on_event("startup")
|
|
async def on_startup():
|
|
if started_event is not None:
|
|
started_event.set()
|
|
|
|
async def workspaces_model_providers(self, request: Request):
|
|
provider_list = ProvidersWrapper(
|
|
provider_manager=self._provider_manager.provider_manager
|
|
).get_provider_list(model_type=request.get("model_type"))
|
|
return ProviderListResponse(data=provider_list)
|
|
|
|
async def workspaces_model_types(self, model_type: str, request: Request):
|
|
models_by_model_type = ProvidersWrapper(
|
|
provider_manager=self._provider_manager.provider_manager
|
|
).get_models_by_model_type(model_type=model_type)
|
|
return ProviderModelTypeResponse(data=models_by_model_type)
|
|
|
|
async def list_models(self, provider: str, request: Request):
|
|
logger.info(f"Received list_models request for provider: {provider}")
|
|
# 返回ModelType所有的枚举
|
|
ai_models: List[AIModelEntity] = []
|
|
for model_type in ModelType.__members__.values():
|
|
try:
|
|
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
|
provider="zhipuai", model_type=model_type
|
|
)
|
|
for model in (
|
|
provider_model_bundle_llm.configuration.custom_configuration.models
|
|
):
|
|
if model.model_type == model_type:
|
|
ai_models.append(
|
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
|
model=model.model,
|
|
credentials=model.credentials,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
|
)
|
|
|
|
# 获取预定义模型
|
|
ai_models.extend(
|
|
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
)
|
|
|
|
logger.info(f"ai_models: {ai_models}")
|
|
|
|
# modelsList[AIModelEntity]转换称List[ModelCard]
|
|
|
|
models_list = [
|
|
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
|
|
for model in llm_models
|
|
]
|
|
|
|
return ModelList(data=models_list)
|
|
|
|
async def create_embeddings(
|
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
|
):
|
|
logger.info(
|
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
|
)
|
|
try:
|
|
model_instance = self._provider_manager.get_model_instance(
|
|
provider=provider,
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
model=embeddings_request.model,
|
|
)
|
|
|
|
# 判断embeddings_request.input是否为list[int]
|
|
input = ""
|
|
if isinstance(embeddings_request.input, list):
|
|
tokens = embeddings_request.input
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
|
except KeyError:
|
|
logger.warning(
|
|
"Warning: model not found. Using cl100k_base encoding."
|
|
)
|
|
model = "cl100k_base"
|
|
encoding = tiktoken.get_encoding(model)
|
|
for i, token in enumerate(tokens):
|
|
# 判断是否是int
|
|
if isinstance(token, int):
|
|
text = encoding.decode(token)
|
|
input += text
|
|
else:
|
|
input += token
|
|
|
|
else:
|
|
input = embeddings_request.input
|
|
|
|
response = model_instance.invoke_text_embedding(
|
|
texts=[input], user="abc-123"
|
|
)
|
|
return await openai_embedding_text(response)
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Error while creating embeddings: {str(e)}")
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except InvokeError as e:
|
|
logger.error(f"Error while creating embeddings: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
)
|
|
|
|
async def create_chat_completion(
|
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
|
):
|
|
logger.info(
|
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
|
)
|
|
|
|
model_instance = self._provider_manager.get_model_instance(
|
|
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
|
)
|
|
prompt_messages = [
|
|
convert_to_message(message) for message in chat_request.messages
|
|
]
|
|
|
|
tools = []
|
|
if chat_request.tools:
|
|
tools = [
|
|
PromptMessageTool(
|
|
name=f.function.name,
|
|
description=f.function.description,
|
|
parameters=f.function.parameters,
|
|
)
|
|
for f in chat_request.tools
|
|
]
|
|
if chat_request.functions:
|
|
tools.extend(
|
|
[
|
|
PromptMessageTool(
|
|
name=f.name, description=f.description, parameters=f.parameters
|
|
)
|
|
for f in chat_request.functions
|
|
]
|
|
)
|
|
|
|
try:
|
|
response = model_instance.invoke_llm(
|
|
prompt_messages=prompt_messages,
|
|
model_parameters={**chat_request.to_model_parameters_dict()},
|
|
tools=tools,
|
|
stop=chat_request.stop,
|
|
stream=chat_request.stream,
|
|
user="abc-123",
|
|
)
|
|
|
|
if chat_request.stream:
|
|
return EventSourceResponse(
|
|
stream_openai_chat_completion(response),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await openai_chat_completion(response)
|
|
except ValueError as e:
|
|
logger.error(f"Error while creating chat completion: {str(e)}")
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
|
|
except InvokeError as e:
|
|
logger.error(f"Error while creating chat completion: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
)
|
|
|
|
|
|
def run(
|
|
cfg: Dict,
|
|
logging_conf: Optional[dict] = None,
|
|
started_event: mp.Event = None,
|
|
):
|
|
logging.config.dictConfig(logging_conf) # type: ignore
|
|
try:
|
|
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
|
cfg=cfg.get("run_openai_api", {})
|
|
)
|
|
api.set_app_event(started_event=started_event)
|
|
api.logging_conf(logging_conf=logging_conf)
|
|
api.run()
|
|
|
|
async def pool_join_thread():
|
|
api.join()
|
|
|
|
asyncio.run(pool_join_thread())
|
|
except SystemExit:
|
|
logger.info("SystemExit raised, exiting")
|
|
raise
|