mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
使用BootstrapWebBuilder适配RESTFulOpenAIBootstrapBaseWeb加载
This commit is contained in:
parent
3ed9162392
commit
032dc8f58d
@ -1,3 +1,5 @@
|
||||
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
|
||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
||||
|
||||
|
||||
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg=model_providers_cfg)
|
||||
# 基于配置管理器创建的模型实例
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
|
||||
)
|
||||
class BootstrapWebBuilder:
|
||||
"""
|
||||
创建一个模型实例创建工具
|
||||
"""
|
||||
_model_providers_cfg_path: str
|
||||
_host: str
|
||||
_port: int
|
||||
|
||||
def model_providers_cfg_path(self, model_providers_cfg_path: str):
|
||||
self._model_providers_cfg_path = model_providers_cfg_path
|
||||
return self
|
||||
|
||||
def host(self, host: str):
|
||||
self._host = host
|
||||
return self
|
||||
|
||||
def port(self, port: int):
|
||||
self._port = port
|
||||
return self
|
||||
|
||||
def build(self) -> OpenAIBootstrapBaseWeb:
|
||||
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
||||
# 转换配置文件
|
||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||
cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ModelManager(provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict)
|
||||
# 创建web服务
|
||||
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
|
||||
"host": self._host,
|
||||
"port": self._port
|
||||
|
||||
})
|
||||
restful.provider_manager = provider_manager
|
||||
return restful
|
||||
|
||||
@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import (
|
||||
FunctionAvailable,
|
||||
ModelList,
|
||||
)
|
||||
from model_providers.core.model_manager import ModelManager, ModelInstance
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
from model_providers.core.utils.generic import dictify, jsonify
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_stream_chat_completion(
|
||||
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
|
||||
):
|
||||
try:
|
||||
response = model_type_instance.invoke(
|
||||
model=chat_request.model,
|
||||
credentials={
|
||||
"openai_api_key": "sk-",
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123",
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
"""
|
||||
Bootstrap Server Lifecycle
|
||||
@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/models",
|
||||
"/{provider}/v1/models",
|
||||
self.list_models,
|
||||
response_model=ModelList,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self._router.add_api_route(
|
||||
"/v1/embeddings",
|
||||
"/{provider}/v1/embeddings",
|
||||
self.create_embeddings,
|
||||
response_model=EmbeddingsResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
methods=["POST"],
|
||||
)
|
||||
self._router.add_api_route(
|
||||
"/v1/chat/completions",
|
||||
"/{provider}/v1/chat/completions",
|
||||
self.create_chat_completion,
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
|
||||
async def list_models(self, request: Request):
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
pass
|
||||
|
||||
async def create_embeddings(
|
||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
)
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
client = ZhipuAI(api_key=authorization)
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = None
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
text = encoding.decode(token)
|
||||
input += text
|
||||
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=embeddings_request.model,
|
||||
input=input,
|
||||
)
|
||||
response = None
|
||||
return EmbeddingsResponse(**dictify(response))
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
logger.info(
|
||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||
)
|
||||
if os.environ["API_KEY"] is None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
authorization = authorization.split("Bearer ")[-1]
|
||||
else:
|
||||
authorization = os.environ["API_KEY"]
|
||||
model_provider_factory.get_providers(provider_name="openai")
|
||||
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
||||
)
|
||||
if chat_request.stream:
|
||||
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
else:
|
||||
response = model_type_instance.invoke(
|
||||
model="gpt-4",
|
||||
credentials={
|
||||
"openai_api_key": "sk-",
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
# Invoke model
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=False,
|
||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
return EventSourceResponse(response, media_type="text/event-stream")
|
||||
else:
|
||||
# Invoke model
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
|
||||
|
||||
def run(
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
cfg: Dict,
|
||||
logging_conf: Optional[dict] = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
try:
|
||||
import signal
|
||||
|
||||
# 跳过键盘中断,使用xoscar的信号处理
|
||||
signal.signal(signal.SIGINT, lambda *_: None)
|
||||
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||
cfg=cfg.get("run_openai_api", {})
|
||||
)
|
||||
|
||||
@ -3,9 +3,11 @@ from collections import deque
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
|
||||
|
||||
class Bootstrap:
|
||||
|
||||
"""最大的任务队列"""
|
||||
|
||||
_MAX_ONGOING_TASKS: int = 1
|
||||
@ -39,21 +41,31 @@ class Bootstrap:
|
||||
|
||||
|
||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||
_provider_manager: ModelManager
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ModelManager:
|
||||
return self._provider_manager
|
||||
|
||||
@provider_manager.setter
|
||||
def provider_manager(self, provider_manager: ModelManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self, request: Request):
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_embeddings(
|
||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(
|
||||
self, request: Request, chat_request: ChatCompletionRequest
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
):
|
||||
pass
|
||||
|
||||
@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
def to_model_parameters_dict(self, *args, **kwargs):
|
||||
# 调用父类的to_dict方法,并排除tools字段
|
||||
helper.dump_model
|
||||
|
||||
return super().dict(
|
||||
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user