diff --git a/server/minx_chat_openai.py b/server/minx_chat_openai.py new file mode 100644 index 00000000..b5362bba --- /dev/null +++ b/server/minx_chat_openai.py @@ -0,0 +1,51 @@ +from typing import ( + TYPE_CHECKING, + Any, + Tuple +) +import sys +import logging + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import tiktoken + + +class MinxChatOpenAI: + + @staticmethod + def import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + @staticmethod + def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]: + tiktoken_ = MinxChatOpenAI.import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + if model == "gpt-3.5-turbo": + # gpt-3.5-turbo may change over time. + # Returning num tokens assuming gpt-3.5-turbo-0301. + model = "gpt-3.5-turbo-0301" + elif model == "gpt-4": + # gpt-4 may change over time. + # Returning num tokens assuming gpt-4-0314. + model = "gpt-4-0314" + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except Exception as e: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding diff --git a/server/utils.py b/server/utils.py index 5512ee64..270c5158 100644 --- a/server/utils.py +++ b/server/utils.py @@ -12,10 +12,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI import httpx -from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple +from typing import ( + TYPE_CHECKING, + Literal, + Optional, + Callable, + Generator, + Dict, + Any, + Awaitable, + Union, + Tuple +) import logging import torch +from server.minx_chat_openai import MinxChatOpenAI + async def wrap_done(fn: Awaitable, event: asyncio.Event): """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" @@ -44,7 +57,7 @@ def get_ChatOpenAI( config = get_model_worker_config(model_name) if model_name == "openai-api": model_name = config.get("model_name") - + ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model model = ChatOpenAI( streaming=streaming, verbose=verbose, @@ -153,6 +166,7 @@ class ChatMessage(BaseModel): def torch_gc(): try: + import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() @@ -500,58 +514,29 @@ def set_httpx_config( # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch -def is_mps_available(): - return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - - -def is_cuda_available(): - return torch.cuda.is_available() - - def detect_device() -> Literal["cuda", "mps", "cpu"]: try: + import torch if torch.cuda.is_available(): return "cuda" - if is_mps_available(): + if torch.backends.mps.is_available(): return "mps" except: pass return "cpu" -def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: +def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE - if device not in ["cuda", "mps", "cpu", "xpu"]: - logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") + if device not in ["cuda", "mps", "cpu"]: device = detect_device() - elif device == 'cuda' and not is_cuda_available() and is_mps_available(): - logging.warning("cuda is not available, fallback to mps") - return "mps" - if device == 'mps' and not is_mps_available() and is_cuda_available(): - logging.warning("mps is not available, fallback to cuda") - return "cuda" - - # auto detect device if not specified - if device not in ["cuda", "mps", "cpu", "xpu"]: - return detect_device() return device -def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: - device = device or LLM_DEVICE +def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: + device = device or EMBEDDING_DEVICE if device not in ["cuda", "mps", "cpu"]: - logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") device = detect_device() - elif device == 'cuda' and not is_cuda_available() and is_mps_available(): - logging.warning("cuda is not available, fallback to mps") - return "mps" - if device == 'mps' and not is_mps_available() and is_cuda_available(): - logging.warning("mps is not available, fallback to cuda") - return "cuda" - - # auto detect device if not specified - if device not in ["cuda", "mps", "cpu"]: - return detect_device() return device