mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
使用BootstrapWebBuilder适配RESTFulOpenAIBootstrapBaseWeb加载
This commit is contained in:
parent
3ed9162392
commit
032dc8f58d
@ -1,3 +1,5 @@
|
|||||||
|
from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb
|
||||||
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
@ -45,11 +47,41 @@ def _to_custom_provide_configuration(cfg: DictConfig):
|
|||||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
||||||
|
|
||||||
|
|
||||||
model_providers_cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers/model_providers.yaml")
|
class BootstrapWebBuilder:
|
||||||
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
"""
|
||||||
cfg=model_providers_cfg)
|
创建一个模型实例创建工具
|
||||||
# 基于配置管理器创建的模型实例
|
"""
|
||||||
provider_manager = ModelManager(
|
_model_providers_cfg_path: str
|
||||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
_host: str
|
||||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict
|
_port: int
|
||||||
)
|
|
||||||
|
def model_providers_cfg_path(self, model_providers_cfg_path: str):
|
||||||
|
self._model_providers_cfg_path = model_providers_cfg_path
|
||||||
|
return self
|
||||||
|
|
||||||
|
def host(self, host: str):
|
||||||
|
self._host = host
|
||||||
|
return self
|
||||||
|
|
||||||
|
def port(self, port: int):
|
||||||
|
self._port = port
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> OpenAIBootstrapBaseWeb:
|
||||||
|
assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None
|
||||||
|
# 读取配置文件
|
||||||
|
cfg = OmegaConf.load(self._model_providers_cfg_path)
|
||||||
|
# 转换配置文件
|
||||||
|
provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration(
|
||||||
|
cfg)
|
||||||
|
# 创建模型管理器
|
||||||
|
provider_manager = ModelManager(provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict)
|
||||||
|
# 创建web服务
|
||||||
|
restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={
|
||||||
|
"host": self._host,
|
||||||
|
"port": self._port
|
||||||
|
|
||||||
|
})
|
||||||
|
restful.provider_manager = provider_manager
|
||||||
|
return restful
|
||||||
|
|||||||
@ -23,44 +23,16 @@ from model_providers.core.bootstrap.openai_protocol import (
|
|||||||
FunctionAvailable,
|
FunctionAvailable,
|
||||||
ModelList,
|
ModelList,
|
||||||
)
|
)
|
||||||
|
from model_providers.core.model_manager import ModelManager, ModelInstance
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
|
||||||
LargeLanguageModel,
|
|
||||||
)
|
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
from model_providers.core.utils.generic import dictify, jsonify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def create_stream_chat_completion(
|
|
||||||
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
response = model_type_instance.invoke(
|
|
||||||
model=chat_request.model,
|
|
||||||
credentials={
|
|
||||||
"openai_api_key": "sk-",
|
|
||||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
|
||||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
|
||||||
},
|
|
||||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
|
||||||
model_parameters={**chat_request.to_model_parameters_dict()},
|
|
||||||
stop=chat_request.stop,
|
|
||||||
stream=chat_request.stream,
|
|
||||||
user="abc-123",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||||
"""
|
"""
|
||||||
Bootstrap Server Lifecycle
|
Bootstrap Server Lifecycle
|
||||||
@ -94,21 +66,21 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._router.add_api_route(
|
self._router.add_api_route(
|
||||||
"/v1/models",
|
"/{provider}/v1/models",
|
||||||
self.list_models,
|
self.list_models,
|
||||||
response_model=ModelList,
|
response_model=ModelList,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self._router.add_api_route(
|
self._router.add_api_route(
|
||||||
"/v1/embeddings",
|
"/{provider}/v1/embeddings",
|
||||||
self.create_embeddings,
|
self.create_embeddings,
|
||||||
response_model=EmbeddingsResponse,
|
response_model=EmbeddingsResponse,
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
)
|
)
|
||||||
self._router.add_api_route(
|
self._router.add_api_route(
|
||||||
"/v1/chat/completions",
|
"/{provider}/v1/chat/completions",
|
||||||
self.create_chat_completion,
|
self.create_chat_completion,
|
||||||
response_model=ChatCompletionResponse,
|
response_model=ChatCompletionResponse,
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
@ -137,78 +109,49 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
if started_event is not None:
|
if started_event is not None:
|
||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
async def list_models(self, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
)
|
)
|
||||||
if os.environ["API_KEY"] is None:
|
|
||||||
authorization = request.headers.get("Authorization")
|
|
||||||
authorization = authorization.split("Bearer ")[-1]
|
|
||||||
else:
|
|
||||||
authorization = os.environ["API_KEY"]
|
|
||||||
client = ZhipuAI(api_key=authorization)
|
|
||||||
# 判断embeddings_request.input是否为list
|
|
||||||
input = None
|
|
||||||
if isinstance(embeddings_request.input, list):
|
|
||||||
tokens = embeddings_request.input
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
|
||||||
except KeyError:
|
|
||||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
|
||||||
model = "cl100k_base"
|
|
||||||
encoding = tiktoken.get_encoding(model)
|
|
||||||
for i, token in enumerate(tokens):
|
|
||||||
text = encoding.decode(token)
|
|
||||||
input += text
|
|
||||||
|
|
||||||
else:
|
response = None
|
||||||
input = embeddings_request.input
|
|
||||||
|
|
||||||
response = client.embeddings.create(
|
|
||||||
model=embeddings_request.model,
|
|
||||||
input=input,
|
|
||||||
)
|
|
||||||
return EmbeddingsResponse(**dictify(response))
|
return EmbeddingsResponse(**dictify(response))
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
)
|
)
|
||||||
if os.environ["API_KEY"] is None:
|
|
||||||
authorization = request.headers.get("Authorization")
|
model_instance = self._provider_manager.get_model_instance(
|
||||||
authorization = authorization.split("Bearer ")[-1]
|
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
||||||
else:
|
)
|
||||||
authorization = os.environ["API_KEY"]
|
|
||||||
model_provider_factory.get_providers(provider_name="openai")
|
|
||||||
provider_instance = model_provider_factory.get_provider_instance("openai")
|
|
||||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
# Invoke model
|
||||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
|
||||||
else:
|
response = model_instance.invoke_llm(
|
||||||
response = model_type_instance.invoke(
|
|
||||||
model="gpt-4",
|
|
||||||
credentials={
|
|
||||||
"openai_api_key": "sk-",
|
|
||||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
|
||||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
|
||||||
},
|
|
||||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||||
model_parameters={
|
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||||
"temperature": 0.7,
|
stop=chat_request.stop,
|
||||||
"top_p": 1.0,
|
stream=chat_request.stream,
|
||||||
"top_k": 1,
|
user="abc-123",
|
||||||
"plugin_web_search": True,
|
)
|
||||||
},
|
|
||||||
stop=["you"],
|
return EventSourceResponse(response, media_type="text/event-stream")
|
||||||
stream=False,
|
else:
|
||||||
|
# Invoke model
|
||||||
|
|
||||||
|
response = model_instance.invoke_llm(
|
||||||
|
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||||
|
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||||
|
stop=chat_request.stop,
|
||||||
|
stream=chat_request.stream,
|
||||||
user="abc-123",
|
user="abc-123",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -218,16 +161,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
logging_conf: Optional[dict] = None,
|
logging_conf: Optional[dict] = None,
|
||||||
started_event: mp.Event = None,
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
import signal
|
|
||||||
|
|
||||||
# 跳过键盘中断,使用xoscar的信号处理
|
|
||||||
signal.signal(signal.SIGINT, lambda *_: None)
|
|
||||||
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||||
cfg=cfg.get("run_openai_api", {})
|
cfg=cfg.get("run_openai_api", {})
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,9 +3,11 @@ from collections import deque
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest
|
||||||
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
class Bootstrap:
|
class Bootstrap:
|
||||||
|
|
||||||
"""最大的任务队列"""
|
"""最大的任务队列"""
|
||||||
|
|
||||||
_MAX_ONGOING_TASKS: int = 1
|
_MAX_ONGOING_TASKS: int = 1
|
||||||
@ -39,21 +41,31 @@ class Bootstrap:
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||||
|
_provider_manager: ModelManager
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_manager(self) -> ModelManager:
|
||||||
|
return self._provider_manager
|
||||||
|
|
||||||
|
@provider_manager.setter
|
||||||
|
def provider_manager(self, provider_manager: ModelManager):
|
||||||
|
self._provider_manager = provider_manager
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def list_models(self, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
def to_model_parameters_dict(self, *args, **kwargs):
|
def to_model_parameters_dict(self, *args, **kwargs):
|
||||||
# 调用父类的to_dict方法,并排除tools字段
|
# 调用父类的to_dict方法,并排除tools字段
|
||||||
helper.dump_model
|
|
||||||
return super().dict(
|
return super().dict(
|
||||||
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
|
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user