From 65466007aeb779416ef5dd75c42a73524b3ecaa5 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 13 Feb 2024 21:08:15 +0800 Subject: [PATCH] make torch & transformers optional import pydantic Model & Field from langchain.pydantic_v1 instead of pydantic.v1 --- configs/model_config.py.example | 19 +++++- requirements.txt | 58 ++++++------------- server/agent/container.py | 54 +++++++++-------- server/agent/tools_factory/arxiv.py | 2 +- .../agent/tools_factory/audio_factory/aqa.py | 2 +- server/agent/tools_factory/calculate.py | 2 +- server/agent/tools_factory/search_internet.py | 2 +- .../search_local_knowledgebase.py | 2 +- server/agent/tools_factory/search_youtube.py | 2 +- server/agent/tools_factory/shell.py | 2 +- .../agent/tools_factory/vision_factory/vqa.py | 5 +- server/agent/tools_factory/weather_check.py | 2 +- server/agent/tools_factory/wolfram.py | 2 +- .../kb_service/faiss_kb_service.py | 4 +- server/utils.py | 52 ++++++----------- webui_pages/dialogue/dialogue.py | 6 +- 16 files changed, 101 insertions(+), 115 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 19e7db7a..433ec60d 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -26,6 +26,7 @@ SUPPORT_AGENT_MODELS = [ "openai-api", "Qwen-14B-Chat", "Qwen-7B-Chat", + "qwen-turbo", ] @@ -83,6 +84,9 @@ MODEL_PLATFORMS = [ "llm_models": [ "gpt-3.5-turbo", ], + "embed_models": [], + "image_models": [], + "multimodal_models": [], "api_base_url": "https://api.openai.com/v1", "api_key": "sk-", "api_proxy": "", @@ -112,8 +116,16 @@ MODEL_PLATFORMS = [ "platform_type": "oneapi", "api_key": "", "llm_models": [ - "chatglm3-6b", + "qwen-turbo", + "qwen-plus", + "chatglm_turbo", + "chatglm_std", ], + "embed_models": [], + "image_models": [], + "multimodal_models": [], + "api_base_url": "http://127.0.0.1:3000/v1", + "api_key": "sk-xxx", }, { @@ -123,6 +135,11 @@ MODEL_PLATFORMS = [ "llm_models": [ "chatglm3-6b", ], + "embed_models": [], + "image_models": [], + "multimodal_models": [], + "api_base_url": "http://127.0.0.1:7860/v1", + "api_key": "EMPTY", }, ] diff --git a/requirements.txt b/requirements.txt index 9a895b7c..97fa093b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,46 +1,24 @@ # API requirements - -# Torch requiremnts, install the cuda version manually from https://pytorch.org/ -torch>=2.1.2 -torchvision>=0.16.2 -torchaudio>=2.1.2 - -# Langchain 0.1.x requirements -langchain>=0.1.0 -langchain_openai>=0.0.2 -langchain-community>=0.0.11 -langchainhub>=0.1.14 - -pydantic==1.10.13 -fschat==0.2.35 -openai==1.9.0 -fastapi==0.109.0 -sse_starlette==1.8.2 -nltk==3.8.1 +langchain==0.1.5 +langchainhub==0.1.14 +langchain-community==0.0.17 +langchain-openai==0.0.5 +langchain-experimental==0.0.50 +fastapi==0.109.2 +sse_starlette~=1.8.2 +nltk~=3.8.1 uvicorn>=0.27.0.post1 -starlette==0.35.0 -unstructured[all-docs] # ==0.11.8 +unstructured[]~=0.12.0 python-magic-bin; sys_platform == 'win32' -SQLAlchemy==2.0.25 -faiss-cpu==1.7.4 -accelerate==0.24.1 -spacy==3.7.2 -PyMuPDF==1.23.16 -rapidocr_onnxruntime==1.3.8 -requests==2.31.0 -pathlib==1.0.1 -pytest==7.4.3 -numexpr==2.8.6 -strsimpy==0.2.1 -markdownify==0.11.6 -tiktoken==0.5.2 -tqdm==4.66.1 -websockets==12.0 -numpy==1.24.4 -pandas==2.0.3 -einops==0.7.0 -transformers_stream_generator==0.0.4 -vllm==0.2.7; sys_platform == "linux" +SQLAlchemy~=2.0.25 +faiss-cpu~=1.7.4 +# accelerate~=0.24.1 +# spacy~=3.7.2 +PyMuPDF~=1.23.16 +rapidocr_onnxruntime~=1.3.8 +requests~=2.31.0 +pathlib~=1.0.1 +pytest~=7.4.3 llama-index==0.9.35 # jq==1.6.0 diff --git a/server/agent/container.py b/server/agent/container.py index 7623f1a2..9dfe7f70 100644 --- a/server/agent/container.py +++ b/server/agent/container.py @@ -1,6 +1,4 @@ -from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer -from configs import TOOL_CONFIG -import torch +from configs import TOOL_CONFIG, logger class ModelContainer: @@ -14,28 +12,38 @@ class ModelContainer: self.audio_model = None if TOOL_CONFIG["vqa_processor"]["use"]: - self.vision_tokenizer = LlamaTokenizer.from_pretrained( - TOOL_CONFIG["vqa_processor"]["tokenizer_path"], - trust_remote_code=True) - self.vision_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"], - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True - ).to(TOOL_CONFIG["vqa_processor"]["device"]).eval() + try: + from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer + import torch + self.vision_tokenizer = LlamaTokenizer.from_pretrained( + TOOL_CONFIG["vqa_processor"]["tokenizer_path"], + trust_remote_code=True) + self.vision_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"], + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True + ).to(TOOL_CONFIG["vqa_processor"]["device"]).eval() + except Exception as e: + logger.error(e, exc_info=True) if TOOL_CONFIG["aqa_processor"]["use"]: - self.audio_tokenizer = AutoTokenizer.from_pretrained( - TOOL_CONFIG["aqa_processor"]["tokenizer_path"], - trust_remote_code=True - ) - self.audio_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"], - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True).to( - TOOL_CONFIG["aqa_processor"]["device"] - ).eval() + try: + from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer + import torch + self.audio_tokenizer = AutoTokenizer.from_pretrained( + TOOL_CONFIG["aqa_processor"]["tokenizer_path"], + trust_remote_code=True + ) + self.audio_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"], + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True).to( + TOOL_CONFIG["aqa_processor"]["device"] + ).eval() + except Exception as e: + logger.error(e, exc_info=True) container = ModelContainer() diff --git a/server/agent/tools_factory/arxiv.py b/server/agent/tools_factory/arxiv.py index c8ee5ffc..43129463 100644 --- a/server/agent/tools_factory/arxiv.py +++ b/server/agent/tools_factory/arxiv.py @@ -1,5 +1,5 @@ # LangChain 的 ArxivQueryRun 工具 -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.arxiv.tool import ArxivQueryRun def arxiv(query: str): tool = ArxivQueryRun() diff --git a/server/agent/tools_factory/audio_factory/aqa.py b/server/agent/tools_factory/audio_factory/aqa.py index 795690c6..4d053ecc 100644 --- a/server/agent/tools_factory/audio_factory/aqa.py +++ b/server/agent/tools_factory/audio_factory/aqa.py @@ -1,6 +1,6 @@ import base64 import os -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field def save_base64_audio(base64_audio, file_path): audio_data = base64.b64decode(base64_audio) diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py index a308cb43..c893e548 100644 --- a/server/agent/tools_factory/calculate.py +++ b/server/agent/tools_factory/calculate.py @@ -1,4 +1,4 @@ -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field def calculate(a: float, b: float, operator: str) -> float: if operator == "+": diff --git a/server/agent/tools_factory/search_internet.py b/server/agent/tools_factory/search_internet.py index 46ef28c3..870156db 100644 --- a/server/agent/tools_factory/search_internet.py +++ b/server/agent/tools_factory/search_internet.py @@ -1,4 +1,4 @@ -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from configs import TOOL_CONFIG diff --git a/server/agent/tools_factory/search_local_knowledgebase.py b/server/agent/tools_factory/search_local_knowledgebase.py index eca3709c..ff9614e1 100644 --- a/server/agent/tools_factory/search_local_knowledgebase.py +++ b/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,5 +1,5 @@ from urllib.parse import urlencode -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field from server.knowledge_base.kb_doc_api import search_docs from configs import TOOL_CONFIG diff --git a/server/agent/tools_factory/search_youtube.py b/server/agent/tools_factory/search_youtube.py index 34a735ec..e7737eb4 100644 --- a/server/agent/tools_factory/search_youtube.py +++ b/server/agent/tools_factory/search_youtube.py @@ -1,5 +1,5 @@ from langchain_community.tools import YouTubeSearchTool -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field def search_youtube(query: str): tool = YouTubeSearchTool() return tool.run(tool_input=query) diff --git a/server/agent/tools_factory/shell.py b/server/agent/tools_factory/shell.py index e7c41a4a..c8f7ddfe 100644 --- a/server/agent/tools_factory/shell.py +++ b/server/agent/tools_factory/shell.py @@ -1,5 +1,5 @@ # LangChain 的 Shell 工具 -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field from langchain_community.tools import ShellTool def shell(query: str): tool = ShellTool() diff --git a/server/agent/tools_factory/vision_factory/vqa.py b/server/agent/tools_factory/vision_factory/vqa.py index d1c6a632..bbf57a8b 100644 --- a/server/agent/tools_factory/vision_factory/vqa.py +++ b/server/agent/tools_factory/vision_factory/vqa.py @@ -3,9 +3,8 @@ Method Use cogagent to generate response for a given image and query. """ import base64 from io import BytesIO -import torch from PIL import Image, ImageDraw -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field from configs import TOOL_CONFIG import re from server.agent.container import container @@ -72,6 +71,8 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m temperature (float): temperature top_k (int): top k """ + import torch + image = Image.open(BytesIO(base64.b64decode(image_base_64))) inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) diff --git a/server/agent/tools_factory/weather_check.py b/server/agent/tools_factory/weather_check.py index c88d2786..db52860f 100644 --- a/server/agent/tools_factory/weather_check.py +++ b/server/agent/tools_factory/weather_check.py @@ -1,7 +1,7 @@ """ 简单的单参数输入工具实现,用于查询现在天气的情况 """ -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field import requests def weather(location: str, api_key: str): diff --git a/server/agent/tools_factory/wolfram.py b/server/agent/tools_factory/wolfram.py index 784cfee4..45ef0f0a 100644 --- a/server/agent/tools_factory/wolfram.py +++ b/server/agent/tools_factory/wolfram.py @@ -1,6 +1,6 @@ # Langchain 自带的 Wolfram Alpha API 封装 from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import BaseModel, Field wolfram_alpha_appid = "your key" def wolfram(query: str): wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 65cccbed..0964fdfe 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -4,8 +4,7 @@ import shutil from configs import SCORE_THRESHOLD from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss -from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path -from server.utils import torch_gc +from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path, EmbeddingsFunAdapter from langchain.docstore.document import Document from typing import List, Dict, Optional, Tuple @@ -83,7 +82,6 @@ class FaissKBService(KBService): if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] - torch_gc() return doc_infos def do_delete_doc(self, diff --git a/server/utils.py b/server/utils.py index 05b68bbb..61723f08 100644 --- a/server/utils.py +++ b/server/utils.py @@ -22,7 +22,7 @@ from typing import ( ) import logging -from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL +from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL, TEMPERATURE from server.minx_chat_openai import MinxChatOpenAI @@ -101,7 +101,7 @@ def get_model_info(model_name: str, platform_name: str = None) -> Dict: def get_ChatOpenAI( model_name: str, - temperature: float, + temperature: float = TEMPERATURE, max_tokens: int = None, streaming: bool = True, callbacks: List[Callable] = [], @@ -109,18 +109,22 @@ def get_ChatOpenAI( **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) - model = ChatOpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - openai_api_key=model_info.get("api_key"), - openai_api_base=model_info.get("api_base_url"), - openai_proxy=model_info.get("api_proxy"), - **kwargs - ) + try: + model = ChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + openai_api_key=model_info.get("api_key"), + openai_api_base=model_info.get("api_base_url"), + openai_proxy=model_info.get("api_proxy"), + **kwargs + ) + except Exception as e: + logger.error(f"failed to create ChatOpenAI for model: {model_name}.", exc_info=True) + model = None return model @@ -238,26 +242,6 @@ class ChatMessage(BaseModel): } -def torch_gc(): - try: - import torch - if torch.cuda.is_available(): - # with torch.cuda.device(DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - elif torch.backends.mps.is_available(): - try: - from torch.mps import empty_cache - empty_cache() - except Exception as e: - msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," - "以支持及时清理 torch 产生的内存占用。") - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - except Exception: - ... - - def run_async(cor): ''' 在同步环境中运行异步代码. diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 9aedb5ac..134d6e14 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -129,9 +129,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.use_chat_name(conversation_name) conversation_id = st.session_state["conversation_ids"][conversation_name] - platforms = [x["platform_name"] for x in MODEL_PLATFORMS] - platform = st.selectbox("选择模型平台", platforms, 1) - llm_models = list(get_config_models(model_type="llm", platform_name=platform)) + platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] + platform = st.selectbox("选择模型平台", platforms) + llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform)) llm_model = st.selectbox("选择LLM模型", llm_models) # 传入后端的内容