mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
支持SQLAlchemy大部分数据库、新增read-only模式,提高安全性、增加text2sql使用建议 (#4155)
* 1、修改text2sql连接配置,支持SQLAlchemy大部分数据库; 2、新增read-only模式,若有数据库写保护需求,会从大模型判断、SQLAlchemy拦截器两个层面进行写拦截,提高安全性; 3、增加text2sql使用建议;
This commit is contained in:
parent
b1c5bf9c94
commit
3f1244d156
@ -227,13 +227,23 @@ TOOL_CONFIG = {
|
||||
"text2images": {
|
||||
"use": False,
|
||||
},
|
||||
# text2sql使用建议
|
||||
# 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估;
|
||||
# 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务;
|
||||
# 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
|
||||
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试;
|
||||
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构;
|
||||
# 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。
|
||||
"text2sql": {
|
||||
"use": False,
|
||||
#mysql连接信息
|
||||
"db_host": "mysql_host",
|
||||
"db_user": "mysql_user",
|
||||
"db_password": "mysql_password",
|
||||
"db_name": "mysql_database_name",
|
||||
# SQLAlchemy连接字符串,支持的数据库有:
|
||||
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
|
||||
# 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql
|
||||
# 如提示缺少对应数据库的驱动,请自行通过poetry安装
|
||||
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
|
||||
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
|
||||
# 优先推荐从数据库层面对用户权限进行限制
|
||||
"read_only": False,
|
||||
#限定返回的行数
|
||||
"top_k":50,
|
||||
#是否返回中间步骤
|
||||
@ -243,7 +253,7 @@ TOOL_CONFIG = {
|
||||
#对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
|
||||
"table_comments":{
|
||||
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
|
||||
# "tableA":"用户表",
|
||||
# "tableA":"这是一个用户表,存储了用户的基本信息",
|
||||
# "tanleB":"角色表",
|
||||
}
|
||||
},
|
||||
|
||||
@ -3,16 +3,38 @@ from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialCha
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy import event
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode.
|
||||
Given an input question, determine if the related SQL can be executed in read-only mode.
|
||||
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
|
||||
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
|
||||
Use the following format:
|
||||
|
||||
Answer: Final answer here
|
||||
|
||||
Question: {query}
|
||||
"""
|
||||
|
||||
# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字
|
||||
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
|
||||
# List of SQL keywords that indicate a write operation
|
||||
write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
|
||||
# Check if the statement starts with any of the write operation keywords
|
||||
if any(statement.strip().lower().startswith(op) for op in write_operations):
|
||||
raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)
|
||||
|
||||
def query_database(query: str,
|
||||
config: dict):
|
||||
db_user = config["db_user"]
|
||||
db_password = config["db_password"]
|
||||
db_host = config["db_host"]
|
||||
db_name = config["db_name"]
|
||||
top_k = config["top_k"]
|
||||
return_intermediate_steps = config["return_intermediate_steps"]
|
||||
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
|
||||
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
|
||||
read_only = config["read_only"]
|
||||
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
|
||||
|
||||
from chatchat.server.api_server.chat_routes import global_model_name
|
||||
from chatchat.server.utils import get_ChatOpenAI
|
||||
llm = get_ChatOpenAI(
|
||||
@ -30,9 +52,26 @@ def query_database(query: str,
|
||||
#由于langchain固定了输入参数,所以只能通过query传递额外的表说明
|
||||
if table_comments:
|
||||
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
|
||||
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
|
||||
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
|
||||
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
|
||||
|
||||
|
||||
if read_only:
|
||||
# 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。
|
||||
READ_ONLY_PROMPT = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template=READ_ONLY_PROMPT_TEMPLATE,
|
||||
)
|
||||
read_only_chain = LLMChain(
|
||||
prompt=READ_ONLY_PROMPT,
|
||||
llm=llm,
|
||||
)
|
||||
read_only_result = read_only_chain.invoke(query)
|
||||
if "SQL cannot be executed normally" in read_only_result["text"]:
|
||||
return "当前数据库为只读状态,无法满足您的需求!"
|
||||
|
||||
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
|
||||
event.listen(db._engine, "before_cursor_execute", intercept_sql)
|
||||
|
||||
#如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain
|
||||
#这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源
|
||||
#如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user