mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-01 11:53:24 +08:00
增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明 (#4154)
* 1、增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明
This commit is contained in:
parent
94524f8479
commit
b1c5bf9c94
@ -227,5 +227,24 @@ TOOL_CONFIG = {
|
||||
"text2images": {
|
||||
"use": False,
|
||||
},
|
||||
|
||||
"text2sql": {
|
||||
"use": False,
|
||||
#mysql连接信息
|
||||
"db_host": "mysql_host",
|
||||
"db_user": "mysql_user",
|
||||
"db_password": "mysql_password",
|
||||
"db_name": "mysql_database_name",
|
||||
#限定返回的行数
|
||||
"top_k":50,
|
||||
#是否返回中间步骤
|
||||
"return_intermediate_steps": True,
|
||||
#如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
|
||||
"table_names":[],
|
||||
#对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
|
||||
"table_comments":{
|
||||
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
|
||||
# "tableA":"用户表",
|
||||
# "tanleB":"角色表",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -10,3 +10,4 @@ from .text2image import text2images
|
||||
|
||||
from .vqa_processor import vqa_processor
|
||||
from .aqa_processor import aqa_processor
|
||||
from .text2sql import text2sql
|
||||
@ -0,0 +1,64 @@
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
def query_database(query: str,
|
||||
config: dict):
|
||||
db_user = config["db_user"]
|
||||
db_password = config["db_password"]
|
||||
db_host = config["db_host"]
|
||||
db_name = config["db_name"]
|
||||
top_k = config["top_k"]
|
||||
return_intermediate_steps = config["return_intermediate_steps"]
|
||||
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
|
||||
from chatchat.server.api_server.chat_routes import global_model_name
|
||||
from chatchat.server.utils import get_ChatOpenAI
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=global_model_name,
|
||||
temperature=0,
|
||||
streaming=True,
|
||||
local_wrap=True,
|
||||
verbose=True
|
||||
)
|
||||
table_names=config["table_names"]
|
||||
table_comments=config["table_comments"]
|
||||
result = None
|
||||
|
||||
#如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
|
||||
#由于langchain固定了输入参数,所以只能通过query传递额外的表说明
|
||||
if table_comments:
|
||||
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
|
||||
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
|
||||
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
|
||||
|
||||
#如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain
|
||||
#这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源
|
||||
#如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断
|
||||
if len(table_names) > 0:
|
||||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})
|
||||
else:
|
||||
#先预测会使用哪些表,然后再将问题和预测的表给大模型
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||
result = db_chain.invoke(query)
|
||||
|
||||
context = f"""查询结果:{result['result']}\n\n"""
|
||||
|
||||
intermediate_steps=result["intermediate_steps"]
|
||||
#如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解
|
||||
if intermediate_steps:
|
||||
if len(intermediate_steps)>2:
|
||||
sql_detail=intermediate_steps[-2:-1][0]["input"]
|
||||
# sql_detail截取从SQLQuery到Answer:之间的内容
|
||||
sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")]
|
||||
context = context+"执行的sql:'"+sql_detail+"'\n\n"
|
||||
return context
|
||||
|
||||
|
||||
@regist_tool(title="Text2Sql")
|
||||
def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")):
|
||||
'''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.'''
|
||||
tool_config = get_tool_config("text2sql")
|
||||
return BaseToolOutput(query_database(query=query, config=tool_config))
|
||||
@ -28,6 +28,8 @@ chat_router.post("/file_chat",
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
|
||||
#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name
|
||||
global_model_name=None
|
||||
|
||||
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||
async def chat_completions(
|
||||
@ -51,6 +53,8 @@ async def chat_completions(
|
||||
for key in list(extra):
|
||||
delattr(body, key)
|
||||
|
||||
global global_model_name
|
||||
global_model_name=body.model
|
||||
# check tools & tool_choice in request body
|
||||
if isinstance(body.tool_choice, str):
|
||||
if t := get_tool(body.tool_choice):
|
||||
|
||||
@ -15,11 +15,11 @@ chatchat-kb = 'chatchat.init_database:main'
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
model-providers = "^0.3.0"
|
||||
langchain = "0.1.5"
|
||||
langchain = "0.1.17"
|
||||
langchainhub = "0.1.14"
|
||||
langchain-community = "0.0.17"
|
||||
langchain-community = "0.0.36"
|
||||
langchain-openai = "0.0.5"
|
||||
langchain-experimental = "0.0.50"
|
||||
langchain-experimental = "0.0.58"
|
||||
fastapi = "~0.109.2"
|
||||
sse_starlette = "~1.8.2"
|
||||
nltk = "~3.8.1"
|
||||
@ -51,12 +51,13 @@ python-multipart = "0.0.9"
|
||||
streamlit = "1.34.0"
|
||||
streamlit-option-menu = "0.3.12"
|
||||
streamlit-antd-components = "0.3.1"
|
||||
streamlit-chatbox = "1.1.12"
|
||||
streamlit-chatbox = "1.1.12.post2"
|
||||
streamlit-modal = "0.1.0"
|
||||
streamlit-aggrid = "0.3.4.post3"
|
||||
streamlit-extras = "0.4.2"
|
||||
xinference_client = { version = "^0.11.1", optional = true }
|
||||
zhipuai = { version = "^2.1.0", optional = true }
|
||||
pymysql = "^1.1.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
xinference = ["xinference_client"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user