make torch & transformers optional

import pydantic Model & Field from langchain.pydantic_v1 instead of pydantic.v1
This commit is contained in:
liunux4odoo 2024-02-13 21:08:15 +08:00
parent 73eb5e2e32
commit 65466007ae
16 changed files with 101 additions and 115 deletions

View File

@ -26,6 +26,7 @@ SUPPORT_AGENT_MODELS = [
"openai-api",
"Qwen-14B-Chat",
"Qwen-7B-Chat",
"qwen-turbo",
]
@ -83,6 +84,9 @@ MODEL_PLATFORMS = [
"llm_models": [
"gpt-3.5-turbo",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "https://api.openai.com/v1",
"api_key": "sk-",
"api_proxy": "",
@ -112,8 +116,16 @@ MODEL_PLATFORMS = [
"platform_type": "oneapi",
"api_key": "",
"llm_models": [
"chatglm3-6b",
"qwen-turbo",
"qwen-plus",
"chatglm_turbo",
"chatglm_std",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "http://127.0.0.1:3000/v1",
"api_key": "sk-xxx",
},
{
@ -123,6 +135,11 @@ MODEL_PLATFORMS = [
"llm_models": [
"chatglm3-6b",
],
"embed_models": [],
"image_models": [],
"multimodal_models": [],
"api_base_url": "http://127.0.0.1:7860/v1",
"api_key": "EMPTY",
},
]

View File

@ -1,46 +1,24 @@
# API requirements
# Torch requiremnts, install the cuda version manually from https://pytorch.org/
torch>=2.1.2
torchvision>=0.16.2
torchaudio>=2.1.2
# Langchain 0.1.x requirements
langchain>=0.1.0
langchain_openai>=0.0.2
langchain-community>=0.0.11
langchainhub>=0.1.14
pydantic==1.10.13
fschat==0.2.35
openai==1.9.0
fastapi==0.109.0
sse_starlette==1.8.2
nltk==3.8.1
langchain==0.1.5
langchainhub==0.1.14
langchain-community==0.0.17
langchain-openai==0.0.5
langchain-experimental==0.0.50
fastapi==0.109.2
sse_starlette~=1.8.2
nltk~=3.8.1
uvicorn>=0.27.0.post1
starlette==0.35.0
unstructured[all-docs] # ==0.11.8
unstructured[]~=0.12.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.25
faiss-cpu==1.7.4
accelerate==0.24.1
spacy==3.7.2
PyMuPDF==1.23.16
rapidocr_onnxruntime==1.3.8
requests==2.31.0
pathlib==1.0.1
pytest==7.4.3
numexpr==2.8.6
strsimpy==0.2.1
markdownify==0.11.6
tiktoken==0.5.2
tqdm==4.66.1
websockets==12.0
numpy==1.24.4
pandas==2.0.3
einops==0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"
SQLAlchemy~=2.0.25
faiss-cpu~=1.7.4
# accelerate~=0.24.1
# spacy~=3.7.2
PyMuPDF~=1.23.16
rapidocr_onnxruntime~=1.3.8
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
llama-index==0.9.35
# jq==1.6.0

View File

@ -1,6 +1,4 @@
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from configs import TOOL_CONFIG
import torch
from configs import TOOL_CONFIG, logger
class ModelContainer:
@ -14,28 +12,38 @@ class ModelContainer:
self.audio_model = None
if TOOL_CONFIG["vqa_processor"]["use"]:
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
except Exception as e:
logger.error(e, exc_info=True)
if TOOL_CONFIG["aqa_processor"]["use"]:
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
trust_remote_code=True
)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(
TOOL_CONFIG["aqa_processor"]["device"]
).eval()
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
trust_remote_code=True
)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(
TOOL_CONFIG["aqa_processor"]["device"]
).eval()
except Exception as e:
logger.error(e, exc_info=True)
container = ModelContainer()

View File

@ -1,5 +1,5 @@
# LangChain 的 ArxivQueryRun 工具
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
def arxiv(query: str):
tool = ArxivQueryRun()

View File

@ -1,6 +1,6 @@
import base64
import os
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
def save_base64_audio(base64_audio, file_path):
audio_data = base64.b64decode(base64_audio)

View File

@ -1,4 +1,4 @@
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
def calculate(a: float, b: float, operator: str) -> float:
if operator == "+":

View File

@ -1,4 +1,4 @@
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import TOOL_CONFIG

View File

@ -1,5 +1,5 @@
from urllib.parse import urlencode
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from server.knowledge_base.kb_doc_api import search_docs
from configs import TOOL_CONFIG

View File

@ -1,5 +1,5 @@
from langchain_community.tools import YouTubeSearchTool
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
def search_youtube(query: str):
tool = YouTubeSearchTool()
return tool.run(tool_input=query)

View File

@ -1,5 +1,5 @@
# LangChain 的 Shell 工具
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from langchain_community.tools import ShellTool
def shell(query: str):
tool = ShellTool()

View File

@ -3,9 +3,8 @@ Method Use cogagent to generate response for a given image and query.
"""
import base64
from io import BytesIO
import torch
from PIL import Image, ImageDraw
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from configs import TOOL_CONFIG
import re
from server.agent.container import container
@ -72,6 +71,8 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
temperature (float): temperature
top_k (int): top k
"""
import torch
image = Image.open(BytesIO(base64.b64decode(image_base_64)))
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])

View File

@ -1,7 +1,7 @@
"""
简单的单参数输入工具实现用于查询现在天气的情况
"""
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
import requests
def weather(location: str, api_key: str):

View File

@ -1,6 +1,6 @@
# Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from pydantic.v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
wolfram_alpha_appid = "your key"
def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)

View File

@ -4,8 +4,7 @@ import shutil
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path, EmbeddingsFunAdapter
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
@ -83,7 +82,6 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
return doc_infos
def do_delete_doc(self,

View File

@ -22,7 +22,7 @@ from typing import (
)
import logging
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL, TEMPERATURE
from server.minx_chat_openai import MinxChatOpenAI
@ -101,7 +101,7 @@ def get_model_info(model_name: str, platform_name: str = None) -> Dict:
def get_ChatOpenAI(
model_name: str,
temperature: float,
temperature: float = TEMPERATURE,
max_tokens: int = None,
streaming: bool = True,
callbacks: List[Callable] = [],
@ -109,18 +109,22 @@ def get_ChatOpenAI(
**kwargs: Any,
) -> ChatOpenAI:
model_info = get_model_info(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
**kwargs
)
try:
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
**kwargs
)
except Exception as e:
logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True)
model = None
return model
@ -238,26 +242,6 @@ class ChatMessage(BaseModel):
}
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
"以支持及时清理 torch 产生的内存占用。")
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
except Exception:
...
def run_async(cor):
'''
在同步环境中运行异步代码.

View File

@ -129,9 +129,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
platforms = [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms, 1)
llm_models = list(get_config_models(model_type="llm", platform_name=platform))
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms)
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = st.selectbox("选择LLM模型", llm_models)
# 传入后端的内容