From 5df19d907b2e32afa0c792892f4c449a3bb4fa8f Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 7 Jan 2024 17:31:58 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B7=9F=E6=96=B0=E4=BA=86langchain=200.1.x?= =?UTF-8?q?=E9=9C=80=E8=A6=81=E7=9A=84=E4=BE=9D=E8=B5=96=E5=92=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 6 ++-- server/agent/agent_factory/glm3_agent.py | 7 +---- server/agent/agent_instruct.md | 32 ++++++++++++++++++++ server/agent/tools_factory/calculate.py | 2 +- server/agent/tools_factory/search_youtube.py | 3 +- server/agent/tools_factory/shell.py | 2 +- server/chat/chat.py | 14 +++------ server/knowledge_base/utils.py | 9 +++--- server/utils.py | 4 +-- update_requirements.sh | 7 +++++ 10 files changed, 57 insertions(+), 29 deletions(-) create mode 100644 server/agent/agent_instruct.md create mode 100644 update_requirements.sh diff --git a/requirements.txt b/requirements.txt index 7bb097d8..550c2b56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,10 @@ torch>=2.1.2 torchvision>=0.16.2 torchaudio>=2.1.2 -langchain>=0.0.352 -langchain-experimental>=0.0.47 +langchain>=0.1.0 +langchain_openai>=0.0.2 +langchain-community>=1.0.0 + pydantic==1.10.13 fschat==0.2.35 openai==1.9.0 diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index 4be21c1b..6a9a2e89 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -1,8 +1,8 @@ """ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. """ -from __future__ import annotations +from __future__ import annotations import json import logging from typing import Any, List, Sequence, Tuple, Optional, Union @@ -21,9 +21,7 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool -from langchain_core.callbacks import Callbacks -HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" logger = logging.getLogger(__name__) @@ -148,7 +146,6 @@ class StructuredGLM3ChatAgent(Agent): formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") template = prompt.format(tool_names=tool_names, tools=formatted_tools, - history="None", input="{input}", agent_scratchpad="{agent_scratchpad}") @@ -169,7 +166,6 @@ class StructuredGLM3ChatAgent(Agent): prompt: str = None, callbacks: List[BaseCallbackHandler] = [], output_parser: Optional[AgentOutputParser] = None, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, input_variables: Optional[List[str]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None, **kwargs: Any, @@ -229,6 +225,5 @@ def initialize_glm3_agent( callbacks=callbacks, memory=memory, tags=tags_, - intermediate_steps=[], **kwargs, ) diff --git a/server/agent/agent_instruct.md b/server/agent/agent_instruct.md new file mode 100644 index 00000000..170c25e2 --- /dev/null +++ b/server/agent/agent_instruct.md @@ -0,0 +1,32 @@ +# What tools should be used + +# search_internet + +使用这个工具是因为用户需要在联网进行搜索。这些问题通常是你不知道的,这些问题具有特点, + +例如: + ++ 联网帮我查询 xxx ++ 我想知道最新的新闻 + +或者,用户有明显的意图,需要获取事实的信息。 + +返回字段如下 + +``` +search_internet +``` + +# search_local_knowledge + +使用这个工具是希望用户能够获取本地的知识,这些知识通常是你自身能力不具备的专业问题,或者用户指定了某个任务的。 + +例如: + ++ 告诉我 关于 xxx 的 xxx 信息 ++ xxx 中 xxx 的 xxx 是什么 + +返回字段如下 +``` +search_local_knowledge +``` \ No newline at end of file diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py index a47d65ca..476784b5 100644 --- a/server/agent/tools_factory/calculate.py +++ b/server/agent/tools_factory/calculate.py @@ -11,7 +11,7 @@ def calculate(a: float, b: float, operator: str) -> float: if b != 0: return a / b else: - return float('inf') # 防止除以零 + return float('inf') elif operator == "^": return a ** b else: diff --git a/server/agent/tools_factory/search_youtube.py b/server/agent/tools_factory/search_youtube.py index 57049897..28b436b6 100644 --- a/server/agent/tools_factory/search_youtube.py +++ b/server/agent/tools_factory/search_youtube.py @@ -1,5 +1,4 @@ -# Langchain 自带的 YouTube 搜索工具封装 -from langchain.tools import YouTubeSearchTool +from langchain_community.tools import YouTubeSearchTool from pydantic import BaseModel, Field def search_youtube(query: str): tool = YouTubeSearchTool() diff --git a/server/agent/tools_factory/shell.py b/server/agent/tools_factory/shell.py index 01046559..7cf0ad2e 100644 --- a/server/agent/tools_factory/shell.py +++ b/server/agent/tools_factory/shell.py @@ -1,6 +1,6 @@ # LangChain 的 Shell 工具 from pydantic import BaseModel, Field -from langchain.tools import ShellTool +from langchain_community.tools import ShellTool def shell(query: str): tool = ShellTool() return tool.run(tool_input=query) diff --git a/server/chat/chat.py b/server/chat/chat.py index e2b6ae44..a9f6c60b 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -5,7 +5,8 @@ from typing import List, Union, AsyncIterable, Dict from fastapi import Body from fastapi.responses import StreamingResponse -from langchain.agents import initialize_agent, AgentType +from langchain.agents import initialize_agent, AgentType, create_structured_chat_agent, AgentExecutor +from langchain_core.messages import HumanMessage, AIMessage from langchain_core.output_parsers import StrOutputParser from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate @@ -83,7 +84,6 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks llm=models["action_model"], tools=tools, prompt=prompts["action_model"], - memory=memory, callbacks=callbacks, verbose=True, ) @@ -92,7 +92,6 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks llm=models["action_model"], tools=tools, prompt=prompts["action_model"], - memory=memory, callbacks=callbacks, verbose=True, ) @@ -102,7 +101,6 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks tools=tools, callbacks=callbacks, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - memory=memory, verbose=True, ) @@ -111,7 +109,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks # chain # ) # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) - full_chain = ({"input": lambda x: x["input"]} | agent_executor) + full_chain = ({"input": lambda x: x["input"], } | agent_executor) else: chain.llm.callbacks = callbacks full_chain = ({"input": lambda x: x["input"]} | chain) @@ -146,14 +144,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] - - # 从配置中选择模型 models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream) - - # 从配置中选择工具 tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] - # 构建完整的Chain full_chain = create_models_chains(prompts=prompts, models=models, conversation_id=conversation_id, @@ -163,6 +156,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 history_len=history_len, metadata=metadata) task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done)) + async for chunk in callback.aiter(): data = json.loads(chunk) data["message_id"] = message_id diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 61b3625c..3c077008 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -17,8 +17,9 @@ from langchain.text_splitter import TextSplitter from pathlib import Path from server.utils import run_in_thread_pool, get_model_worker_config import json -from typing import List, Union,Dict, Tuple, Generator +from typing import List, Union, Dict, Tuple, Generator import chardet +from langchain_community.document_loaders import JSONLoader def validate_kb_name(knowledge_base_id: str) -> bool: @@ -122,15 +123,13 @@ def _new_json_dumps(obj, **kwargs): kwargs["ensure_ascii"] = False return _origin_json_dumps(obj, **kwargs) + if json.dumps is not _new_json_dumps: _origin_json_dumps = json.dumps json.dumps = _new_json_dumps -class JSONLinesLoader(langchain.document_loaders.JSONLoader): - ''' - 行式 Json 加载器,要求文件扩展名为 .jsonl - ''' +class JSONLinesLoader(JSONLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._json_lines = True diff --git a/server/utils.py b/server/utils.py index fa3ec0ae..990cd13a 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,8 +9,8 @@ from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE, FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os from concurrent.futures import ThreadPoolExecutor, as_completed -from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI +from langchain_openai.chat_models import ChatOpenAI +from langchain_community.llms import OpenAI import httpx from typing import ( TYPE_CHECKING, diff --git a/update_requirements.sh b/update_requirements.sh new file mode 100644 index 00000000..fcb2bfd4 --- /dev/null +++ b/update_requirements.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +python -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple + +while read requirement; do + python -m pip install --upgrade "$requirement" -i https://pypi.tuna.tsinghua.edu.cn/simple +done < requirements.txt