diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index af68a74d..a857e80b 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -1,7 +1,7 @@ import os -# 默认向量库类型。可选:faiss, milvus, pg. +# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg. DEFAULT_VS_TYPE = "faiss" # 缓存向量库数量(针对FAISS) @@ -42,13 +42,17 @@ BING_SUBSCRIPTION_KEY = "" ZH_TITLE_ENHANCE = False -# 通常情况下不需要更改以下内容 +# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 +KB_INFO = { + "知识库名称": "知识库介绍", + "samples": "关于本项目issue的解答", +} +# 通常情况下不需要更改以下内容 # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") if not os.path.exists(KB_ROOT_PATH): os.mkdir(KB_ROOT_PATH) - # 数据库默认存储路径。 # 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") @@ -65,6 +69,13 @@ kbs_config = { "password": "", "secure": False, }, + "zilliz": { + "host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn", + "port": "19530", + "user": "", + "password": "", + "secure": True, + }, "pg": { "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", } @@ -74,11 +85,11 @@ kbs_config = { text_splitter_dict = { "ChineseRecursiveTextSplitter": { "source": "huggingface", ## 选择tiktoken则使用openai的方法 - "tokenizer_name_or_path": "gpt2", + "tokenizer_name_or_path": "", }, "SpacyTextSplitter": { "source": "huggingface", - "tokenizer_name_or_path": "", + "tokenizer_name_or_path": "gpt2", }, "RecursiveCharacterTextSplitter": { "source": "tiktoken", diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index 6446ce2e..a52b1f70 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -9,98 +9,106 @@ # - context: 从检索结果拼接的知识文本 # - question: 用户提出的问题 +# Agent对话支持的变量: -PROMPT_TEMPLATES = { - # LLM对话模板 - "llm_chat": "{{ input }}", +# - tools: 可用的工具列表 +# - tool_names: 可用的工具名称列表 +# - history: 用户和Agent的对话历史 +# - input: 用户输入内容 +# - agent_scratchpad: Agent的思维记录 - # 基于本地知识问答的提示词模板 - "knowledge_base_chat": +PROMPT_TEMPLATES = {} + +PROMPT_TEMPLATES["llm_chat"] = { + "default": "{{ input }}", + + "py": """ - <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 - <已知信息>{{ context }}、 - <问题>{{ question }}""", - - # 基于agent的提示词模板 - "agent_chat": + 你是一个聪明的代码助手,请你给我写出简单的py代码。 \n + {{ input }} """ - Answer the following questions as best you can. You have access to the following tools: - - {tools} - Use the following format: - - Question: the input question you must answer - Thought: you should always think about what to do - Action: the action to take, should be one of [{tool_names}] - Action Input: the input to the action - Observation: the result of the action - ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) - Thought: I now know the final answer - Final Answer: the final answer to the original input question - - Begin! - - history: - {history} - - Question: {input} - Thought: {agent_scratchpad} -""" + , } +PROMPT_TEMPLATES["knowledge_base_chat"] = { + "default": + """ + <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 + <已知信息>{{ context }}、 + <问题>{{ question }} + """, + "text": + """ + <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 + <已知信息>{{ context }}、 + <问题>{{ question }} + """, +} +PROMPT_TEMPLATES["search_engine_chat"] = { + "default": + """ + <指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 + <已知信息>{{ context }}、 + <问题>{{ question }} + """, -## GPT或Qwen 的Prompt -# """ -# Answer the following questions as best you can. You have access to the following tools: -# -# {tools} -# -# Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base -# -# Use the following format: -# -# Question: the input question you must answer -# Thought: you should always think about what to do -# Action: the action to take, should be one of [{tool_names}] -# Action Input: the input to the action -# Observation: the result of the action -# ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) -# Thought: I now know the final answer -# Final Answer: the final answer to the original input question -# -# Begin! -# -# history: -# {history} -# -# Question: {input} -# Thought: {agent_scratchpad} -# """ + "search": + """ + <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 + <已知信息>{{ context }}、 + <问题>{{ question }} + """, +} +PROMPT_TEMPLATES["agent_chat"] = { + "default": + """ + Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools: + + {tools} + + Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base, + Please note that the "天气查询工具" can only be used once since Question begin. + + Use the following format: + Question: the input question you must answer1 + Thought: you should always think about what to do and what tools to use. + Action: the action to take, should be one of [{tool_names}] + Action Input: the input to the action + Observation: the result of the action + ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) + Thought: I now know the final answer + Final Answer: the final answer to the original input question -## ChatGLM-Pro的Prompt + Begin! + history: + {history} + Question: {input} + Thought: {agent_scratchpad} + """, + "ChatGLM": + """ + 请请严格按照提供的思维方式来思考。你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。 + 你有以下工具可以使用: + {tools} + ``` + Question: 用户的提问或者观察到的信息, + Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。 + Action: 需要使用的工具,应该是在[{tool_names}]中的一个。 + Action Input: 传入工具的内容 + Observation: 工具给出的答案(不是你生成的) + ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) + Thought: 通过工具给出的答案,你是否能回答Question。 + Final Answer是你的答案 -# """ -# 请请严格按照提供的思维方式来思考。你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。 -# 你有以下工具可以使用: -# {tools} -# ``` -# Question: 用户的提问或者观察到的信息, -# Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。 -# Action: 需要使用的工具,应该是在[{tool_names}]中的一个。 -# Action Input: 传入工具的内容 -# Observation: 工具给出的答案(不是你生成的) -# ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) -# Thought: 通过工具给出的答案,你是否能回答Question。 -# Final Answer是你的答案 -# -# 现在,我们开始! -# 你和用户的历史记录: -# History: -# {history} -# -# 用户开始以提问: -# Question: {input} -# Thought: {agent_scratchpad} -# -# """ + 现在,我们开始! + 你和用户的历史记录: + History: + {history} + + 用户开始以提问: + Question: {input} + Thought: {agent_scratchpad} + + """, +} diff --git a/knowledge_base/samples/vector_store/index.faiss b/knowledge_base/samples/vector_store/index.faiss deleted file mode 100644 index 2404c993..00000000 Binary files a/knowledge_base/samples/vector_store/index.faiss and /dev/null differ diff --git a/knowledge_base/samples/vector_store/index.pkl b/knowledge_base/samples/vector_store/index.pkl deleted file mode 100644 index 709f9ee7..00000000 Binary files a/knowledge_base/samples/vector_store/index.pkl and /dev/null differ diff --git a/requirements.txt b/requirements.txt index 02bd05de..76312473 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -langchain==0.0.313 -langchain-experimental==0.0.30 +langchain>=0.0.314 +langchain-experimental>=0.0.30 fschat[model_worker]==0.2.30 openai sentence_transformers diff --git a/server/agent/__init__.py b/server/agent/__init__.py new file mode 100644 index 00000000..0de21612 --- /dev/null +++ b/server/agent/__init__.py @@ -0,0 +1,4 @@ +from .model_contain import * +from .callbacks import * +from .custom_template import * +from .tools import * \ No newline at end of file diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py index 3d143605..3a82b9c7 100644 --- a/server/agent/callbacks.py +++ b/server/agent/callbacks.py @@ -29,6 +29,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue = asyncio.Queue() self.done = asyncio.Event() self.cur_tool = {} + self.out = True async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, @@ -57,6 +58,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: + self.out = True ## 重置输出 self.cur_tool.update( status=Status.tool_finish, output_str=output.replace("Answer:", ""), @@ -72,7 +74,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - if token: + if "Action" in token: ## 减少重复输出 + before_action = token.split("Action")[0] + self.cur_tool.update( + status=Status.running, + llm_token=before_action + "\n", + ) + self.queue.put_nowait(dumps(self.cur_tool)) + + self.out = False + + if token and self.out: self.cur_tool.update( status=Status.running, llm_token=token, @@ -86,6 +98,14 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) self.queue.put_nowait(dumps(self.cur_tool)) + async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any, + ) -> None: + self.cur_tool.update( + status=Status.start, + llm_token="", + ) + self.queue.put_nowait(dumps(self.cur_tool)) + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.cur_tool.update( status=Status.complete, diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py index aaa2bfe4..22469c6b 100644 --- a/server/agent/custom_template.py +++ b/server/agent/custom_template.py @@ -1,11 +1,9 @@ from __future__ import annotations from langchain.agents import Tool, AgentOutputParser from langchain.prompts import StringPromptTemplate -from typing import List, Union, Tuple, Dict +from typing import List from langchain.schema import AgentAction, AgentFinish -import re -from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN - +from server.agent import model_container begin = False class CustomPromptTemplate(StringPromptTemplate): # The template to use @@ -41,7 +39,7 @@ class CustomOutputParser(AgentOutputParser): def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: # Check if agent should finish support_agent = ["gpt","Qwen","qwen-api","baichuan-api"] - if not any(agent in LLM_MODEL for agent in support_agent) and self.begin: + if not any(agent in model_container.MODEL for agent in support_agent) and self.begin: self.begin = False stop_words = ["Observation:"] min_index = len(llm_output) diff --git a/server/agent/model_contain.py b/server/agent/model_contain.py new file mode 100644 index 00000000..1927c88f --- /dev/null +++ b/server/agent/model_contain.py @@ -0,0 +1,8 @@ + +## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍 +class ModelContainer: + def __init__(self): + self.MODEL = None + self.DATABASE = None + +model_container = ModelContainer() diff --git a/server/agent/search_knowledge.py b/server/agent/search_knowledge.py deleted file mode 100644 index bdf7f4c0..00000000 --- a/server/agent/search_knowledge.py +++ /dev/null @@ -1,37 +0,0 @@ -## 单独运行的时候需要添加 -import sys -import os -import json - -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import LLM_MODEL, TEMPERATURE, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD - -import asyncio - - -async def search_knowledge_base_iter(query: str): - response = await knowledge_base_chat(query=query, - knowledge_base_name="tcqa", - model_name=LLM_MODEL, - temperature=TEMPERATURE, - history=[], - top_k = VECTOR_SEARCH_TOP_K, - prompt_name = "knowledge_base_chat", - score_threshold = SCORE_THRESHOLD, - stream=False) - - contents = "" - async for data in response.body_iterator: # 这里的data是一个json字符串 - data = json.loads(data) - contents = data["answer"] - docs = data["docs"] - return contents - -def search_knowledge(query: str): - return asyncio.run(search_knowledge_base_iter(query)) - - -if __name__ == "__main__": - result = search_knowledge("大数据男女比例") - print("答案:",result) diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py new file mode 100644 index 00000000..8bb5cac6 --- /dev/null +++ b/server/agent/tools/__init__.py @@ -0,0 +1,11 @@ +## 导入所有的工具类 +from .search_knowledge_simple import knowledge_search_simple +from .search_all_knowledge_once import knowledge_search_once +from .search_all_knowledge_more import knowledge_search_more +from .travel_assistant import travel_assistant +from .calculate import calculate +from .translator import translate +from .weather import weathercheck +from .shell import shell +from .search_internet import search_internet + diff --git a/server/agent/math.py b/server/agent/tools/calculate.py similarity index 84% rename from server/agent/math.py rename to server/agent/tools/calculate.py index b4056d07..2d963ae8 100644 --- a/server/agent/math.py +++ b/server/agent/tools/calculate.py @@ -1,12 +1,12 @@ ## 单独运行的时候需要添加 import sys import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) from langchain.prompts import PromptTemplate from langchain.chains import LLMMathChain -from server.utils import get_ChatOpenAI -from configs.model_config import LLM_MODEL, TEMPERATURE +from server.agent import model_container + _PROMPT_TEMPLATE = """ 将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。 问题: ${{包含数学问题的问题。}} @@ -63,11 +63,7 @@ PROMPT = PromptTemplate( def calculate(query: str): - model = get_ChatOpenAI( - streaming=False, - model_name=LLM_MODEL, - temperature=TEMPERATURE, - ) + model = model_container.MODEL llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) ans = llm_math.run(query) return ans diff --git a/server/agent/tools/search_all_knowledge_more.py b/server/agent/tools/search_all_knowledge_more.py new file mode 100644 index 00000000..fe70171c --- /dev/null +++ b/server/agent/tools/search_all_knowledge_more.py @@ -0,0 +1,296 @@ +## 单独运行的时候需要添加 +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + +import json +import re +import warnings +from typing import Dict +from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun +from langchain.chains.llm import LLMChain +from langchain.pydantic_v1 import Extra, root_validator +from langchain.schema import BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel +from typing import List, Any, Optional +from langchain.prompts import PromptTemplate +from server.chat.knowledge_base_chat import knowledge_base_chat +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD +import asyncio +from server.agent import model_container + + +async def search_knowledge_base_iter(database: str, query: str) -> str: + response = await knowledge_base_chat(query=query, + knowledge_base_name=database, + model_name=model_container.MODEL.model_name, + temperature=0.01, + history=[], + top_k=VECTOR_SEARCH_TOP_K, + max_tokens=None, + prompt_name="default", + score_threshold=SCORE_THRESHOLD, + stream=False) + + contents = "" + async for data in response.body_iterator: # 这里的data是一个json字符串 + data = json.loads(data) + contents += data["answer"] + docs = data["docs"] + return contents + + +async def search_knowledge_multiple(queries) -> List[str]: + # queries 应该是一个包含多个 (database, query) 元组的列表 + tasks = [search_knowledge_base_iter(database, query) for database, query in queries] + results = await asyncio.gather(*tasks) + # 结合每个查询结果,并在每个查询结果前添加一个自定义的消息 + combined_results = [] + for (database, _), result in zip(queries, results): + message = f"\n查询到 {database} 知识库的相关信息:\n{result}" + combined_results.append(message) + + return combined_results + + +def search_knowledge(queries) -> str: + responses = asyncio.run(search_knowledge_multiple(queries)) + # 输出每个整合的查询结果 + contents = "" + for response in responses: + contents += response + "\n\n" + return contents + + +_PROMPT_TEMPLATE = """ +用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。 + +对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。 + +例子: + +robotic,机器人男女比例是多少 +bigdata,大数据的就业情况如何 + + +这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考 + +{database_names} + +你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。 + + +Question: ${{用户的问题}} + +```text +${{知识库名称,查询问题,不要带有任何除了,之外的符号}} + +```output +数据库查询的结果 + + + +这是一个完整的问题拆分和提问的例子: + + +问题: 分别对比机器人和大数据专业的就业情况并告诉我哪儿专业的就业情况更好? + +```text +robotic,机器人专业的就业情况 +bigdata,大数据专业的就业情况 + + + +现在,我们开始作答 +问题: {question} +""" + +PROMPT = PromptTemplate( + input_variables=["question", "database_names"], + template=_PROMPT_TEMPLATE, +) + + +class LLMKnowledgeChain(LLMChain): + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" + prompt: BasePromptTemplate = PROMPT + """[Deprecated] Prompt to use to translate to python if necessary.""" + database_names: Dict[str, str] = None + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _evaluate_expression(self, queries) -> str: + try: + output = search_knowledge(queries) + except Exception as e: + output = "输入的信息有误或不存在知识库,错误信息如下:\n" + return output + str(e) + return output + + def _process_llm_result( + self, + llm_output: str, + run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + + run_manager.on_text(llm_output, color="green", verbose=self.verbose) + + llm_output = llm_output.strip() + # text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + text_match = re.search(r"```text(.*)", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1).strip() + cleaned_input_str = (expression.replace("\"", "").replace("“", ""). + replace("”", "").replace("```", "").strip()) + lines = cleaned_input_str.split("\n") + # 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表 + + queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] + run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose) + output = self._evaluate_expression(queries) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = llm_output.split("Answer:")[-1] + else: + return {self.output_key: f"输入的格式不对:\n {llm_output}"} + return {self.output_key: answer} + + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"```text(.*)", llm_output, re.DOTALL) + if text_match: + + expression = text_match.group(1).strip() + cleaned_input_str = ( + expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip()) + lines = cleaned_input_str.split("\n") + queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] + await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", + verbose=self.verbose) + + output = self._evaluate_expression(queries) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + self.database_names = model_container.DATABASE + data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) + llm_output = self.llm_chain.predict( + database_names=data_formatted_str, + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return self._process_llm_result(llm_output, _run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + self.database_names = model_container.DATABASE + data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) + llm_output = await self.llm_chain.apredict( + database_names=data_formatted_str, + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager) + + @property + def _chain_type(self) -> str: + return "llm_knowledge_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ): + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) + + +def knowledge_search_more(query: str): + model = model_container.MODEL + llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT) + ans = llm_knowledge.run(query) + return ans + + +if __name__ == "__main__": + result = knowledge_search_more("机器人和大数据在代码教学上有什么区别") + print(result) + +# 这是一个正常的切割 +# queries = [ +# ("bigdata", "大数据专业的男女比例"), +# ("robotic", "机器人专业的优势") +# ] +# result = search_knowledge(queries) +# print(result) diff --git a/server/agent/tools/search_all_knowledge_once.py b/server/agent/tools/search_all_knowledge_once.py new file mode 100644 index 00000000..7a10536f --- /dev/null +++ b/server/agent/tools/search_all_knowledge_once.py @@ -0,0 +1,234 @@ +## 单独运行的时候需要添加 +# import sys +# import os +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + +import re +import warnings +from typing import Dict + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.llm import LLMChain +from langchain.pydantic_v1 import Extra, root_validator +from langchain.schema import BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel +from typing import List, Any, Optional +from langchain.prompts import PromptTemplate +import sys +import os +import json + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from server.chat.knowledge_base_chat import knowledge_base_chat +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD + +import asyncio +from server.agent import model_container + + +async def search_knowledge_base_iter(database: str, query: str): + response = await knowledge_base_chat(query=query, + knowledge_base_name=database, + model_name=model_container.MODEL.model_name, + temperature=0.01, + history=[], + top_k=VECTOR_SEARCH_TOP_K, + max_tokens=None, + prompt_name="knowledge_base_chat", + score_threshold=SCORE_THRESHOLD, + stream=False) + + contents = "" + async for data in response.body_iterator: # 这里的data是一个json字符串 + data = json.loads(data) + contents += data["answer"] + docs = data["docs"] + return contents + + +_PROMPT_TEMPLATE = """ +用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考 +Question: ${{用户的问题}} +这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能: + +{database_names} + +你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。 +```text +${{知识库的名称}} +``` +```output +数据库查询的结果 +``` +答案: ${{答案}} + +现在,这是我的问题: +问题: {question} + +""" +PROMPT = PromptTemplate( + input_variables=["question", "database_names"], + template=_PROMPT_TEMPLATE, +) + + +class LLMKnowledgeChain(LLMChain): + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" + prompt: BasePromptTemplate = PROMPT + """[Deprecated] Prompt to use to translate to python if necessary.""" + database_names: Dict[str, str] = model_container.DATABASE + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _evaluate_expression(self, dataset, query) -> str: + try: + output = asyncio.run(search_knowledge_base_iter(dataset, query)) + except Exception as e: + output = "输入的信息有误或不存在知识库" + return output + return output + + def _process_llm_result( + self, + llm_output: str, + llm_input: str, + run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + + run_manager.on_text(llm_output, color="green", verbose=self.verbose) + + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + database = text_match.group(1).strip() + output = self._evaluate_expression(database, llm_input) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + return {self.output_key: f"输入的格式不对: {llm_output}"} + return {self.output_key: answer} + + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) + llm_output = self.llm_chain.predict( + database_names=data_formatted_str, + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) + llm_output = await self.llm_chain.apredict( + database_names=data_formatted_str, + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager) + + @property + def _chain_type(self) -> str: + return "llm_knowledge_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ): + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) + + +def knowledge_search_once(query: str): + model = model_container.MODEL + llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT) + ans = llm_knowledge.run(query) + return ans + + +if __name__ == "__main__": + result = knowledge_search_once("大数据的男女比例") + print(result) diff --git a/server/agent/search_internet.py b/server/agent/tools/search_internet.py similarity index 54% rename from server/agent/search_internet.py rename to server/agent/tools/search_internet.py index 1ed39faf..6eec93e5 100644 --- a/server/agent/search_internet.py +++ b/server/agent/tools/search_internet.py @@ -1,35 +1,39 @@ ## 单独运行的时候需要添加 -import sys -import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +# import sys +# import os +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + import json from server.chat import search_engine_chat -from configs import LLM_MODEL, TEMPERATURE, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD - +from configs import VECTOR_SEARCH_TOP_K import asyncio - +from server.agent import model_container async def search_engine_iter(query: str): response = await search_engine_chat(query=query, - search_engine_name="bing", - model_name=LLM_MODEL, - temperature=TEMPERATURE, + search_engine_name="bing", # 这里切换搜索引擎 + model_name=model_container.MODEL.model_name, + temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01 history=[], top_k = VECTOR_SEARCH_TOP_K, - prompt_name = "knowledge_base_chat", + max_tokens= None, # Agent 搜索互联网的时候,max_tokens设置为None + prompt_name = "default", stream=False) contents = "" + async for data in response.body_iterator: # 这里的data是一个json字符串 data = json.loads(data) contents = data["answer"] docs = data["docs"] + return contents def search_internet(query: str): + return asyncio.run(search_engine_iter(query)) if __name__ == "__main__": - result = search_internet("大数据男女比例") + result = search_internet("今天星期几") print("答案:",result) diff --git a/server/agent/tools/search_knowledge_simple.py b/server/agent/tools/search_knowledge_simple.py new file mode 100644 index 00000000..03f4da10 --- /dev/null +++ b/server/agent/tools/search_knowledge_simple.py @@ -0,0 +1,38 @@ +## 最简单的版本,只支持固定的知识库 + +# ## 单独运行的时候需要添加 +# import sys +# import os +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + +from server.chat.knowledge_base_chat import knowledge_base_chat +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD +import json +import asyncio +from server.agent import model_container + +async def search_knowledge_base_iter(database: str, query: str) -> str: + response = await knowledge_base_chat(query=query, + knowledge_base_name=database, + model_name=model_container.MODEL.model_name, + temperature=0.01, + history=[], + top_k=VECTOR_SEARCH_TOP_K, + prompt_name="knowledge_base_chat", + score_threshold=SCORE_THRESHOLD, + stream=False) + + contents = "" + async for data in response.body_iterator: # 这里的data是一个json字符串 + data = json.loads(data) + contents = data["answer"] + docs = data["docs"] + return contents + +def knowledge_search_simple(query: str): + return asyncio.run(search_knowledge_base_iter(query)) + + +if __name__ == "__main__": + result = knowledge_search_simple("大数据男女比例") + print("答案:",result) \ No newline at end of file diff --git a/server/agent/shell.py b/server/agent/tools/shell.py similarity index 100% rename from server/agent/shell.py rename to server/agent/tools/shell.py diff --git a/server/agent/translator.py b/server/agent/tools/translator.py similarity index 71% rename from server/agent/translator.py rename to server/agent/tools/translator.py index 96b4c8f8..62ffa33b 100644 --- a/server/agent/translator.py +++ b/server/agent/tools/translator.py @@ -1,12 +1,11 @@ ## 单独运行的时候需要添加 -import sys -import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +# import sys +# import os +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + from langchain.prompts import PromptTemplate from langchain.chains import LLMChain -from server.utils import get_ChatOpenAI -from langchain.chains.llm_math.prompt import PROMPT -from configs.model_config import LLM_MODEL,TEMPERATURE +from server.agent import model_container _PROMPT_TEMPLATE = ''' # 指令 @@ -30,11 +29,7 @@ PROMPT = PromptTemplate( def translate(query: str): - model = get_ChatOpenAI( - streaming=False, - model_name=LLM_MODEL, - temperature=TEMPERATURE, - ) + model = model_container.MODEL llm_translate = LLMChain(llm=model, prompt=PROMPT) ans = llm_translate.run(query) return ans diff --git a/server/agent/weather.py b/server/agent/tools/weather.py similarity index 96% rename from server/agent/weather.py rename to server/agent/tools/weather.py index ad06d28c..d0dd58dc 100644 --- a/server/agent/weather.py +++ b/server/agent/tools/weather.py @@ -1,6 +1,5 @@ ## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配 ## 这个模型的提示词非常复杂,我们推荐使用GPT4模型进行运行 - from __future__ import annotations ## 单独运行的时候需要添加 @@ -8,10 +7,6 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - -from server.utils import get_ChatOpenAI - - import re import warnings from typing import Dict @@ -27,10 +22,9 @@ from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel import requests from typing import List, Any, Optional -from configs.model_config import LLM_MODEL, TEMPERATURE from datetime import datetime from langchain.prompts import PromptTemplate - +from server.agent import model_container ## 使用和风天气API查询天气 KEY = "ac880e5a877042809ac7ffdd19d95b0d" @@ -237,9 +231,6 @@ class LLMWeatherChain(Chain): output = weather(expression) except Exception as e: output = "输入的信息有误,请再次尝试" - return {self.output_key: output} - raise ValueError(f"错误: {expression},输入的信息不对") - return output def _process_llm_result( @@ -262,7 +253,6 @@ class LLMWeatherChain(Chain): answer = "Answer: " + llm_output.split("Answer:")[-1] else: return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"} - # raise ValueError(f"unknown format from LLM: {llm_output}") return {self.output_key: answer} async def _aprocess_llm_result( @@ -273,6 +263,7 @@ class LLMWeatherChain(Chain): await run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) @@ -332,14 +323,10 @@ class LLMWeatherChain(Chain): def weathercheck(query: str): - model = get_ChatOpenAI( - streaming=False, - model_name=LLM_MODEL, - temperature=TEMPERATURE, - ) + model = model_container.MODEL llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT) ans = llm_weather.run(query) return ans if __name__ == '__main__': - result = weathercheck("苏州工姑苏区今晚热不热?") \ No newline at end of file + result = weathercheck("苏州姑苏区今晚热不热?") \ No newline at end of file diff --git a/server/agent/tools.py b/server/agent/tools_select.py similarity index 56% rename from server/agent/tools.py rename to server/agent/tools_select.py index 0f94529f..40bb63fc 100644 --- a/server/agent/tools.py +++ b/server/agent/tools_select.py @@ -1,16 +1,5 @@ -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - -from server.agent.math import calculate -from server.agent.translator import translate -from server.agent.weather import weathercheck -from server.agent.shell import shell -from langchain.agents import Tool -from server.agent.search_knowledge import search_knowledge -from server.agent.search_internet import search_internet - +from langchain.tools import Tool +from server.agent.tools import * tools = [ Tool.from_function( func=calculate, @@ -25,7 +14,7 @@ tools = [ Tool.from_function( func=weathercheck, name="天气查询工具", - description="如果你无法访问互联网,并需要查询中国各地未来24小时的天气,你应该使用这个工具,每轮对话仅能使用一次", + description="无需访问互联网,使用这个工具查询中国各地未来24小时的天气", ), Tool.from_function( func=shell, @@ -33,15 +22,15 @@ tools = [ description="使用命令行工具输出", ), Tool.from_function( - func=search_knowledge, + func=knowledge_search_more, name="知识库查询工具", - description="访问知识库来获取答案", + description="优先访问知识库来获取答案", ), Tool.from_function( func=search_internet, name="互联网查询工具", description="如果你无法访问互联网,这个工具可以帮助你访问Bing互联网来解答问题", ), - ] + tool_names = [tool.name for tool in tools] diff --git a/server/api.py b/server/api.py index 0b0692a0..370f7edf 100644 --- a/server/api.py +++ b/server/api.py @@ -16,7 +16,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, - search_docs, DocumentWithScore) + search_docs, DocumentWithScore, update_info) from server.llm_api import (list_running_models, list_config_models, change_llm_model, stop_llm_model, get_model_config, list_search_engines) @@ -115,6 +115,11 @@ def create_app(): summary="删除知识库内指定文件" )(delete_docs) + app.post("/knowledge_base/update_info", + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="更新知识库介绍" + )(update_info) app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 1f518025..c78add68 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -1,33 +1,34 @@ from langchain.memory import ConversationBufferWindowMemory -from server.agent.tools import tools, tool_names -from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps +from server.agent.tools_select import tools, tool_names +from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from langchain.agents import AgentExecutor, LLMSingleActionAgent from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN +from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from langchain.chains import LLMChain -from typing import AsyncIterable, Optional +from typing import AsyncIterable, Optional, Dict import asyncio from typing import List from server.chat.utils import History import json - +from server.agent import model_container +from server.knowledge_base.kb_service.base import get_kb_details async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", examples=[[ - {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"}]] + {"role": "user", "content": "请使用知识库工具查询今天北京天气"}, + {"role": "assistant", "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 - prompt_name: str = Body("agent_chat", - description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), + # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): history = [History.from_data(h) for h in history] @@ -43,19 +44,26 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples model_name=model_name, temperature=temperature, max_tokens=max_tokens, + callbacks=[callback], ) - prompt_template = CustomPromptTemplate( - template=get_prompt_template(prompt_name), + ## 传入全局变量来实现agent调用 + kb_list = {x["kb_name"]: x for x in get_kb_details()} + model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} + model_container.MODEL = model + + prompt_template = get_prompt_template("agent_chat", prompt_name) + prompt_template_agent = CustomPromptTemplate( + template=prompt_template, tools=tools, input_variables=["input", "intermediate_steps", "history"] ) output_parser = CustomOutputParser() - llm_chain = LLMChain(llm=model, prompt=prompt_template) + llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) agent = LLMSingleActionAgent( llm_chain=llm_chain, output_parser=output_parser, - stop=["\nObservation:", "Observation:", "<|im_end|>"], # Qwen模型中使用这个 + stop=["\nObservation:", "Observation:", "<|im_end|>"], # Qwen模型中使用这个 allowed_tools=tool_names, ) # 把history转成agent的memory @@ -73,15 +81,15 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples verbose=True, memory=memory, ) - input_msg = History(role="user", content="{{ input }}").to_msg_template(False) while True: try: task = asyncio.create_task(wrap_done( - agent_executor.acall(query, callbacks=[callback], include_run_info=True), - callback.done)) + agent_executor.acall(query, callbacks=[callback], include_run_info=True), + callback.done)) break except: pass + if stream: async for chunk in callback.aiter(): tools_use = [] @@ -89,46 +97,55 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue - if data["status"] == Status.error: + elif data["status"] == Status.error: + tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用失败") tools_use.append("错误信息: " + data["error"]) tools_use.append("重新开始尝试") tools_use.append("\n```\n") yield json.dumps({"tools": tools_use}, ensure_ascii=False) - if data["status"] == Status.agent_action: - yield json.dumps({"answer": "\n\n```\n\n"}, ensure_ascii=False) - if data["status"] == Status.tool_finish: + elif data["status"] == Status.tool_finish: + tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用成功") tools_use.append("工具输入: " + data["input_str"]) tools_use.append("工具输出: " + data["output_str"]) tools_use.append("\n```\n") yield json.dumps({"tools": tools_use}, ensure_ascii=False) - if data["status"] == Status.agent_finish: + elif data["status"] == Status.agent_finish: yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False) else: yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False) - else: - pass - # agent必须要steram=True,这部分暂时没有完成 - # result = [] - # async for chunk in callback.aiter(): - # data = json.loads(chunk) - # status = data["status"] - # if status == Status.start: - # result.append(chunk) - # elif status == Status.running: - # result[-1]["llm_token"] += chunk["llm_token"] - # elif status == Status.complete: - # result[-1]["status"] = Status.complete - # elif status == Status.agent_finish: - # result.append(chunk) - # elif status == Status.agent_finish: - # pass - # yield dumps(result) + else: + answer = "" + final_answer = "" + async for chunk in callback.aiter(): + # Use server-sent-events to stream the response + data = json.loads(chunk) + if data["status"] == Status.start or data["status"] == Status.complete: + continue + if data["status"] == Status.error: + answer += "\n```\n" + answer += "工具名称: " + data["tool_name"] + "\n" + answer += "工具状态: " + "调用失败" + "\n" + answer += "错误信息: " + data["error"] + "\n" + answer += "\n```\n" + if data["status"] == Status.tool_finish: + answer += "\n```\n" + answer += "工具名称: " + data["tool_name"] + "\n" + answer += "工具状态: " + "调用成功" + "\n" + answer += "工具输入: " + data["input_str"] + "\n" + answer += "工具输出: " + data["output_str"] + "\n" + answer += "\n```\n" + if data["status"] == Status.agent_finish: + final_answer = data["final_answer"] + else: + answer += data["llm_token"] + + yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False) await task return StreamingResponse(agent_chat_iterator(query=query, diff --git a/server/chat/chat.py b/server/chat/chat.py index 2f37f99b..3ec68558 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -22,9 +22,10 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), + # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), - prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): history = [History.from_data(h) for h in history] @@ -41,7 +42,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks=[callback], ) - prompt_template = get_prompt_template(prompt_name) + prompt_template = get_prompt_template("llm_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 6ae90e29..c39b147e 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,9 +31,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 - prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - request: Request = None, + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), + # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -57,7 +57,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) - prompt_template = get_prompt_template(prompt_name) + prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) @@ -74,10 +74,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", for inum, doc in enumerate(docs): filename = os.path.split(doc.metadata["source"])[-1] parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) - url = f"{request.base_url}knowledge_base/download_doc?" + parameters + url = f"/knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) - if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 930c9bdd..83ed65e4 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -72,8 +72,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 - prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), + # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -101,7 +102,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", docs = await lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) - prompt_template = get_prompt_template(prompt_name) + prompt_template = get_prompt_template("search_engine_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index 478bc1f3..f9035af4 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -10,10 +10,11 @@ class KnowledgeBaseModel(Base): __tablename__ = 'knowledge_base' id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') kb_name = Column(String(50), comment='知识库名称') + kb_info = Column(String(200), comment='知识库简介(用于Agent)') vs_type = Column(String(50), comment='向量库类型') embed_model = Column(String(50), comment='嵌入模型名称') file_count = Column(Integer, default=0, comment='文件数量') create_time = Column(DateTime, default=func.now(), comment='创建时间') def __repr__(self): - return f"" + return f"" diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index 585fd9b3..d20a973d 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -3,13 +3,14 @@ from server.db.session import with_session @with_session -def add_kb_to_db(session, kb_name, vs_type, embed_model): +def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): # 创建知识库实例 kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() if not kb: - kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) + kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) session.add(kb) - else: # update kb with new vs_type and embed_model + else: # update kb with new vs_type and embed_model + kb.kb_info = kb_info kb.vs_type = vs_type kb.embed_model = embed_model return True @@ -53,6 +54,7 @@ def get_kb_detail(session, kb_name: str) -> dict: if kb: return { "kb_name": kb.kb_name, + "kb_info": kb.kb_info, "vs_type": kb.vs_type, "embed_model": kb.embed_model, "file_count": kb.file_count, diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index e158ad06..3d01f9e2 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -203,6 +203,20 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) +def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), + ): + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb.update_info(kb_info) + + return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) + + def update_docs( knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index a725a78e..701f0cdb 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -19,7 +19,7 @@ from server.db.repository.knowledge_file_repository import ( ) from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_MODEL) + EMBEDDING_MODEL, KB_INFO) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, @@ -42,11 +42,11 @@ class KBService(ABC): embed_model: str = EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name + self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) @@ -63,7 +63,7 @@ class KBService(ABC): if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) self.do_create_kb() - status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model) + status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status def clear_vs(self): @@ -116,6 +116,14 @@ class KBService(ABC): os.remove(kb_file.filepath) return status + def update_info(self, kb_info: str): + """ + 更新知识库介绍 + """ + self.kb_info = kb_info + status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) + return status + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 @@ -127,7 +135,7 @@ class KBService(ABC): def exist_doc(self, file_name: str): return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name, - filename=file_name)) + filename=file_name)) def list_files(self): return list_files_from_db(self.kb_name) @@ -271,6 +279,7 @@ def get_kb_details() -> List[Dict]: result[kb] = { "kb_name": kb, "vs_type": "", + "kb_info": "", "embed_model": "", "file_count": 0, "create_time": None, diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py new file mode 100644 index 00000000..679e7d9b --- /dev/null +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -0,0 +1,98 @@ +from typing import List, Dict, Optional +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import Zilliz +from configs import kbs_config +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ + score_threshold_process +from server.knowledge_base.utils import KnowledgeFile + + +class ZillizKBService(KBService): + zilliz: Zilliz + + @staticmethod + def get_collection(zilliz_name): + from pymilvus import Collection + return Collection(zilliz_name) + + # def save_vector_store(self): + # if self.zilliz.col: + # self.zilliz.col.flush() + + def get_doc_by_id(self, id: str) -> Optional[Document]: + if self.zilliz.col: + data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"]) + if len(data_list) > 0: + data = data_list[0] + text = data.pop("text") + return Document(page_content=text, metadata=data) + + @staticmethod + def search(zilliz_name, content, limit=3): + search_params = { + "metric_type": "IP", + "params": {}, + } + c = ZillizKBService.get_collection(zilliz_name) + return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"]) + + def do_create_kb(self): + pass + + def vs_type(self) -> str: + return SupportedVSType.ZILLIZ + + def _load_zilliz(self, embeddings: Embeddings = None): + if embeddings is None: + embeddings = self._load_embeddings() + zilliz_args = kbs_config.get("zilliz") + self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings), + collection_name=self.kb_name, connection_args=zilliz_args) + + + def do_init(self): + self._load_zilliz() + + def do_drop_kb(self): + if self.zilliz.col: + self.zilliz.col.release() + self.zilliz.col.drop() + + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): + self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings)) + return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k)) + + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + for doc in docs: + for k, v in doc.metadata.items(): + doc.metadata[k] = str(v) + for field in self.zilliz.fields: + doc.metadata.setdefault(field, "") + doc.metadata.pop(self.zilliz._text_field, None) + doc.metadata.pop(self.zilliz._vector_field, None) + + ids = self.zilliz.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos + + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): + if self.zilliz.col: + filepath = kb_file.filepath.replace('\\', '\\\\') + delete_list = [item.get("pk") for item in + self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] + self.zilliz.col.delete(expr=f'pk in {delete_list}') + + def do_clear_vs(self): + if self.zilliz.col: + self.do_drop_kb() + self.do_init() + + +if __name__ == '__main__': + + from server.db.base import Base, engine + + Base.metadata.create_all(bind=engine) + zillizService = ZillizKBService("test") + diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 499e915d..517ed406 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -37,9 +37,10 @@ def folder2db( kb_names: List[str], mode: Literal["recreate_vs", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, + kb_info: dict[str, Any] = {}, embed_model: str = EMBEDDING_MODEL, chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, zh_title_enhance: bool = ZH_TITLE_ENHANCE, ): ''' diff --git a/server/llm_api.py b/server/llm_api.py index 2b1ce456..dc9ddced 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -32,7 +32,6 @@ def list_config_models() -> BaseResponse: 从本地获取configs中配置的模型列表 ''' configs = list_config_llm_models() - # 删除ONLINE_MODEL配置中的敏感信息 for config in configs["online"].values(): del_keys = set(["worker_class"]) diff --git a/server/utils.py b/server/utils.py index f62cccfe..2f6dfc49 100644 --- a/server/utils.py +++ b/server/utils.py @@ -389,15 +389,16 @@ def webui_address() -> str: return f"http://{host}:{port}" -def get_prompt_template(name: str) -> Optional[str]: +def get_prompt_template(type:str,name: str) -> Optional[str]: ''' 从prompt_config中加载模板内容 + type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 ''' + from configs import prompt_config import importlib importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 - - return prompt_config.PROMPT_TEMPLATES.get(name) + return prompt_config.PROMPT_TEMPLATES[type].get(name) def set_httpx_config( @@ -409,6 +410,7 @@ def set_httpx_config( 将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效) 对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。 ''' + import httpx import os diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 975f8bcc..82ef3a10 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -137,6 +137,14 @@ def test_search_docs(api="/knowledge_base/search_docs"): assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K +def test_update_info(api="/knowledge_base/update_info"): + url = api_base_url + api + print("\n更新知识库介绍") + r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"}) + data = r.json() + pprint(data) + assert data["code"] == 200 + def test_update_docs(api="/knowledge_base/update_docs"): url = api_base_url + api diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index deadc32d..adc60293 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -3,10 +3,9 @@ from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime import os -from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN +from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES from typing import List, Dict - chat_box = ChatBox( assistant_avatar=os.path.join( "img", @@ -47,11 +46,11 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool): if LLM_MODEL in running_models: return LLM_MODEL, True - + local_models = [k for k, v in running_models.items() if not v.get("online_api")] if local_models: return local_models[0], True - + return running_models[0], False @@ -94,15 +93,14 @@ def dialogue_page(api: ApiRequest): running_models = list(api.list_running_models()) available_models = [] config_models = api.list_config_models() - worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 + worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 for m in worker_models: if m not in running_models and m != "default": available_models.append(m) - for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT) + for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT) if not v.get("provider") and k not in running_models: print(k, v) available_models.append(k) - llm_models = running_models + available_models index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) llm_model = st.selectbox("选择LLM模型:", @@ -124,11 +122,33 @@ def dialogue_page(api: ApiRequest): st.success(msg) st.session_state["prev_llm_model"] = llm_model + index_prompt = { + "LLM 对话": "llm_chat", + "自定义Agent问答": "agent_chat", + "搜索引擎问答": "search_engine_chat", + "知识库问答": "knowledge_base_chat", + } + prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) + prompt_template_name = prompt_templates_kb_list[0] + if "prompt_template_select" not in st.session_state: + st.session_state.prompt_template_select = prompt_templates_kb_list[0] + + def prompt_change(): + text = f"已切换为 {prompt_template_name} 模板。" + st.toast(text) + + prompt_template_select = st.selectbox( + "请选择Prompt模板:", + prompt_templates_kb_list, + index=0, + on_change=prompt_change, + key="prompt_template_select", + ) + prompt_template_name = st.session_state.prompt_template_select + temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05) history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN) - LLM_MODEL_WEBUI = llm_model - TEMPERATURE_WEBUI = temperature def on_kb_change(): st.toast(f"已加载知识库: {st.session_state.selected_kb}") @@ -144,8 +164,7 @@ def dialogue_page(api: ApiRequest): ) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) - # chunk_content = st.checkbox("关联上下文", False, disabled=True) - # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) + elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines() with st.expander("搜索引擎配置", True): @@ -168,7 +187,11 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature) + r = api.chat_chat(prompt, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -178,37 +201,38 @@ def dialogue_page(api: ApiRequest): chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + elif dialogue_mode == "自定义Agent问答": chat_box.ai_say([ f"正在思考...", Markdown("...", in_expander=True, title="思考过程", state="complete"), + ]) text = "" ans = "" - support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 + support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 if not any(agent in llm_model for agent in support_agent): ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n" chat_box.update_msg(ans, element_index=0, streaming=False) - - for d in api.agent_chat(prompt, history=history, model=llm_model, - temperature=temperature): + prompt_name=prompt_template_name, + temperature=temperature, + ): try: d = json.loads(d) except: pass if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) - - elif chunk := d.get("final_answer"): - ans += chunk - chat_box.update_msg(ans, element_index=0) - elif chunk := d.get("answer"): + if chunk := d.get("answer"): text += chunk chat_box.update_msg(text, element_index=1) - elif chunk := d.get("tools"): + if chunk := d.get("final_answer"): + ans += chunk + chat_box.update_msg(ans, element_index=0) + if chunk := d.get("tools"): text += "\n\n".join(d.get("tools", [])) chat_box.update_msg(text, element_index=1) chat_box.update_msg(ans, element_index=0, streaming=False) @@ -225,6 +249,7 @@ def dialogue_page(api: ApiRequest): score_threshold=score_threshold, history=history, model=llm_model, + prompt_name=prompt_template_name, temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) @@ -244,6 +269,7 @@ def dialogue_page(api: ApiRequest): top_k=se_top_k, history=history, model=llm_model, + prompt_name=prompt_template_name, temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index bf8f0894..95a1fcab 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -63,6 +63,9 @@ def knowledge_base_page(api: ApiRequest): else: selected_kb_index = 0 + if "selected_kb_info" not in st.session_state: + st.session_state["selected_kb_info"] = "" + def format_selected_kb(kb_name: str) -> str: if kb := kb_list.get(kb_name): return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})" @@ -84,6 +87,11 @@ def knowledge_base_page(api: ApiRequest): placeholder="新知识库名称,不支持中文命名", key="kb_name", ) + kb_info = st.text_input( + "知识库简介", + placeholder="知识库简介,方便Agent查找", + key="kb_info", + ) cols = st.columns(2) @@ -123,18 +131,23 @@ def knowledge_base_page(api: ApiRequest): ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name + st.session_state["selected_kb_info"] = kb_info st.experimental_rerun() elif selected_kb: kb = selected_kb - - + st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] # 上传文件 files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) + kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None, + help=None, on_change=None, args=None, kwargs=None) + if kb_info != st.session_state["selected_kb_info"]: + st.session_state["selected_kb_info"] = kb_info + api.update_kb_info(kb, kb_info) # with st.sidebar: with st.expander( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 3e077667..7b9e161c 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -279,7 +279,7 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = 1024, - prompt_name: str = "llm_chat", + prompt_name: str = "default", **kwargs, ): ''' @@ -309,6 +309,7 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = 1024, + prompt_name: str = "default", ): ''' 对应api.py/chat/agent_chat 接口 @@ -320,6 +321,7 @@ class ApiRequest: "model_name": model, "temperature": temperature, "max_tokens": max_tokens, + "prompt_name": prompt_name, } print(f"received input message:") @@ -339,7 +341,7 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = 1024, - prompt_name: str = "knowledge_base_chat", + prompt_name: str = "default", ): ''' 对应api.py/chat/knowledge_base_chat接口 @@ -377,7 +379,7 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = 1024, - prompt_name: str = "knowledge_base_chat", + prompt_name: str = "default", ): ''' 对应api.py/chat/search_engine_chat接口 @@ -558,6 +560,22 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) + + def update_kb_info(self,knowledge_base_name,kb_info): + ''' + 对应api.py/knowledge_base/update_info接口 + ''' + data = { + "knowledge_base_name": knowledge_base_name, + "kb_info": kb_info, + } + + response = self.post( + "/knowledge_base/update_info", + json=data, + ) + return self._get_response_value(response, as_json=True) + def update_kb_docs( self, knowledge_base_name: str, @@ -652,7 +670,7 @@ class ApiRequest: def get_model_config( self, - model_name: str, + model_name: str = None, ) -> Dict: ''' 获取服务器上模型配置 @@ -662,6 +680,7 @@ class ApiRequest: } response = self.post( "/llm_model/get_model_config", + json=data, ) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))