From 3f1244d1561573140041bc0dae0242726da021c6 Mon Sep 17 00:00:00 2001 From: srszzw <741992282@qq.com> Date: Sun, 9 Jun 2024 12:53:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81SQLAlchemy=E5=A4=A7=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=95=B0=E6=8D=AE=E5=BA=93=E3=80=81=E6=96=B0=E5=A2=9E?= =?UTF-8?q?read-only=E6=A8=A1=E5=BC=8F=EF=BC=8C=E6=8F=90=E9=AB=98=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=80=A7=E3=80=81=E5=A2=9E=E5=8A=A0text2sql=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E5=BB=BA=E8=AE=AE=20(#4155)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 1、修改text2sql连接配置,支持SQLAlchemy大部分数据库; 2、新增read-only模式,若有数据库写保护需求,会从大模型判断、SQLAlchemy拦截器两个层面进行写拦截,提高安全性; 3、增加text2sql使用建议; --- .../chatchat/configs/_model_config.py | 22 +++++--- .../server/agent/tools_factory/text2sql.py | 53 ++++++++++++++++--- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index 56809b49..35625d56 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -227,13 +227,23 @@ TOOL_CONFIG = { "text2images": { "use": False, }, + # text2sql使用建议 + # 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估; + # 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务; + # 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作; + # 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试; + # 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构; + # 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。 "text2sql": { "use": False, - #mysql连接信息 - "db_host": "mysql_host", - "db_user": "mysql_user", - "db_password": "mysql_password", - "db_name": "mysql_database_name", + # SQLAlchemy连接字符串,支持的数据库有: + # crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb + # 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql + # 如提示缺少对应数据库的驱动,请自行通过poetry安装 + "sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e", + # 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求 + # 优先推荐从数据库层面对用户权限进行限制 + "read_only": False, #限定返回的行数 "top_k":50, #是否返回中间步骤 @@ -243,7 +253,7 @@ TOOL_CONFIG = { #对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。 "table_comments":{ # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 - # "tableA":"用户表", + # "tableA":"这是一个用户表,存储了用户的基本信息", # "tanleB":"角色表", } }, diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py index 1cb37716..ef8ec878 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py @@ -3,16 +3,38 @@ from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialCha from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field from .tools_registry import regist_tool, BaseToolOutput +from sqlalchemy.exc import OperationalError +from sqlalchemy import event +from langchain_core.prompts.prompt import PromptTemplate +from langchain.chains import LLMChain + +READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode. +Given an input question, determine if the related SQL can be executed in read-only mode. +If the SQL can be executed normally, return Answer:'SQL can be executed normally'. +If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'. +Use the following format: + +Answer: Final answer here + +Question: {query} +""" + +# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字 +def intercept_sql(conn, cursor, statement, parameters, context, executemany): + # List of SQL keywords that indicate a write operation + write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename") + # Check if the statement starts with any of the write operation keywords + if any(statement.strip().lower().startswith(op) for op in write_operations): + raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None) def query_database(query: str, config: dict): - db_user = config["db_user"] - db_password = config["db_password"] - db_host = config["db_host"] - db_name = config["db_name"] top_k = config["top_k"] return_intermediate_steps = config["return_intermediate_steps"] - db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") + sqlalchemy_connect_str = config["sqlalchemy_connect_str"] + read_only = config["read_only"] + db = SQLDatabase.from_uri(sqlalchemy_connect_str) + from chatchat.server.api_server.chat_routes import global_model_name from chatchat.server.utils import get_ChatOpenAI llm = get_ChatOpenAI( @@ -30,9 +52,26 @@ def query_database(query: str, #由于langchain固定了输入参数,所以只能通过query传递额外的表说明 if table_comments: TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n" - table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()]) + table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()]) query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n" - + + if read_only: + # 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。 + READ_ONLY_PROMPT = PromptTemplate( + input_variables=["query"], + template=READ_ONLY_PROMPT_TEMPLATE, + ) + read_only_chain = LLMChain( + prompt=READ_ONLY_PROMPT, + llm=llm, + ) + read_only_result = read_only_chain.invoke(query) + if "SQL cannot be executed normally" in read_only_result["text"]: + return "当前数据库为只读状态,无法满足您的需求!" + + # 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作 + event.listen(db._engine, "before_cursor_execute", intercept_sql) + #如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain #这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源 #如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断