mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 05:33:12 +08:00
ChatOpenAI为了判断token有没有超过模型的context上下文长度,每个模型的token算法不一样 ,所以这里应该自己实现token长度计算
第一次初始化的时候,openai的类会强制使用3.5,
This commit is contained in:
parent
0a37fe93b8
commit
f031ebc19e
51
server/minx_chat_openai.py
Normal file
51
server/minx_chat_openai.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
class MinxChatOpenAI:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def import_tiktoken() -> Any:
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_token_ids. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
return tiktoken
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
|
||||||
|
tiktoken_ = MinxChatOpenAI.import_tiktoken()
|
||||||
|
if self.tiktoken_model_name is not None:
|
||||||
|
model = self.tiktoken_model_name
|
||||||
|
else:
|
||||||
|
model = self.model_name
|
||||||
|
if model == "gpt-3.5-turbo":
|
||||||
|
# gpt-3.5-turbo may change over time.
|
||||||
|
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||||
|
model = "gpt-3.5-turbo-0301"
|
||||||
|
elif model == "gpt-4":
|
||||||
|
# gpt-4 may change over time.
|
||||||
|
# Returning num tokens assuming gpt-4-0314.
|
||||||
|
model = "gpt-4-0314"
|
||||||
|
# Returns the number of tokens used by a list of messages.
|
||||||
|
try:
|
||||||
|
encoding = tiktoken_.encoding_for_model(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken_.get_encoding(model)
|
||||||
|
return model, encoding
|
||||||
@ -12,10 +12,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Callable,
|
||||||
|
Generator,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Union,
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from server.minx_chat_openai import MinxChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||||
@ -44,7 +57,7 @@ def get_ChatOpenAI(
|
|||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
if model_name == "openai-api":
|
if model_name == "openai-api":
|
||||||
model_name = config.get("model_name")
|
model_name = config.get("model_name")
|
||||||
|
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@ -153,6 +166,7 @@ class ChatMessage(BaseModel):
|
|||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
try:
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# with torch.cuda.device(DEVICE):
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -500,58 +514,29 @@ def set_httpx_config(
|
|||||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
|
||||||
|
|
||||||
def is_mps_available():
|
|
||||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_available():
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
try:
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
if is_mps_available():
|
if torch.backends.mps.is_available():
|
||||||
return "mps"
|
return "mps"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or LLM_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
|
||||||
logging.warning("cuda is not available, fallback to mps")
|
|
||||||
return "mps"
|
|
||||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
|
||||||
logging.warning("mps is not available, fallback to cuda")
|
|
||||||
return "cuda"
|
|
||||||
|
|
||||||
# auto detect device if not specified
|
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
return detect_device()
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or EMBEDDING_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
|
||||||
logging.warning("cuda is not available, fallback to mps")
|
|
||||||
return "mps"
|
|
||||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
|
||||||
logging.warning("mps is not available, fallback to cuda")
|
|
||||||
return "cuda"
|
|
||||||
|
|
||||||
# auto detect device if not specified
|
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
|
||||||
return detect_device()
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user