支持SQLAlchemy大部分数据库、新增read-only模式,提高安全性、增加text2sql使用建议 (#4155)

* 1、修改text2sql连接配置,支持SQLAlchemy大部分数据库;
2、新增read-only模式,若有数据库写保护需求,会从大模型判断、SQLAlchemy拦截器两个层面进行写拦截,提高安全性;
3、增加text2sql使用建议;
This commit is contained in:
srszzw 2024-06-09 12:53:54 +08:00 committed by GitHub
parent b1c5bf9c94
commit 3f1244d156
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 13 deletions

View File

@ -227,13 +227,23 @@ TOOL_CONFIG = {
"text2images": {
"use": False,
},
# text2sql使用建议
# 1、因大模型生成的sql可能与预期有偏差请务必在测试环境中进行充分测试、评估
# 2、生产环境中对于查询操作由于不确定查询效率推荐数据库采用主从数据库架构让text2sql连接从数据库防止可能的慢查询影响主业务
# 3、对于写操作应保持谨慎如不需要写操作设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关可切换不同大模型进行测试
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解且应对数据库表名、字段进行详细的备注说明帮助大模型更好理解数据库结构
# 6、若现有数据库表名难于让大模型理解可配置下面table_comments字段补充说明某些表的作用。
"text2sql": {
"use": False,
#mysql连接信息
"db_host": "mysql_host",
"db_user": "mysql_user",
"db_password": "mysql_password",
"db_name": "mysql_database_name",
# SQLAlchemy连接字符串支持的数据库有
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
# 不同的数据库请查询SQLAlchemy修改sqlalchemy_connect_str配置对应的数据库连接如sqlite为sqlite:///数据库文件路径下面示例为mysql
# 如提示缺少对应数据库的驱动请自行通过poetry安装
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
# 优先推荐从数据库层面对用户权限进行限制
"read_only": False,
#限定返回的行数
"top_k":50,
#是否返回中间步骤
@ -243,7 +253,7 @@ TOOL_CONFIG = {
#对表名进行额外说明辅助大模型更好的判断应该使用哪些表尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
"table_comments":{
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
# "tableA":"用户表",
# "tableA":"这是一个用户表,存储了用户的基本信息",
# "tanleB":"角色表",
}
},

View File

@ -3,16 +3,38 @@ from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialCha
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput
from sqlalchemy.exc import OperationalError
from sqlalchemy import event
from langchain_core.prompts.prompt import PromptTemplate
from langchain.chains import LLMChain
READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode.
Given an input question, determine if the related SQL can be executed in read-only mode.
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
Use the following format:
Answer: Final answer here
Question: {query}
"""
# 定义一个拦截器函数来检查SQL语句以支持read-only,可修改下面的write_operations以匹配你使用的数据库写操作关键字
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
# List of SQL keywords that indicate a write operation
write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
# Check if the statement starts with any of the write operation keywords
if any(statement.strip().lower().startswith(op) for op in write_operations):
raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)
def query_database(query: str,
config: dict):
db_user = config["db_user"]
db_password = config["db_password"]
db_host = config["db_host"]
db_name = config["db_name"]
top_k = config["top_k"]
return_intermediate_steps = config["return_intermediate_steps"]
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
read_only = config["read_only"]
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
from chatchat.server.api_server.chat_routes import global_model_name
from chatchat.server.utils import get_ChatOpenAI
llm = get_ChatOpenAI(
@ -30,9 +52,26 @@ def query_database(query: str,
#由于langchain固定了输入参数所以只能通过query传递额外的表说明
if table_comments:
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
if read_only:
# 在read_only下先让大模型判断只读模式是否能满足需求避免后续执行过程报错返回友好提示。
READ_ONLY_PROMPT = PromptTemplate(
input_variables=["query"],
template=READ_ONLY_PROMPT_TEMPLATE,
)
read_only_chain = LLMChain(
prompt=READ_ONLY_PROMPT,
llm=llm,
)
read_only_result = read_only_chain.invoke(query)
if "SQL cannot be executed normally" in read_only_result["text"]:
return "当前数据库为只读状态,无法满足您的需求!"
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
event.listen(db._engine, "before_cursor_execute", intercept_sql)
#如果不指定table_names优先走SQLDatabaseSequentialChain这个链会先预测需要哪些表然后再将相关表输入SQLDatabaseChain
#这是因为如果不指定table_names直接走SQLDatabaseChainLangchain会将全量表结构传递给大模型可能会因token太长从而引发错误也浪费资源
#如果指定了table_names直接走SQLDatabaseChain将特定表结构传递给大模型进行判断