diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index 2762628a..56809b49 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -227,5 +227,24 @@ TOOL_CONFIG = { "text2images": { "use": False, }, - + "text2sql": { + "use": False, + #mysql连接信息 + "db_host": "mysql_host", + "db_user": "mysql_user", + "db_password": "mysql_password", + "db_name": "mysql_database_name", + #限定返回的行数 + "top_k":50, + #是否返回中间步骤 + "return_intermediate_steps": True, + #如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表 + "table_names":[], + #对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。 + "table_comments":{ + # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 + # "tableA":"用户表", + # "tanleB":"角色表", + } + }, } diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py index 8faaedb7..2242434f 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py @@ -10,3 +10,4 @@ from .text2image import text2images from .vqa_processor import vqa_processor from .aqa_processor import aqa_processor +from .text2sql import text2sql \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py new file mode 100644 index 00000000..1cb37716 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py @@ -0,0 +1,64 @@ +from langchain.utilities import SQLDatabase +from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain +from chatchat.server.utils import get_tool_config +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool, BaseToolOutput + +def query_database(query: str, + config: dict): + db_user = config["db_user"] + db_password = config["db_password"] + db_host = config["db_host"] + db_name = config["db_name"] + top_k = config["top_k"] + return_intermediate_steps = config["return_intermediate_steps"] + db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") + from chatchat.server.api_server.chat_routes import global_model_name + from chatchat.server.utils import get_ChatOpenAI + llm = get_ChatOpenAI( + model_name=global_model_name, + temperature=0, + streaming=True, + local_wrap=True, + verbose=True + ) + table_names=config["table_names"] + table_comments=config["table_comments"] + result = None + + #如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判 + #由于langchain固定了输入参数,所以只能通过query传递额外的表说明 + if table_comments: + TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n" + table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()]) + query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n" + + #如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain + #这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源 + #如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断 + if len(table_names) > 0: + db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) + result = db_chain.invoke({"query":query,"table_names_to_use":table_names}) + else: + #先预测会使用哪些表,然后再将问题和预测的表给大模型 + db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) + result = db_chain.invoke(query) + + context = f"""查询结果:{result['result']}\n\n""" + + intermediate_steps=result["intermediate_steps"] + #如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解 + if intermediate_steps: + if len(intermediate_steps)>2: + sql_detail=intermediate_steps[-2:-1][0]["input"] + # sql_detail截取从SQLQuery到Answer:之间的内容 + sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")] + context = context+"执行的sql:'"+sql_detail+"'\n\n" + return context + + +@regist_tool(title="Text2Sql") +def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")): + '''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.''' + tool_config = get_tool_config("text2sql") + return BaseToolOutput(query_database(query=query, config=tool_config)) diff --git a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py index f935ab9c..eb2c74ca 100644 --- a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -28,6 +28,8 @@ chat_router.post("/file_chat", summary="文件对话" )(file_chat) +#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name +global_model_name=None @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") async def chat_completions( @@ -51,6 +53,8 @@ async def chat_completions( for key in list(extra): delattr(body, key) + global global_model_name + global_model_name=body.model # check tools & tool_choice in request body if isinstance(body.tool_choice, str): if t := get_tool(body.tool_choice): diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index 8db5fb89..d667add1 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -15,11 +15,11 @@ chatchat-kb = 'chatchat.init_database:main' [tool.poetry.dependencies] python = ">=3.8.1,<3.12,!=3.9.7" model-providers = "^0.3.0" -langchain = "0.1.5" +langchain = "0.1.17" langchainhub = "0.1.14" -langchain-community = "0.0.17" +langchain-community = "0.0.36" langchain-openai = "0.0.5" -langchain-experimental = "0.0.50" +langchain-experimental = "0.0.58" fastapi = "~0.109.2" sse_starlette = "~1.8.2" nltk = "~3.8.1" @@ -51,12 +51,13 @@ python-multipart = "0.0.9" streamlit = "1.34.0" streamlit-option-menu = "0.3.12" streamlit-antd-components = "0.3.1" -streamlit-chatbox = "1.1.12" +streamlit-chatbox = "1.1.12.post2" streamlit-modal = "0.1.0" streamlit-aggrid = "0.3.4.post3" streamlit-extras = "0.4.2" xinference_client = { version = "^0.11.1", optional = true } zhipuai = { version = "^2.1.0", optional = true } +pymysql = "^1.1.0" [tool.poetry.extras] xinference = ["xinference_client"]