From 42aa900566864f5fb5643d77032cf0a93bfd6695 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:55:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=B7=A5=E5=85=B7=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=EF=BC=9B=E6=B7=BB=E5=8A=A0=20openai=20=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E7=9A=84=E7=BB=9F=E4=B8=80=20chat=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=20(#3570)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复: - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回 - CallbackHandler 正确处理工具调用信息 - 重写 tool 定义方式: - 添加 regist_tool 简化 tool 定义: - 可以指定一个用户友好的名称 - 自动将函数的 __doc__ 作为 tool.description - 支持用 Field 定义参数,不再需要额外定义 ModelSchema - 添加 BaseToolOutput 封装 tool 返回结果,以便同时获取原始值、给LLM的字符串值 - 支持工具热加载(有待测试) - 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持: - Agent 对话 - 指定工具调用(如知识库RAG) - LLM 对话 - 根据后端功能更新 webui --- .../server/agent/agent_factory/qwen_agent.py | 3 +- .../agent/tools_factory/aqa_processor.py | 12 +- .../server/agent/tools_factory/arxiv.py | 6 +- .../server/agent/tools_factory/calculate.py | 10 +- .../agent/tools_factory/search_internet.py | 6 +- .../search_local_knowledgebase.py | 52 ++-- .../agent/tools_factory/search_youtube.py | 6 +- .../server/agent/tools_factory/shell.py | 6 +- .../server/agent/tools_factory/text2image.py | 6 +- .../agent/tools_factory/tools_registry.py | 44 ++- .../agent/tools_factory/vqa_processor.py | 10 +- .../agent/tools_factory/weather_check.py | 6 +- .../server/agent/tools_factory/wolfram.py | 4 +- .../chatchat/server/api_server/api_schemas.py | 75 +++++- .../chatchat/server/api_server/chat_routes.py | 120 ++++++++- .../server/api_server/openai_routes.py | 68 +++-- .../chatchat/server/api_server/tool_routes.py | 18 +- .../agent_callback_handler.py | 13 +- chatchat-server/chatchat/server/chat/chat.py | 85 ++++-- .../kb_service/es_kb_service.py | 6 +- .../kb_service/faiss_kb_service.py | 1 - .../chatchat/webui_pages/dialogue/dialogue.py | 252 ++++++++++-------- chatchat-server/chatchat/webui_pages/utils.py | 32 +++ 23 files changed, 614 insertions(+), 227 deletions(-) diff --git a/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py b/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py index 3847158e..621b4849 100644 --- a/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py +++ b/chatchat-server/chatchat/server/agent/agent_factory/qwen_agent.py @@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser): s = s[-1] return AgentFinish({"output": s}, log=text) else: - raise OutputParserException(f"Could not parse LLM output: {text}") + return AgentFinish({"output": text}, log=text) + # raise OutputParserException(f"Could not parse LLM output: {text}") @property def _type(self) -> str: diff --git a/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py b/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py index e5fec534..d6a1d70a 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py @@ -1,7 +1,7 @@ import base64 from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput def save_base64_audio(base64_audio, file_path): @@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query): return response -@regist_tool -def aqa_processor(query: str = Field(description="The question of the image in English")): +@regist_tool(title="音频问答") +def aqa_processor(query: str = Field(description="The question of the audio in English")): '''use this tool to get answer for audio question''' from chatchat.server.agent.container import container @@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E "audio": file_path, "text": query, } - return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) + ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) else: - return "No Audio, Please Try Again" + ret = "No Audio, Please Try Again" + + return BaseToolOutput(ret) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py b/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py index 8c0bc4f4..b5da1cf7 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py @@ -1,12 +1,12 @@ # LangChain 的 ArxivQueryRun 工具 from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput -@regist_tool +@regist_tool(title="ARXIV论文") def arxiv(query: str = Field(description="The search query title")): '''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.''' from langchain.tools.arxiv.tool import ArxivQueryRun tool = ArxivQueryRun() - return tool.run(tool_input=query) + return BaseToolOutput(tool.run(tool_input=query)) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/calculate.py b/chatchat-server/chatchat/server/agent/tools_factory/calculate.py index caec1536..69a0ae4a 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/calculate.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/calculate.py @@ -1,8 +1,8 @@ from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput -@regist_tool +@regist_tool(title="数学计算器") def calculate(text: str = Field(description="a math expression")) -> float: ''' Useful to answer questions about simple calculations. @@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float: import numexpr try: - return str(numexpr.evaluate(text)) + ret = str(numexpr.evaluate(text)) except Exception as e: - return f"wrong: {e}" + ret = f"wrong: {e}" + + return BaseToolOutput(ret) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py b/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py index 58624960..344e8901 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py @@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput def bing_search(text, config): @@ -95,8 +95,8 @@ def search_engine(query: str, return context -@regist_tool +@regist_tool(title="互联网搜索") def search_internet(query: str = Field(description="query for Internet search")): '''Use this tool to use bing search engine to search the internet and get information.''' tool_config = get_tool_config("search_internet") - return search_engine(query=query, config=tool_config) + return BaseToolOutput(search_engine(query=query, config=tool_config)) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 22fdb35e..1f83d94a 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,8 +1,9 @@ from urllib.parse import urlencode from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool -from chatchat.server.knowledge_base.kb_doc_api import search_docs +from .tools_registry import regist_tool, BaseToolOutput +from chatchat.server.knowledge_base.kb_api import list_kbs +from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.configs import KB_INFO @@ -11,35 +12,44 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") +class KBToolOutput(BaseToolOutput): + def __str__(self) -> str: + context = "" + docs = self.data + source_documents = [] + + for inum, doc in enumerate(docs): + doc = DocumentWithVSId.parse_obj(doc) + filename = doc.metadata.get("source") + parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename}) + url = f"download_doc?" + parameters + text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" + source_documents.append(text) + + if len(source_documents) == 0: + context = "没有找到相关文档,请更换关键词重试" + else: + for doc in source_documents: + context += doc + "\n" + + return context + + def search_knowledgebase(query: str, database: str, config: dict): docs = search_docs( query=query, knowledge_base_name=database, top_k=config["top_k"], score_threshold=config["score_threshold"]) - context = "" - source_documents = [] - for inum, doc in enumerate(docs): - filename = doc.metadata.get("source") - parameters = urlencode({"knowledge_base_name": database, "file_name": filename}) - url = f"download_doc?" + parameters - text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" - source_documents.append(text) - - if len(source_documents) == 0: - context = "没有找到相关文档,请更换关键词重试" - else: - for doc in source_documents: - context += doc + "\n" - - return context + return docs -@regist_tool(description=template_knowledge) +@regist_tool(description=template_knowledge, title="本地知识库") def search_local_knowledgebase( - database: str = Field(description="Database for Knowledge Search"), + database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data), query: str = Field(description="Query for Knowledge Search"), ): '''''' tool_config = get_tool_config("search_local_knowledgebase") - return search_knowledgebase(query=query, database=database, config=tool_config) + ret = search_knowledgebase(query=query, database=database, config=tool_config) + return BaseToolOutput(ret, database=database) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py b/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py index a706bef3..ab3cb03e 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py @@ -1,10 +1,10 @@ from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput -@regist_tool +@regist_tool(title="油管视频") def search_youtube(query: str = Field(description="Query for Videos search")): '''use this tools_factory to search youtube videos''' from langchain_community.tools import YouTubeSearchTool tool = YouTubeSearchTool() - return tool.run(tool_input=query) + return BaseToolOutput(tool.run(tool_input=query)) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/shell.py b/chatchat-server/chatchat/server/agent/tools_factory/shell.py index efae7cd0..910a0552 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/shell.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/shell.py @@ -2,11 +2,11 @@ from langchain.tools.shell import ShellTool from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput -@regist_tool +@regist_tool(title="系统命令") def shell(query: str = Field(description="The command to execute")): '''Use Shell to execute system shell commands''' tool = ShellTool() - return tool.run(tool_input=query) + return BaseToolOutput(tool.run(tool_input=query)) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/text2image.py b/chatchat-server/chatchat/server/agent/tools_factory/text2image.py index 29afff24..b190067b 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/text2image.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/text2image.py @@ -7,7 +7,7 @@ import uuid from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput import openai from chatchat.configs.basic_config import MEDIA_PATH @@ -26,7 +26,7 @@ def get_image_model_config() -> dict: return config -@regist_tool(return_direct=True) +@regist_tool(title="文生图", return_direct=True) def text2images( prompt: str, n: int = Field(1, description="需生成图片的数量"), @@ -56,7 +56,7 @@ def text2images( with open(os.path.join(MEDIA_PATH, filename), "wb") as fp: fp.write(base64.b64decode(x.b64_json)) images.append(filename) - return json.dumps({"message_type": MsgType.IMAGE, "images": images}) + return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json") if __name__ == "__main__": diff --git a/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py b/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py index 969e091d..bd2e661b 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py @@ -1,15 +1,22 @@ +import json import re from typing import Any, Union, Dict, Tuple, Callable, Optional, Type from langchain.agents import tool from langchain_core.tools import BaseTool -from chatchat.server.pydantic_v1 import BaseModel +from chatchat.server.pydantic_v1 import BaseModel, Extra + + +__all__ = ["regist_tool", "BaseToolOutput"] _TOOLS_REGISTRY = {} +# patch BaseTool to support extra fields e.g. a title +BaseTool.Config.extra = Extra.allow + ################################### TODO: workaround to langchain #15855 # patch BaseTool to support tool parameters defined using pydantic Field @@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs def regist_tool( *args: Any, + title: str = "", description: str = "", return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, @@ -70,9 +78,10 @@ def regist_tool( add tool to regstiry automatically ''' def _parse_tool(t: BaseTool): - nonlocal description + nonlocal description, title _TOOLS_REGISTRY[t.name] = t + # change default description if not description: if t.func is not None: @@ -80,6 +89,10 @@ def regist_tool( elif t.coroutine is not None: description = t.coroutine.__doc__ t.description = " ".join(re.split(r"\n+\s*", description)) + # set a default title for human + if not title: + title = "".join([x.capitalize() for x in t.name.split("_")]) + t.title = title def wrapper(def_func: Callable) -> BaseTool: partial_ = tool(*args, @@ -101,3 +114,30 @@ def regist_tool( ) _parse_tool(t) return t + + +class BaseToolOutput: + ''' + LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。 + 只需要将 Tool 返回值用该类封装,能同时满足两者的需要。 + 基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。 + 用户也可以继承该类定义自己的转换方法。 + ''' + def __init__( + self, + data: Any, + format: str="", + data_alias: str="", + **extras: Any, + ) -> None: + self.data = data + self.format = format + self.extras = extras + if data_alias: + setattr(self, data_alias, property(lambda obj: obj.data)) + + def __str__(self) -> str: + if self.format == "json": + return json.dumps(self.data, ensure_ascii=False, indent=2) + else: + return str(self.data) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py b/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py index 13b965ea..c1269c7c 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py @@ -6,7 +6,7 @@ from io import BytesIO from PIL import Image, ImageDraw from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput import re from chatchat.server.agent.container import container @@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m return response -@regist_tool +@regist_tool(title="图片对话") def vqa_processor(query: str = Field(description="The question of the image in English")): '''use this tool to get answer for image question''' @@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E # end_marker = "Grounded Operation:" # ans = extract_between_markers(ans, start_marker, end_marker) - return ans + ret = ans else: - return "No Image, Please Try Again" + ret = "No Image, Please Try Again" + + return BaseToolOutput(ret) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py b/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py index 967264aa..1f3cef1b 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py @@ -3,11 +3,11 @@ """ from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput import requests -@regist_tool +@regist_tool(title="天气查询") def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")): '''Use this tool to check the weather at a specific city''' @@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun "temperature": data["results"][0]["now"]["temperature"], "description": data["results"][0]["now"]["text"], } - return weather + return BaseToolOutput(weather) else: raise Exception( f"Failed to retrieve weather: {response.status_code}") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py b/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py index 8015ef76..0bfa2f71 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py @@ -2,7 +2,7 @@ from chatchat.server.pydantic_v1 import Field from chatchat.server.utils import get_tool_config -from .tools_registry import regist_tool +from .tools_registry import regist_tool, BaseToolOutput @regist_tool @@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")): wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid")) ans = wolfram.run(query) - return ans + return BaseToolOutput(ans) diff --git a/chatchat-server/chatchat/server/api_server/api_schemas.py b/chatchat-server/chatchat/server/api_server/api_schemas.py index 1db27e2a..973c4709 100644 --- a/chatchat-server/chatchat/server/api_server/api_schemas.py +++ b/chatchat-server/chatchat/server/api_server/api_schemas.py @@ -1,10 +1,10 @@ from __future__ import annotations -import re +import json +import time from typing import Dict, List, Literal, Optional, Union from fastapi import UploadFile -from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator from openai.types.chat import ( ChatCompletionMessageParam, ChatCompletionToolChoiceOptionParam, @@ -12,8 +12,10 @@ from openai.types.chat import ( completion_create_params, ) -from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG - +from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE +from chatchat.server.callback_handler.agent_callback_handler import AgentStatus +from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl +from chatchat.server.utils import MsgType class OpenAIBaseInput(BaseModel): user: Optional[str] = None @@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel): # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict] = None extra_query: Optional[Dict] = None - extra_body: Optional[Dict] = None + extra_json: Optional[Dict] = Field(None, alias="extra_body") timeout: Optional[float] = None class Config: @@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput): stop: Union[Optional[str], List[str]] = None stream: Optional[bool] = None temperature: Optional[float] = TEMPERATURE - tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None - tools: List[ChatCompletionToolParam] = None + tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None + tools: List[Union[ChatCompletionToolParam, str]] = None top_logprobs: Optional[int] = None top_p: Optional[float] = None @@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput): voice: str response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None speed: Optional[float] = None + + +class OpenAIBaseOutput(BaseModel): + id: Optional[str] = None + content: Optional[str] = None + model: Optional[str] = None + object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk" + role: Literal["assistant"] = "assistant" + finish_reason: Optional[str] = None + created: int = Field(default_factory=lambda : int(time.time())) + tool_calls: List[Dict] = [] + + status: Optional[int] = None # AgentStatus + message_type: int = MsgType.TEXT + message_id: Optional[str] = None # id in database table + is_ref: bool = False # wheather show in seperated expander + + class Config: + extra = "allow" + + def model_dump(self) -> dict: + result = { + "id": self.id, + "object": self.object, + "model": self.model, + "created": self.created, + + "status": self.status, + "message_type": self.message_type, + "message_id": self.message_id, + "is_ref": self.is_ref, + **(self.model_extra or {}), + } + + if self.object == "chat.completion.chunk": + result["choices"] = [{ + "delta": { + "content": self.content, + "tool_calls": self.tool_calls, + }, + "role": self.role, + }] + elif self.object == "chat.completion": + result["choices"] = [{ + "message": { + "role": self.role, + "content": self.content, + "finish_reason": self.finish_reason, + "tool_calls": self.tool_calls, + } + }] + return result + + def model_dump_json(self): + return json.dumps(self.model_dump(), ensure_ascii=False) + + +class OpenAIChatOutput(OpenAIBaseOutput): + ... diff --git a/chatchat-server/chatchat/server/api_server/chat_routes.py b/chatchat-server/chatchat/server/api_server/chat_routes.py index f954b8c6..11a561ff 100644 --- a/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -1,12 +1,17 @@ from __future__ import annotations -from typing import List +from typing import List, Dict from fastapi import APIRouter, Request +from langchain.prompts.prompt import PromptTemplate +from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus from chatchat.server.chat.chat import chat from chatchat.server.chat.feedback import chat_feedback from chatchat.server.chat.file_chat import file_chat +from chatchat.server.db.repository import add_message_to_db +from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template +from .openai_routes import openai_request chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"]) @@ -21,4 +26,115 @@ chat_router.post("/feedback", chat_router.post("/file_chat", summary="文件对话" - )(file_chat) \ No newline at end of file + )(file_chat) + + +@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") +async def chat_completions( + request: Request, + body: OpenAIChatInput, +) -> Dict: + ''' + 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 + tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 + 通过不同的参数组合调用不同的 chat 功能: + - tool_choice + - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) + - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice + - tools: agent 对话 + - 其它:LLM 对话 + 以后还要考虑其它的组合(如文件对话) + 返回与 openai 兼容的 Dict + ''' + client = get_OpenAIClient(model_name=body.model, is_async=True) + extra = {**body.model_extra} or {} + for key in list(extra): + delattr(body, key) + + # check tools & tool_choice in request body + if isinstance(body.tool_choice, str): + if t := get_tool(body.tool_choice): + body.tool_choice = {"function": {"name": t.name}, "type": "function"} + if isinstance(body.tools, list): + for i in range(len(body.tools)): + if isinstance(body.tools[i], str): + if t := get_tool(body.tools[i]): + body.tools[i] = { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": t.args, + } + } + + conversation_id = extra.get("conversation_id") + + # chat based on result from one choiced tool + if body.tool_choice: + tool = get_tool(body.tool_choice["function"]["name"]) + if not body.tools: + body.tools = [{ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.args, + } + }] + if tool_input := extra.get("tool_input"): + message_id = add_message_to_db( + chat_type="tool_call", + query=body.messages[-1]["content"], + conversation_id=conversation_id + ) if conversation_id else None + + tool_result = await tool.ainvoke(tool_input) + prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag")) + body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"]) + del body.tools + del body.tool_choice + extra_json = { + "message_id": message_id, + "status": None, + } + header = [{**extra_json, + "content": f"知识库参考资料:\n\n{tool_result}\n\n", + "tool_output":tool_result.data, + "is_ref": True, + }] + return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header) + + # agent chat with tool calls + if body.tools: + message_id = add_message_to_db( + chat_type="agent_chat", + query=body.messages[-1]["content"], + conversation_id=conversation_id + ) if conversation_id else None + + chat_model_config = {} # TODO: 前端支持配置模型 + tool_names = [x["function"]["name"] for x in body.tools] + tool_config = {name: get_tool_config(name) for name in tool_names} + result = await chat(query=body.messages[-1]["content"], + metadata=extra.get("metadata", {}), + conversation_id=extra.get("conversation_id", ""), + message_id=message_id, + history_len=-1, + history=body.messages[:-1], + stream=body.stream, + chat_model_config=extra.get("chat_model_config", chat_model_config), + tool_config=extra.get("tool_config", tool_config), + ) + return result + else: # LLM chat directly + message_id = add_message_to_db( + chat_type="llm_chat", + query=body.messages[-1]["content"], + conversation_id=conversation_id + ) if conversation_id else None + extra_json = { + "message_id": message_id, + "status": None, + } + return await openai_request(client.chat.completions.create, body, extra_json=extra_json) diff --git a/chatchat-server/chatchat/server/api_server/openai_routes.py b/chatchat-server/chatchat/server/api_server/openai_routes.py index 2f81e584..3e2a4d5a 100644 --- a/chatchat-server/chatchat/server/api_server/openai_routes.py +++ b/chatchat-server/chatchat/server/api_server/openai_routes.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager -from typing import Dict, Tuple, AsyncGenerator +from typing import Dict, Tuple, AsyncGenerator, Iterable from fastapi import APIRouter, Request from openai import AsyncClient @@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"] @asynccontextmanager -async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: +async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: ''' 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 ''' @@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: semaphore.release() -async def openai_request(method, body): +async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]): ''' - helper function to make openai request + helper function to make openai request with extra fields ''' async def generator(): - async for chunk in await method(**params): - yield {"data": chunk.json()} + for x in header: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {header}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + async for chunk in await method(**params): + for k, v in extra_json.items(): + setattr(chunk, k, v) + yield chunk.model_dump_json() + + for x in tail: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {tail}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + params = body.model_dump(exclude_unset=True) - params = body.dict(exclude_unset=True) if hasattr(body, "stream") and body.stream: return EventSourceResponse(generator()) else: - return (await method(**params)).dict() + result = await method(**params) + for k, v in extra_json.items(): + setattr(result, k, v) + return result.model_dump() @openai_router.get("/models") @@ -74,12 +102,12 @@ async def list_models() -> List: client = get_OpenAIClient(name, is_async=True) models = await client.models.list() if config.get("platform_type") == "xinference": - models = models.dict(exclude={"data":..., "object":...}) + models = models.model_dump(exclude={"data":..., "object":...}) for x in models: models[x]["platform_name"] = name return [{**v, "id": k} for k, v in models.items()] elif config.get("platform_type") == "oneapi": - return [{**x.dict(), "platform_name": name} for x in models.data] + return [{**x.model_dump(), "platform_name": name} for x in models.data] except Exception: logger.error(f"failed request to platform: {name}", exc_info=True) return {} @@ -97,12 +125,8 @@ async def create_chat_completions( request: Request, body: OpenAIChatInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: result = await openai_request(client.chat.completions.create, body) - # result["related_docs"] = ["doc1"] - # result["choices"][0]["message"]["related_docs"] = ["doc1"] - # print(result) - # breakpoint() return result @@ -111,7 +135,7 @@ async def create_completions( request: Request, body: OpenAIChatInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.completions.create, body) @@ -130,7 +154,7 @@ async def create_image_generations( request: Request, body: OpenAIImageGenerationsInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.images.generate, body) @@ -139,7 +163,7 @@ async def create_image_variations( request: Request, body: OpenAIImageVariationsInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.images.create_variation, body) @@ -148,7 +172,7 @@ async def create_image_edit( request: Request, body: OpenAIImageEditsInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.images.edit, body) @@ -157,7 +181,7 @@ async def create_audio_translations( request: Request, body: OpenAIAudioTranslationsInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.audio.translations.create, body) @@ -166,7 +190,7 @@ async def create_audio_transcriptions( request: Request, body: OpenAIAudioTranscriptionsInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.audio.transcriptions.create, body) @@ -175,7 +199,7 @@ async def create_audio_speech( request: Request, body: OpenAIAudioSpeechInput, ): - async with acquire_model_client(body.model) as client: + async with get_model_client(body.model) as client: return await openai_request(client.audio.speech.create, body) diff --git a/chatchat-server/chatchat/server/api_server/tool_routes.py b/chatchat-server/chatchat/server/api_server/tool_routes.py index 3fa68d86..7355d6b8 100644 --- a/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -5,7 +5,7 @@ from typing import List from fastapi import APIRouter, Request, Body from chatchat.configs import logger -from chatchat.server.utils import BaseResponse, get_tool +from chatchat.server.utils import BaseResponse, get_tool, get_tool_config tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) @@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) @tool_router.get("/", response_model=BaseResponse) async def list_tools(): tools = get_tool() - data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools} + data = {t.name: { + "name": t.name, + "title": t.title, + "description": t.description, + "args": t.args, + "config": get_tool_config(t.name), + } for t in tools.values()} return {"data": data} @tool_router.post("/call", response_model=BaseResponse) async def call_tool( name: str = Body(examples=["calculate"]), - kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]), + tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]), ): - tools = get_tool() - - if tool := tools.get(name): + if tool := get_tool(name): try: - result = await tool.ainvoke(kwargs) + result = await tool.ainvoke(tool_input) return {"data": result} except Exception: msg = f"failed to call tool '{name}'" diff --git a/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py b/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py index 015d608e..6f01662c 100644 --- a/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py +++ b/chatchat-server/chatchat/server/callback_handler/agent_callback_handler.py @@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - pass + data = { + "run_id": str(run_id), + "status": AgentStatus.tool_start, + "tool": serialized["name"], + "tool_input": input_str, + } + self.queue.put_nowait(dumps(data)) async def on_tool_end( @@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) -> None: """Run when tool ends running.""" data = { + "run_id": str(run_id), "status": AgentStatus.tool_end, "tool_output": output, } @@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) -> None: """Run when tool errors.""" data = { + "run_id": str(run_id), "status": AgentStatus.tool_end, - "text": str(error), + "tool_output": str(error), + "is_error": True, } # self.done.clear() self.queue.put_nowait(dumps(data)) diff --git a/chatchat-server/chatchat/server/chat/chat.py b/chatchat-server/chatchat/server/chat/chat.py index e7a4dbad..cc58ed6c 100644 --- a/chatchat-server/chatchat/server/chat/chat.py +++ b/chatchat-server/chatchat/server/chat/chat.py @@ -1,6 +1,8 @@ import asyncio import json +import time from typing import AsyncIterable, List +import uuid from fastapi import Body from sse_starlette.sse import EventSourceResponse @@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts import PromptTemplate + +from chatchat.configs.model_config import LLM_MODEL_CONFIG from chatchat.server.agent.agent_factory.agents_registry import agents_registry from chatchat.server.agent.container import container - +from chatchat.server.api_server.api_schemas import OpenAIChatOutput from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool from chatchat.server.chat.utils import History from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory -from chatchat.server.db.repository import add_message_to_db from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus def create_models_from_config(configs, callbacks, stream): - if configs is None: - configs = {} + configs = configs or LLM_MODEL_CONFIG models = {} prompts = {} for model_type, model_configs in configs.items(): @@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), conversation_id: str = Body("", description="对话框ID"), + message_id: str = Body(None, description="数据库消息ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: List[History] = Body( [], @@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), tool_config: dict = Body({}, description="工具配置", examples=[]), ): - async def chat_iterator() -> AsyncIterable[str]: - message_id = add_message_to_db( - chat_type="llm_chat", - query=query, - conversation_id=conversation_id - ) if conversation_id else None + '''Agent 对话''' + async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, @@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 } ), callback.done)) + last_tool = {} async for chunk in callback.aiter(): data = json.loads(chunk) - if data["status"] == AgentStatus.tool_end: + data["tool_calls"] = [] + data["message_type"] = MsgType.TEXT + + if data["status"] == AgentStatus.tool_start: + last_tool = { + "index": 0, + "id": data["run_id"], + "type": "function", + "function": { + "name": data["tool"], + "arguments": data["tool_input"], + }, + "tool_output": None, + "is_error": False, + } + data["tool_calls"].append(last_tool) + if data["status"] in [AgentStatus.tool_end]: + last_tool.update( + tool_output=data["tool_output"], + is_error=data.get("is_error", False) + ) + data["tool_calls"] = [last_tool] + last_tool = {} try: tool_output = json.loads(data["tool_output"]) if message_type := tool_output.get("message_type"): @@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 data["message_type"] = message_type except: ... - data.setdefault("message_type", MsgType.TEXT) - data["message_id"] = message_id - yield json.dumps(data, ensure_ascii=False) + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion.chunk", + content=data.get("text", ""), + role="assistant", + tool_calls=data["tool_calls"], + model=models["llm_model"].model_name, + status = data["status"], + message_type = data["message_type"], + message_id=message_id, + ) + yield ret.model_dump_json() await task - return EventSourceResponse(chat_iterator()) + if stream: + return EventSourceResponse(chat_iterator()) + else: + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion", + content="", + role="assistant", + finish_reason="stop", + tool_calls=[], + status = AgentStatus.agent_finish, + message_type = MsgType.TEXT, + message_id=message_id, + ) + + async for chunk in chat_iterator(): + data = json.loads(chunk) + if text := data["choices"][0]["delta"]["content"]: + ret.content += text + if data["status"] == AgentStatus.tool_end: + ret.tool_calls += data["choices"][0]["delta"]["tool_calls"] + ret.model = data["model"] + ret.created = data["created"] + + return ret.model_dump() diff --git a/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index e69a71b1..9b65d81e 100644 --- a/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -96,11 +96,7 @@ class ESKBService(KBService): return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store") def do_create_kb(self): - if os.path.exists(self.doc_path): - if not os.path.exists(os.path.join(self.kb_path, "vector_store")): - os.makedirs(os.path.join(self.kb_path, "vector_store")) - else: - logger.warning("directory `vector_store` already exists.") + ... def vs_type(self) -> str: return SupportedVSType.ES diff --git a/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 51196153..95f7cd64 100644 --- a/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -65,7 +65,6 @@ class FaissKBService(KBService): embed_func = get_Embeddings(self.embed_model) embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - embeddings = vs.embeddings.embed_query(query) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) return docs diff --git a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index 52a5bf46..4a9b3efb 100644 --- a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -1,24 +1,25 @@ import base64 - -from chatchat.server.utils import get_tool_config -import streamlit as st -from streamlit_antd_components.utils import ParseItems - -from chatchat.webui_pages.dialogue.utils import process_files -# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \ -# get_select_model_endpoint -from chatchat.webui_pages.utils import * -from streamlit_chatbox import * -from streamlit_modal import Modal -from datetime import datetime +import uuid import os import re import time +from typing import List, Dict + +import streamlit as st +from streamlit_antd_components.utils import ParseItems + +import openai +from streamlit_chatbox import * +from streamlit_modal import Modal +from datetime import datetime + from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS) from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.utils import MsgType, get_config_models -import uuid -from typing import List, Dict +from chatchat.server.utils import get_tool_config +from chatchat.webui_pages.utils import * +from chatchat.webui_pages.dialogue.utils import process_files + img_dir = (Path(__file__).absolute().parent.parent.parent) @@ -121,69 +122,66 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.write("\n\n".join(cmds)) with st.sidebar: - conv_names = list(st.session_state["conversation_ids"].keys()) - index = 0 + tab1, tab2 = st.tabs(["对话设置", "模型设置"]) - if st.session_state.get("cur_conv_name") in conv_names: - index = conv_names.index(st.session_state.get("cur_conv_name")) - conversation_name = st.selectbox("当前会话", conv_names, index=index) - chat_box.use_chat_name(conversation_name) - conversation_id = st.session_state["conversation_ids"][conversation_name] + with tab1: + use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力") + # 选择工具 + tools = api.list_tools() + if use_agent: + selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"]) + else: + selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"]) + selected_tools = [selected_tool] + selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools} - platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] - platform = st.selectbox("选择模型平台", platforms) - llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform)) - llm_model = st.selectbox("选择LLM模型", llm_models) + # 当不启用Agent时,手动生成工具参数 + # TODO: 需要更精细的控制控件 + tool_input = {} + if not use_agent and len(selected_tools) == 1: + with st.expander("工具参数", True): + for k, v in tools[selected_tools[0]]["args"].items(): + if choices := v.get("choices", v.get("enum")): + tool_input[k] = st.selectbox(v["title"], choices) + else: + if v["type"] == "integer": + tool_input[k] = st.slider(v["title"], value=v.get("default")) + elif v["type"] == "number": + tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1) + else: + tool_input[k] = st.text_input(v["title"], v.get("default")) - # 传入后端的内容 - chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} - tool_use = True - for key in LLM_MODEL_CONFIG: - if key == 'llm_model': - continue - if key == 'action_model': - first_key = next(iter(LLM_MODEL_CONFIG[key])) - if first_key not in SUPPORT_AGENT_MODELS: - st.warning("不支持Agent的模型,无法执行任何工具调用") - tool_use = False - continue - if LLM_MODEL_CONFIG[key]: - first_key = next(iter(LLM_MODEL_CONFIG[key])) - chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] - # 选择工具 - selected_tool_configs = {} - if tool_use: - from chatchat.configs import model_config as model_config_py - import importlib - importlib.reload(model_config_py) + uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) + files_upload = process_files(files=[uploaded_file]) if uploaded_file else None + + with tab2: + # 会话 + conv_names = list(st.session_state["conversation_ids"].keys()) + index = 0 + if st.session_state.get("cur_conv_name") in conv_names: + index = conv_names.index(st.session_state.get("cur_conv_name")) + conversation_name = st.selectbox("当前会话", conv_names, index=index) + chat_box.use_chat_name(conversation_name) + conversation_id = st.session_state["conversation_ids"][conversation_name] - tools = get_tool_config() - with st.expander("工具栏"): - for tool in tools: - is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool) - if is_selected: - selected_tool_configs[tool] = tools[tool] + # 模型 + platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] + platform = st.selectbox("选择模型平台", platforms) + llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform)) + llm_model = st.selectbox("选择LLM模型", llm_models) - if llm_model is not None: - chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) + # 传入后端的内容 + chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} + for key in LLM_MODEL_CONFIG: + if LLM_MODEL_CONFIG[key]: + first_key = next(iter(LLM_MODEL_CONFIG[key])) + chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] - uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) - files_upload = process_files(files=[uploaded_file]) if uploaded_file else None - # print(len(files_upload["audios"])) if files_upload else None + if llm_model is not None: + chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) - # if dialogue_mode == "文件对话": - # with st.expander("文件对话配置", True): - # files = st.file_uploader("上传知识文件:", - # [i for ls in LOADER_DICT.values() for i in ls], - # accept_multiple_files=True, - # ) - # kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) - # score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) - # if st.button("开始上传", disabled=len(files) == 0): - # st.session_state["file_chat_id"] = upload_temp_docs(files, api) # Display chat messages from history on app rerun - chat_box.output_messages() chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " @@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.ai_say("正在思考...") text = "" - text_action = "" - element_index = 0 + started = False - for d in api.chat_chat(query=prompt, - metadata=files_upload, - history=history, - chat_model_config=chat_model_config, - conversation_id=conversation_id, - tool_config=selected_tool_configs, - ): - message_id = d.get("message_id", "") + client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE") + messages = history + [{"role": "user", "content": prompt}] + tools = list(selected_tool_configs) + if len(selected_tools) == 1: + tool_choice = selected_tools[0] + else: + tool_choice = None + # 如果 tool_input 中有空的字段,设为用户输入 + for k in tool_input: + if tool_input[k] in [None, ""]: + tool_input[k] = prompt + + extra_body = dict( + metadata=files_upload, + chat_model_config=chat_model_config, + conversation_id=conversation_id, + tool_input = tool_input, + ) + for d in client.chat.completions.create( + messages=messages, + model=llm_model, + stream=True, + tools=tools, + tool_choice=tool_choice, + extra_body=extra_body, + ): + print("\n\n", d.status, "\n", d, "\n\n") + message_id = d.message_id metadata = { "message_id": message_id, } - print(d) - if d["status"] == AgentStatus.error: - st.error(d["text"]) - elif d["status"] == AgentStatus.agent_action: - formatted_data = { - "Function": d["tool_name"], - "function_input": d["tool_input"] - } - element_index += 1 - formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) - chat_box.insert_msg( - Markdown(title="Function call", in_expander=True, expanded=True, state="running")) - text = """\n```{}\n```\n""".format(formatted_json) - chat_box.update_msg(Markdown(text), element_index=element_index) - elif d["status"] == AgentStatus.tool_end: - text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"]) - chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete") - elif d["status"] == AgentStatus.llm_new_token: - text += d["text"] - chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata) - elif d["status"] == AgentStatus.llm_end: - chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) - elif d["status"] == AgentStatus.agent_finish: - if d["message_type"] == MsgType.IMAGE: - for url in json.loads(d["text"]).get("images", []): - url = f"{api.base_url}/media/{url}" - chat_box.insert_msg(Image(url)) - chat_box.update_msg(element_index=element_index, expanded=False, state="complete") + + if d.status == AgentStatus.error: + st.error(d.choices[0].delta.content) + elif d.status == AgentStatus.llm_start: + if not started: + started = True else: - chat_box.insert_msg(Markdown(d["text"], expanded=True)) + chat_box.insert_msg("正在解读工具输出结果...") + text = d.choices[0].delta.content or "" + elif d.status == AgentStatus.llm_new_token: + text += d.choices[0].delta.content or "" + chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata) + elif d.status == AgentStatus.llm_end: + text += d.choices[0].delta.content or "" + chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata) + # tool 的输出与 llm 输出重复了 + # elif d.status == AgentStatus.tool_start: + # formatted_data = { + # "Function": d.choices[0].delta.tool_calls[0].function.name, + # "function_input": d.choices[0].delta.tool_calls[0].function.arguments, + # } + # formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) + # text = """\n```{}\n```\n""".format(formatted_json) + # chat_box.insert_msg( # TODO: insert text directly not shown + # Markdown(text, title="Function call", in_expander=True, expanded=True, state="running")) + # elif d.status == AgentStatus.tool_end: + # tool_output = d.choices[0].delta.tool_calls[0].tool_output + # if d.message_type == MsgType.IMAGE: + # for url in json.loads(tool_output).get("images", []): + # url = f"{api.base_url}/media/{url}" + # chat_box.insert_msg(Image(url)) + # chat_box.update_msg(expanded=False, state="complete") + # else: + # text += """\n```\nObservation:\n{}\n```\n""".format(tool_output) + # chat_box.update_msg(text, streaming=False, expanded=False, state="complete") + elif d.status == AgentStatus.agent_finish: + text = d.choices[0].delta.content or "" + chat_box.update_msg(text.replace("\n", "\n\n")) + elif d.status == None: # not agent chat + if getattr(d, "is_ref", False): + chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料")) + chat_box.insert_msg("") + else: + text += d.choices[0].delta.content or "" + chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata) + chat_box.update_msg(text, streaming=False, metadata=metadata) if os.path.exists("tmp/image.jpg"): with open("tmp/image.jpg", "rb") as image_file: @@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.rerun() now = datetime.now() - with st.sidebar: - + with tab1: cols = st.columns(2) export_btn = cols[0] if cols[1].button( @@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): mime="text/markdown", use_container_width=True, ) + + # st.write(chat_box.history) diff --git a/chatchat-server/chatchat/webui_pages/utils.py b/chatchat-server/chatchat/webui_pages/utils.py index 4f9c7522..53fdb2fc 100644 --- a/chatchat-server/chatchat/webui_pages/utils.py +++ b/chatchat-server/chatchat/webui_pages/utils.py @@ -50,6 +50,14 @@ class ApiRequest: timeout=self.timeout) return self._client + def _check_url(self, url: str) -> str: + ''' + 新版 httpx 强制要求 url 以 / 结尾,否则会返回 307 + ''' + if not url.endswith("/"): + url = url + "/" + return url + def get( self, url: str, @@ -58,6 +66,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: + url = self._check_url(url) while retry > 0: try: if stream: @@ -79,6 +88,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: + url = self._check_url(url) while retry > 0: try: # print(kwargs) @@ -101,6 +111,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: + url = self._check_url(url) while retry > 0: try: if stream: @@ -638,6 +649,27 @@ class ApiRequest: resp = self.post("/chat/feedback", json=data) return self._get_response_value(resp) + def list_tools(self) -> Dict: + ''' + 列出所有工具 + ''' + resp = self.get("/tools") + return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {})) + + def call_tool( + self, + name: str, + tool_input: Dict = {}, + ): + ''' + 调用工具 + ''' + data = { + "name": name, + "tool_input": tool_input, + } + resp = self.post("/tools/call", json=data) + return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) class AsyncApiRequest(ApiRequest): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):