From f031ebc19e40c61c369a07cf706ec2748b4cbb87 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sat, 13 Jan 2024 16:11:30 +0800 Subject: [PATCH] =?UTF-8?q?ChatOpenAI=E4=B8=BA=E4=BA=86=E5=88=A4=E6=96=ADt?= =?UTF-8?q?oken=E6=9C=89=E6=B2=A1=E6=9C=89=E8=B6=85=E8=BF=87=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=9A=84context=E4=B8=8A=E4=B8=8B=E6=96=87=E9=95=BF?= =?UTF-8?q?=E5=BA=A6=EF=BC=8C=E6=AF=8F=E4=B8=AA=E6=A8=A1=E5=9E=8B=E7=9A=84?= =?UTF-8?q?token=E7=AE=97=E6=B3=95=E4=B8=8D=E4=B8=80=E6=A0=B7=20=EF=BC=8C?= =?UTF-8?q?=E6=89=80=E4=BB=A5=E8=BF=99=E9=87=8C=E5=BA=94=E8=AF=A5=E8=87=AA?= =?UTF-8?q?=E5=B7=B1=E5=AE=9E=E7=8E=B0token=E9=95=BF=E5=BA=A6=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=20=E7=AC=AC=E4=B8=80=E6=AC=A1=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E7=9A=84=E6=97=B6=E5=80=99=EF=BC=8Copenai=E7=9A=84?= =?UTF-8?q?=E7=B1=BB=E4=BC=9A=E5=BC=BA=E5=88=B6=E4=BD=BF=E7=94=A83.5?= =?UTF-8?q?=EF=BC=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/minx_chat_openai.py | 51 ++++++++++++++++++++++++++++++++ server/utils.py | 59 ++++++++++++++------------------------ 2 files changed, 73 insertions(+), 37 deletions(-) create mode 100644 server/minx_chat_openai.py diff --git a/server/minx_chat_openai.py b/server/minx_chat_openai.py new file mode 100644 index 00000000..b5362bba --- /dev/null +++ b/server/minx_chat_openai.py @@ -0,0 +1,51 @@ +from typing import ( + TYPE_CHECKING, + Any, + Tuple +) +import sys +import logging + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import tiktoken + + +class MinxChatOpenAI: + + @staticmethod + def import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + @staticmethod + def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]: + tiktoken_ = MinxChatOpenAI.import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + if model == "gpt-3.5-turbo": + # gpt-3.5-turbo may change over time. + # Returning num tokens assuming gpt-3.5-turbo-0301. + model = "gpt-3.5-turbo-0301" + elif model == "gpt-4": + # gpt-4 may change over time. + # Returning num tokens assuming gpt-4-0314. + model = "gpt-4-0314" + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except Exception as e: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding diff --git a/server/utils.py b/server/utils.py index 5512ee64..270c5158 100644 --- a/server/utils.py +++ b/server/utils.py @@ -12,10 +12,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI import httpx -from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple +from typing import ( + TYPE_CHECKING, + Literal, + Optional, + Callable, + Generator, + Dict, + Any, + Awaitable, + Union, + Tuple +) import logging import torch +from server.minx_chat_openai import MinxChatOpenAI + async def wrap_done(fn: Awaitable, event: asyncio.Event): """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" @@ -44,7 +57,7 @@ def get_ChatOpenAI( config = get_model_worker_config(model_name) if model_name == "openai-api": model_name = config.get("model_name") - + ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model model = ChatOpenAI( streaming=streaming, verbose=verbose, @@ -153,6 +166,7 @@ class ChatMessage(BaseModel): def torch_gc(): try: + import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() @@ -500,58 +514,29 @@ def set_httpx_config( # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch -def is_mps_available(): - return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - - -def is_cuda_available(): - return torch.cuda.is_available() - - def detect_device() -> Literal["cuda", "mps", "cpu"]: try: + import torch if torch.cuda.is_available(): return "cuda" - if is_mps_available(): + if torch.backends.mps.is_available(): return "mps" except: pass return "cpu" -def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: +def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE - if device not in ["cuda", "mps", "cpu", "xpu"]: - logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") + if device not in ["cuda", "mps", "cpu"]: device = detect_device() - elif device == 'cuda' and not is_cuda_available() and is_mps_available(): - logging.warning("cuda is not available, fallback to mps") - return "mps" - if device == 'mps' and not is_mps_available() and is_cuda_available(): - logging.warning("mps is not available, fallback to cuda") - return "cuda" - - # auto detect device if not specified - if device not in ["cuda", "mps", "cpu", "xpu"]: - return detect_device() return device -def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: - device = device or LLM_DEVICE +def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: + device = device or EMBEDDING_DEVICE if device not in ["cuda", "mps", "cpu"]: - logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") device = detect_device() - elif device == 'cuda' and not is_cuda_available() and is_mps_available(): - logging.warning("cuda is not available, fallback to mps") - return "mps" - if device == 'mps' and not is_mps_available() and is_cuda_available(): - logging.warning("mps is not available, fallback to cuda") - return "cuda" - - # auto detect device if not specified - if device not in ["cuda", "mps", "cpu"]: - return detect_device() return device