mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 16:10:18 +08:00
优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)
- 修复:
- Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
- CallbackHandler 正确处理工具调用信息
- 重写 tool 定义方式:
- 添加 regist_tool 简化 tool 定义:
- 可以指定一个用户友好的名称
- 自动将函数的 __doc__ 作为 tool.description
- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
- 添加 BaseToolOutput 封装 tool 返回结果,以便同时获取原始值、给LLM的字符串值
- 支持工具热加载(有待测试)
- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
- Agent 对话
- 指定工具调用(如知识库RAG)
- LLM 对话
- 根据后端功能更新 webui
This commit is contained in:
parent
5e70aff522
commit
42aa900566
@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
|
|||||||
s = s[-1]
|
s = s[-1]
|
||||||
return AgentFinish({"output": s}, log=text)
|
return AgentFinish({"output": s}, log=text)
|
||||||
else:
|
else:
|
||||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
return AgentFinish({"output": text}, log=text)
|
||||||
|
# raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
def save_base64_audio(base64_audio, file_path):
|
def save_base64_audio(base64_audio, file_path):
|
||||||
@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="音频问答")
|
||||||
def aqa_processor(query: str = Field(description="The question of the image in English")):
|
def aqa_processor(query: str = Field(description="The question of the audio in English")):
|
||||||
'''use this tool to get answer for audio question'''
|
'''use this tool to get answer for audio question'''
|
||||||
|
|
||||||
from chatchat.server.agent.container import container
|
from chatchat.server.agent.container import container
|
||||||
@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E
|
|||||||
"audio": file_path,
|
"audio": file_path,
|
||||||
"text": query,
|
"text": query,
|
||||||
}
|
}
|
||||||
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
||||||
else:
|
else:
|
||||||
return "No Audio, Please Try Again"
|
ret = "No Audio, Please Try Again"
|
||||||
|
|
||||||
|
return BaseToolOutput(ret)
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
# LangChain 的 ArxivQueryRun 工具
|
# LangChain 的 ArxivQueryRun 工具
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="ARXIV论文")
|
||||||
def arxiv(query: str = Field(description="The search query title")):
|
def arxiv(query: str = Field(description="The search query title")):
|
||||||
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
|
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
|
||||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||||
|
|
||||||
tool = ArxivQueryRun()
|
tool = ArxivQueryRun()
|
||||||
return tool.run(tool_input=query)
|
return BaseToolOutput(tool.run(tool_input=query))
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="数学计算器")
|
||||||
def calculate(text: str = Field(description="a math expression")) -> float:
|
def calculate(text: str = Field(description="a math expression")) -> float:
|
||||||
'''
|
'''
|
||||||
Useful to answer questions about simple calculations.
|
Useful to answer questions about simple calculations.
|
||||||
@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float:
|
|||||||
import numexpr
|
import numexpr
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return str(numexpr.evaluate(text))
|
ret = str(numexpr.evaluate(text))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"wrong: {e}"
|
ret = f"wrong: {e}"
|
||||||
|
|
||||||
|
return BaseToolOutput(ret)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
|||||||
|
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
def bing_search(text, config):
|
def bing_search(text, config):
|
||||||
@ -95,8 +95,8 @@ def search_engine(query: str,
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="互联网搜索")
|
||||||
def search_internet(query: str = Field(description="query for Internet search")):
|
def search_internet(query: str = Field(description="query for Internet search")):
|
||||||
'''Use this tool to use bing search engine to search the internet and get information.'''
|
'''Use this tool to use bing search engine to search the internet and get information.'''
|
||||||
tool_config = get_tool_config("search_internet")
|
tool_config = get_tool_config("search_internet")
|
||||||
return search_engine(query=query, config=tool_config)
|
return BaseToolOutput(search_engine(query=query, config=tool_config))
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs
|
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||||
|
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||||
from chatchat.configs import KB_INFO
|
from chatchat.configs import KB_INFO
|
||||||
|
|
||||||
|
|
||||||
@ -11,17 +12,16 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
|||||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||||
|
|
||||||
|
|
||||||
def search_knowledgebase(query: str, database: str, config: dict):
|
class KBToolOutput(BaseToolOutput):
|
||||||
docs = search_docs(
|
def __str__(self) -> str:
|
||||||
query=query,
|
|
||||||
knowledge_base_name=database,
|
|
||||||
top_k=config["top_k"],
|
|
||||||
score_threshold=config["score_threshold"])
|
|
||||||
context = ""
|
context = ""
|
||||||
|
docs = self.data
|
||||||
source_documents = []
|
source_documents = []
|
||||||
|
|
||||||
for inum, doc in enumerate(docs):
|
for inum, doc in enumerate(docs):
|
||||||
|
doc = DocumentWithVSId.parse_obj(doc)
|
||||||
filename = doc.metadata.get("source")
|
filename = doc.metadata.get("source")
|
||||||
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
|
parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename})
|
||||||
url = f"download_doc?" + parameters
|
url = f"download_doc?" + parameters
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
source_documents.append(text)
|
source_documents.append(text)
|
||||||
@ -35,11 +35,21 @@ def search_knowledgebase(query: str, database: str, config: dict):
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
@regist_tool(description=template_knowledge)
|
def search_knowledgebase(query: str, database: str, config: dict):
|
||||||
|
docs = search_docs(
|
||||||
|
query=query,
|
||||||
|
knowledge_base_name=database,
|
||||||
|
top_k=config["top_k"],
|
||||||
|
score_threshold=config["score_threshold"])
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
@regist_tool(description=template_knowledge, title="本地知识库")
|
||||||
def search_local_knowledgebase(
|
def search_local_knowledgebase(
|
||||||
database: str = Field(description="Database for Knowledge Search"),
|
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
|
||||||
query: str = Field(description="Query for Knowledge Search"),
|
query: str = Field(description="Query for Knowledge Search"),
|
||||||
):
|
):
|
||||||
''''''
|
''''''
|
||||||
tool_config = get_tool_config("search_local_knowledgebase")
|
tool_config = get_tool_config("search_local_knowledgebase")
|
||||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
|
return BaseToolOutput(ret, database=database)
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="油管视频")
|
||||||
def search_youtube(query: str = Field(description="Query for Videos search")):
|
def search_youtube(query: str = Field(description="Query for Videos search")):
|
||||||
'''use this tools_factory to search youtube videos'''
|
'''use this tools_factory to search youtube videos'''
|
||||||
from langchain_community.tools import YouTubeSearchTool
|
from langchain_community.tools import YouTubeSearchTool
|
||||||
tool = YouTubeSearchTool()
|
tool = YouTubeSearchTool()
|
||||||
return tool.run(tool_input=query)
|
return BaseToolOutput(tool.run(tool_input=query))
|
||||||
|
|||||||
@ -2,11 +2,11 @@
|
|||||||
from langchain.tools.shell import ShellTool
|
from langchain.tools.shell import ShellTool
|
||||||
|
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="系统命令")
|
||||||
def shell(query: str = Field(description="The command to execute")):
|
def shell(query: str = Field(description="The command to execute")):
|
||||||
'''Use Shell to execute system shell commands'''
|
'''Use Shell to execute system shell commands'''
|
||||||
tool = ShellTool()
|
tool = ShellTool()
|
||||||
return tool.run(tool_input=query)
|
return BaseToolOutput(tool.run(tool_input=query))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import uuid
|
|||||||
|
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from chatchat.configs.basic_config import MEDIA_PATH
|
from chatchat.configs.basic_config import MEDIA_PATH
|
||||||
@ -26,7 +26,7 @@ def get_image_model_config() -> dict:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@regist_tool(return_direct=True)
|
@regist_tool(title="文生图", return_direct=True)
|
||||||
def text2images(
|
def text2images(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
n: int = Field(1, description="需生成图片的数量"),
|
n: int = Field(1, description="需生成图片的数量"),
|
||||||
@ -56,7 +56,7 @@ def text2images(
|
|||||||
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
|
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
|
||||||
fp.write(base64.b64decode(x.b64_json))
|
fp.write(base64.b64decode(x.b64_json))
|
||||||
images.append(filename)
|
images.append(filename)
|
||||||
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
|
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,15 +1,22 @@
|
|||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
|
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
|
||||||
|
|
||||||
from langchain.agents import tool
|
from langchain.agents import tool
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from chatchat.server.pydantic_v1 import BaseModel
|
from chatchat.server.pydantic_v1 import BaseModel, Extra
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["regist_tool", "BaseToolOutput"]
|
||||||
|
|
||||||
|
|
||||||
_TOOLS_REGISTRY = {}
|
_TOOLS_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
|
# patch BaseTool to support extra fields e.g. a title
|
||||||
|
BaseTool.Config.extra = Extra.allow
|
||||||
|
|
||||||
################################### TODO: workaround to langchain #15855
|
################################### TODO: workaround to langchain #15855
|
||||||
# patch BaseTool to support tool parameters defined using pydantic Field
|
# patch BaseTool to support tool parameters defined using pydantic Field
|
||||||
|
|
||||||
@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
|
|||||||
|
|
||||||
def regist_tool(
|
def regist_tool(
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
title: str = "",
|
||||||
description: str = "",
|
description: str = "",
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
@ -70,9 +78,10 @@ def regist_tool(
|
|||||||
add tool to regstiry automatically
|
add tool to regstiry automatically
|
||||||
'''
|
'''
|
||||||
def _parse_tool(t: BaseTool):
|
def _parse_tool(t: BaseTool):
|
||||||
nonlocal description
|
nonlocal description, title
|
||||||
|
|
||||||
_TOOLS_REGISTRY[t.name] = t
|
_TOOLS_REGISTRY[t.name] = t
|
||||||
|
|
||||||
# change default description
|
# change default description
|
||||||
if not description:
|
if not description:
|
||||||
if t.func is not None:
|
if t.func is not None:
|
||||||
@ -80,6 +89,10 @@ def regist_tool(
|
|||||||
elif t.coroutine is not None:
|
elif t.coroutine is not None:
|
||||||
description = t.coroutine.__doc__
|
description = t.coroutine.__doc__
|
||||||
t.description = " ".join(re.split(r"\n+\s*", description))
|
t.description = " ".join(re.split(r"\n+\s*", description))
|
||||||
|
# set a default title for human
|
||||||
|
if not title:
|
||||||
|
title = "".join([x.capitalize() for x in t.name.split("_")])
|
||||||
|
t.title = title
|
||||||
|
|
||||||
def wrapper(def_func: Callable) -> BaseTool:
|
def wrapper(def_func: Callable) -> BaseTool:
|
||||||
partial_ = tool(*args,
|
partial_ = tool(*args,
|
||||||
@ -101,3 +114,30 @@ def regist_tool(
|
|||||||
)
|
)
|
||||||
_parse_tool(t)
|
_parse_tool(t)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolOutput:
|
||||||
|
'''
|
||||||
|
LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。
|
||||||
|
只需要将 Tool 返回值用该类封装,能同时满足两者的需要。
|
||||||
|
基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。
|
||||||
|
用户也可以继承该类定义自己的转换方法。
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
format: str="",
|
||||||
|
data_alias: str="",
|
||||||
|
**extras: Any,
|
||||||
|
) -> None:
|
||||||
|
self.data = data
|
||||||
|
self.format = format
|
||||||
|
self.extras = extras
|
||||||
|
if data_alias:
|
||||||
|
setattr(self, data_alias, property(lambda obj: obj.data))
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.format == "json":
|
||||||
|
return json.dumps(self.data, ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
return str(self.data)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from io import BytesIO
|
|||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
import re
|
import re
|
||||||
from chatchat.server.agent.container import container
|
from chatchat.server.agent.container import container
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="图片对话")
|
||||||
def vqa_processor(query: str = Field(description="The question of the image in English")):
|
def vqa_processor(query: str = Field(description="The question of the image in English")):
|
||||||
'''use this tool to get answer for image question'''
|
'''use this tool to get answer for image question'''
|
||||||
|
|
||||||
@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E
|
|||||||
# end_marker = "Grounded Operation:"
|
# end_marker = "Grounded Operation:"
|
||||||
# ans = extract_between_markers(ans, start_marker, end_marker)
|
# ans = extract_between_markers(ans, start_marker, end_marker)
|
||||||
|
|
||||||
return ans
|
ret = ans
|
||||||
else:
|
else:
|
||||||
return "No Image, Please Try Again"
|
ret = "No Image, Please Try Again"
|
||||||
|
|
||||||
|
return BaseToolOutput(ret)
|
||||||
|
|||||||
@ -3,11 +3,11 @@
|
|||||||
"""
|
"""
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool(title="天气查询")
|
||||||
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
|
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
|
||||||
'''Use this tool to check the weather at a specific city'''
|
'''Use this tool to check the weather at a specific city'''
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun
|
|||||||
"temperature": data["results"][0]["now"]["temperature"],
|
"temperature": data["results"][0]["now"]["temperature"],
|
||||||
"description": data["results"][0]["now"]["text"],
|
"description": data["results"][0]["now"]["text"],
|
||||||
}
|
}
|
||||||
return weather
|
return BaseToolOutput(weather)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to retrieve weather: {response.status_code}")
|
f"Failed to retrieve weather: {response.status_code}")
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from .tools_registry import regist_tool
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
|
||||||
|
|
||||||
@regist_tool
|
@regist_tool
|
||||||
@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")):
|
|||||||
|
|
||||||
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
|
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
|
||||||
ans = wolfram.run(query)
|
ans = wolfram.run(query)
|
||||||
return ans
|
return BaseToolOutput(ans)
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import json
|
||||||
|
import time
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionToolChoiceOptionParam,
|
ChatCompletionToolChoiceOptionParam,
|
||||||
@ -12,8 +12,10 @@ from openai.types.chat import (
|
|||||||
completion_create_params,
|
completion_create_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
|
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
|
||||||
|
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||||
|
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
|
||||||
|
from chatchat.server.utils import MsgType
|
||||||
|
|
||||||
class OpenAIBaseInput(BaseModel):
|
class OpenAIBaseInput(BaseModel):
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel):
|
|||||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||||
extra_headers: Optional[Dict] = None
|
extra_headers: Optional[Dict] = None
|
||||||
extra_query: Optional[Dict] = None
|
extra_query: Optional[Dict] = None
|
||||||
extra_body: Optional[Dict] = None
|
extra_json: Optional[Dict] = Field(None, alias="extra_body")
|
||||||
timeout: Optional[float] = None
|
timeout: Optional[float] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput):
|
|||||||
stop: Union[Optional[str], List[str]] = None
|
stop: Union[Optional[str], List[str]] = None
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
temperature: Optional[float] = TEMPERATURE
|
temperature: Optional[float] = TEMPERATURE
|
||||||
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
|
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
|
||||||
tools: List[ChatCompletionToolParam] = None
|
tools: List[Union[ChatCompletionToolParam, str]] = None
|
||||||
top_logprobs: Optional[int] = None
|
top_logprobs: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
|
|||||||
voice: str
|
voice: str
|
||||||
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
|
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
|
||||||
speed: Optional[float] = None
|
speed: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIBaseOutput(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
created: int = Field(default_factory=lambda : int(time.time()))
|
||||||
|
tool_calls: List[Dict] = []
|
||||||
|
|
||||||
|
status: Optional[int] = None # AgentStatus
|
||||||
|
message_type: int = MsgType.TEXT
|
||||||
|
message_id: Optional[str] = None # id in database table
|
||||||
|
is_ref: bool = False # wheather show in seperated expander
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def model_dump(self) -> dict:
|
||||||
|
result = {
|
||||||
|
"id": self.id,
|
||||||
|
"object": self.object,
|
||||||
|
"model": self.model,
|
||||||
|
"created": self.created,
|
||||||
|
|
||||||
|
"status": self.status,
|
||||||
|
"message_type": self.message_type,
|
||||||
|
"message_id": self.message_id,
|
||||||
|
"is_ref": self.is_ref,
|
||||||
|
**(self.model_extra or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.object == "chat.completion.chunk":
|
||||||
|
result["choices"] = [{
|
||||||
|
"delta": {
|
||||||
|
"content": self.content,
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
},
|
||||||
|
"role": self.role,
|
||||||
|
}]
|
||||||
|
elif self.object == "chat.completion":
|
||||||
|
result["choices"] = [{
|
||||||
|
"message": {
|
||||||
|
"role": self.role,
|
||||||
|
"content": self.content,
|
||||||
|
"finish_reason": self.finish_reason,
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def model_dump_json(self):
|
||||||
|
return json.dumps(self.model_dump(), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatOutput(OpenAIBaseOutput):
|
||||||
|
...
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
|
||||||
from chatchat.server.chat.chat import chat
|
from chatchat.server.chat.chat import chat
|
||||||
from chatchat.server.chat.feedback import chat_feedback
|
from chatchat.server.chat.feedback import chat_feedback
|
||||||
from chatchat.server.chat.file_chat import file_chat
|
from chatchat.server.chat.file_chat import file_chat
|
||||||
|
from chatchat.server.db.repository import add_message_to_db
|
||||||
|
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
|
||||||
|
from .openai_routes import openai_request
|
||||||
|
|
||||||
|
|
||||||
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
|
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
|
||||||
@ -22,3 +27,114 @@ chat_router.post("/feedback",
|
|||||||
chat_router.post("/file_chat",
|
chat_router.post("/file_chat",
|
||||||
summary="文件对话"
|
summary="文件对话"
|
||||||
)(file_chat)
|
)(file_chat)
|
||||||
|
|
||||||
|
|
||||||
|
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||||
|
async def chat_completions(
|
||||||
|
request: Request,
|
||||||
|
body: OpenAIChatInput,
|
||||||
|
) -> Dict:
|
||||||
|
'''
|
||||||
|
请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数
|
||||||
|
tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换
|
||||||
|
通过不同的参数组合调用不同的 chat 功能:
|
||||||
|
- tool_choice
|
||||||
|
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
|
||||||
|
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
|
||||||
|
- tools: agent 对话
|
||||||
|
- 其它:LLM 对话
|
||||||
|
以后还要考虑其它的组合(如文件对话)
|
||||||
|
返回与 openai 兼容的 Dict
|
||||||
|
'''
|
||||||
|
client = get_OpenAIClient(model_name=body.model, is_async=True)
|
||||||
|
extra = {**body.model_extra} or {}
|
||||||
|
for key in list(extra):
|
||||||
|
delattr(body, key)
|
||||||
|
|
||||||
|
# check tools & tool_choice in request body
|
||||||
|
if isinstance(body.tool_choice, str):
|
||||||
|
if t := get_tool(body.tool_choice):
|
||||||
|
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
|
||||||
|
if isinstance(body.tools, list):
|
||||||
|
for i in range(len(body.tools)):
|
||||||
|
if isinstance(body.tools[i], str):
|
||||||
|
if t := get_tool(body.tools[i]):
|
||||||
|
body.tools[i] = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"parameters": t.args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_id = extra.get("conversation_id")
|
||||||
|
|
||||||
|
# chat based on result from one choiced tool
|
||||||
|
if body.tool_choice:
|
||||||
|
tool = get_tool(body.tool_choice["function"]["name"])
|
||||||
|
if not body.tools:
|
||||||
|
body.tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.args,
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
if tool_input := extra.get("tool_input"):
|
||||||
|
message_id = add_message_to_db(
|
||||||
|
chat_type="tool_call",
|
||||||
|
query=body.messages[-1]["content"],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
) if conversation_id else None
|
||||||
|
|
||||||
|
tool_result = await tool.ainvoke(tool_input)
|
||||||
|
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
|
||||||
|
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
|
||||||
|
del body.tools
|
||||||
|
del body.tool_choice
|
||||||
|
extra_json = {
|
||||||
|
"message_id": message_id,
|
||||||
|
"status": None,
|
||||||
|
}
|
||||||
|
header = [{**extra_json,
|
||||||
|
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
|
||||||
|
"tool_output":tool_result.data,
|
||||||
|
"is_ref": True,
|
||||||
|
}]
|
||||||
|
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
|
||||||
|
|
||||||
|
# agent chat with tool calls
|
||||||
|
if body.tools:
|
||||||
|
message_id = add_message_to_db(
|
||||||
|
chat_type="agent_chat",
|
||||||
|
query=body.messages[-1]["content"],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
) if conversation_id else None
|
||||||
|
|
||||||
|
chat_model_config = {} # TODO: 前端支持配置模型
|
||||||
|
tool_names = [x["function"]["name"] for x in body.tools]
|
||||||
|
tool_config = {name: get_tool_config(name) for name in tool_names}
|
||||||
|
result = await chat(query=body.messages[-1]["content"],
|
||||||
|
metadata=extra.get("metadata", {}),
|
||||||
|
conversation_id=extra.get("conversation_id", ""),
|
||||||
|
message_id=message_id,
|
||||||
|
history_len=-1,
|
||||||
|
history=body.messages[:-1],
|
||||||
|
stream=body.stream,
|
||||||
|
chat_model_config=extra.get("chat_model_config", chat_model_config),
|
||||||
|
tool_config=extra.get("tool_config", tool_config),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else: # LLM chat directly
|
||||||
|
message_id = add_message_to_db(
|
||||||
|
chat_type="llm_chat",
|
||||||
|
query=body.messages[-1]["content"],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
) if conversation_id else None
|
||||||
|
extra_json = {
|
||||||
|
"message_id": message_id,
|
||||||
|
"status": None,
|
||||||
|
}
|
||||||
|
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Tuple, AsyncGenerator
|
from typing import Dict, Tuple, AsyncGenerator, Iterable
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from openai import AsyncClient
|
from openai import AsyncClient
|
||||||
@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
||||||
'''
|
'''
|
||||||
对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
|
对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
|
||||||
'''
|
'''
|
||||||
@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
|||||||
semaphore.release()
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
async def openai_request(method, body):
|
async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]):
|
||||||
'''
|
'''
|
||||||
helper function to make openai request
|
helper function to make openai request with extra fields
|
||||||
'''
|
'''
|
||||||
async def generator():
|
async def generator():
|
||||||
async for chunk in await method(**params):
|
for x in header:
|
||||||
yield {"data": chunk.json()}
|
if isinstance(x, str):
|
||||||
|
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
x = OpenAIChatOutput.model_validate(x)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"unsupported value: {header}")
|
||||||
|
for k, v in extra_json.items():
|
||||||
|
setattr(x, k, v)
|
||||||
|
yield x.model_dump_json()
|
||||||
|
|
||||||
|
async for chunk in await method(**params):
|
||||||
|
for k, v in extra_json.items():
|
||||||
|
setattr(chunk, k, v)
|
||||||
|
yield chunk.model_dump_json()
|
||||||
|
|
||||||
|
for x in tail:
|
||||||
|
if isinstance(x, str):
|
||||||
|
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
x = OpenAIChatOutput.model_validate(x)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"unsupported value: {tail}")
|
||||||
|
for k, v in extra_json.items():
|
||||||
|
setattr(x, k, v)
|
||||||
|
yield x.model_dump_json()
|
||||||
|
|
||||||
|
params = body.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
params = body.dict(exclude_unset=True)
|
|
||||||
if hasattr(body, "stream") and body.stream:
|
if hasattr(body, "stream") and body.stream:
|
||||||
return EventSourceResponse(generator())
|
return EventSourceResponse(generator())
|
||||||
else:
|
else:
|
||||||
return (await method(**params)).dict()
|
result = await method(**params)
|
||||||
|
for k, v in extra_json.items():
|
||||||
|
setattr(result, k, v)
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
@openai_router.get("/models")
|
@openai_router.get("/models")
|
||||||
@ -74,12 +102,12 @@ async def list_models() -> List:
|
|||||||
client = get_OpenAIClient(name, is_async=True)
|
client = get_OpenAIClient(name, is_async=True)
|
||||||
models = await client.models.list()
|
models = await client.models.list()
|
||||||
if config.get("platform_type") == "xinference":
|
if config.get("platform_type") == "xinference":
|
||||||
models = models.dict(exclude={"data":..., "object":...})
|
models = models.model_dump(exclude={"data":..., "object":...})
|
||||||
for x in models:
|
for x in models:
|
||||||
models[x]["platform_name"] = name
|
models[x]["platform_name"] = name
|
||||||
return [{**v, "id": k} for k, v in models.items()]
|
return [{**v, "id": k} for k, v in models.items()]
|
||||||
elif config.get("platform_type") == "oneapi":
|
elif config.get("platform_type") == "oneapi":
|
||||||
return [{**x.dict(), "platform_name": name} for x in models.data]
|
return [{**x.model_dump(), "platform_name": name} for x in models.data]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"failed request to platform: {name}", exc_info=True)
|
logger.error(f"failed request to platform: {name}", exc_info=True)
|
||||||
return {}
|
return {}
|
||||||
@ -97,12 +125,8 @@ async def create_chat_completions(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIChatInput,
|
body: OpenAIChatInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
result = await openai_request(client.chat.completions.create, body)
|
result = await openai_request(client.chat.completions.create, body)
|
||||||
# result["related_docs"] = ["doc1"]
|
|
||||||
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
|
|
||||||
# print(result)
|
|
||||||
# breakpoint()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +135,7 @@ async def create_completions(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIChatInput,
|
body: OpenAIChatInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.completions.create, body)
|
return await openai_request(client.completions.create, body)
|
||||||
|
|
||||||
|
|
||||||
@ -130,7 +154,7 @@ async def create_image_generations(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIImageGenerationsInput,
|
body: OpenAIImageGenerationsInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.images.generate, body)
|
return await openai_request(client.images.generate, body)
|
||||||
|
|
||||||
|
|
||||||
@ -139,7 +163,7 @@ async def create_image_variations(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIImageVariationsInput,
|
body: OpenAIImageVariationsInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.images.create_variation, body)
|
return await openai_request(client.images.create_variation, body)
|
||||||
|
|
||||||
|
|
||||||
@ -148,7 +172,7 @@ async def create_image_edit(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIImageEditsInput,
|
body: OpenAIImageEditsInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.images.edit, body)
|
return await openai_request(client.images.edit, body)
|
||||||
|
|
||||||
|
|
||||||
@ -157,7 +181,7 @@ async def create_audio_translations(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIAudioTranslationsInput,
|
body: OpenAIAudioTranslationsInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.audio.translations.create, body)
|
return await openai_request(client.audio.translations.create, body)
|
||||||
|
|
||||||
|
|
||||||
@ -166,7 +190,7 @@ async def create_audio_transcriptions(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIAudioTranscriptionsInput,
|
body: OpenAIAudioTranscriptionsInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.audio.transcriptions.create, body)
|
return await openai_request(client.audio.transcriptions.create, body)
|
||||||
|
|
||||||
|
|
||||||
@ -175,7 +199,7 @@ async def create_audio_speech(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIAudioSpeechInput,
|
body: OpenAIAudioSpeechInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with get_model_client(body.model) as client:
|
||||||
return await openai_request(client.audio.speech.create, body)
|
return await openai_request(client.audio.speech.create, body)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import List
|
|||||||
from fastapi import APIRouter, Request, Body
|
from fastapi import APIRouter, Request, Body
|
||||||
|
|
||||||
from chatchat.configs import logger
|
from chatchat.configs import logger
|
||||||
from chatchat.server.utils import BaseResponse, get_tool
|
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
|
||||||
|
|
||||||
|
|
||||||
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
||||||
@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
|||||||
@tool_router.get("/", response_model=BaseResponse)
|
@tool_router.get("/", response_model=BaseResponse)
|
||||||
async def list_tools():
|
async def list_tools():
|
||||||
tools = get_tool()
|
tools = get_tool()
|
||||||
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
|
data = {t.name: {
|
||||||
|
"name": t.name,
|
||||||
|
"title": t.title,
|
||||||
|
"description": t.description,
|
||||||
|
"args": t.args,
|
||||||
|
"config": get_tool_config(t.name),
|
||||||
|
} for t in tools.values()}
|
||||||
return {"data": data}
|
return {"data": data}
|
||||||
|
|
||||||
|
|
||||||
@tool_router.post("/call", response_model=BaseResponse)
|
@tool_router.post("/call", response_model=BaseResponse)
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
name: str = Body(examples=["calculate"]),
|
name: str = Body(examples=["calculate"]),
|
||||||
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
|
tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]),
|
||||||
):
|
):
|
||||||
tools = get_tool()
|
if tool := get_tool(name):
|
||||||
|
|
||||||
if tool := tools.get(name):
|
|
||||||
try:
|
try:
|
||||||
result = await tool.ainvoke(kwargs)
|
result = await tool.ainvoke(tool_input)
|
||||||
return {"data": result}
|
return {"data": result}
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = f"failed to call tool '{name}'"
|
msg = f"failed to call tool '{name}'"
|
||||||
|
|||||||
@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
data = {
|
||||||
|
"run_id": str(run_id),
|
||||||
|
"status": AgentStatus.tool_start,
|
||||||
|
"tool": serialized["name"],
|
||||||
|
"tool_input": input_str,
|
||||||
|
}
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
|
||||||
async def on_tool_end(
|
async def on_tool_end(
|
||||||
@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
data = {
|
data = {
|
||||||
|
"run_id": str(run_id),
|
||||||
"status": AgentStatus.tool_end,
|
"status": AgentStatus.tool_end,
|
||||||
"tool_output": output,
|
"tool_output": output,
|
||||||
}
|
}
|
||||||
@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
data = {
|
data = {
|
||||||
|
"run_id": str(run_id),
|
||||||
"status": AgentStatus.tool_end,
|
"status": AgentStatus.tool_end,
|
||||||
"text": str(error),
|
"tool_output": str(error),
|
||||||
|
"is_error": True,
|
||||||
}
|
}
|
||||||
# self.done.clear()
|
# self.done.clear()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from typing import AsyncIterable, List
|
from typing import AsyncIterable, List
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage
|
|||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from chatchat.configs.model_config import LLM_MODEL_CONFIG
|
||||||
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
|
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
|
||||||
from chatchat.server.agent.container import container
|
from chatchat.server.agent.container import container
|
||||||
|
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
|
||||||
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
|
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
|
||||||
from chatchat.server.chat.utils import History
|
from chatchat.server.chat.utils import History
|
||||||
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||||
from chatchat.server.db.repository import add_message_to_db
|
|
||||||
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||||
|
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks, stream):
|
def create_models_from_config(configs, callbacks, stream):
|
||||||
if configs is None:
|
configs = configs or LLM_MODEL_CONFIG
|
||||||
configs = {}
|
|
||||||
models = {}
|
models = {}
|
||||||
prompts = {}
|
prompts = {}
|
||||||
for model_type, model_configs in configs.items():
|
for model_type, model_configs in configs.items():
|
||||||
@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
||||||
conversation_id: str = Body("", description="对话框ID"),
|
conversation_id: str = Body("", description="对话框ID"),
|
||||||
|
message_id: str = Body(None, description="数据库消息ID"),
|
||||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||||
history: List[History] = Body(
|
history: List[History] = Body(
|
||||||
[],
|
[],
|
||||||
@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
||||||
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
'''Agent 对话'''
|
||||||
message_id = add_message_to_db(
|
|
||||||
chat_type="llm_chat",
|
|
||||||
query=query,
|
|
||||||
conversation_id=conversation_id
|
|
||||||
) if conversation_id else None
|
|
||||||
|
|
||||||
|
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
|
||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
||||||
@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
}
|
}
|
||||||
), callback.done))
|
), callback.done))
|
||||||
|
|
||||||
|
last_tool = {}
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
if data["status"] == AgentStatus.tool_end:
|
data["tool_calls"] = []
|
||||||
|
data["message_type"] = MsgType.TEXT
|
||||||
|
|
||||||
|
if data["status"] == AgentStatus.tool_start:
|
||||||
|
last_tool = {
|
||||||
|
"index": 0,
|
||||||
|
"id": data["run_id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": data["tool"],
|
||||||
|
"arguments": data["tool_input"],
|
||||||
|
},
|
||||||
|
"tool_output": None,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
data["tool_calls"].append(last_tool)
|
||||||
|
if data["status"] in [AgentStatus.tool_end]:
|
||||||
|
last_tool.update(
|
||||||
|
tool_output=data["tool_output"],
|
||||||
|
is_error=data.get("is_error", False)
|
||||||
|
)
|
||||||
|
data["tool_calls"] = [last_tool]
|
||||||
|
last_tool = {}
|
||||||
try:
|
try:
|
||||||
tool_output = json.loads(data["tool_output"])
|
tool_output = json.loads(data["tool_output"])
|
||||||
if message_type := tool_output.get("message_type"):
|
if message_type := tool_output.get("message_type"):
|
||||||
@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
data["message_type"] = message_type
|
data["message_type"] = message_type
|
||||||
except:
|
except:
|
||||||
...
|
...
|
||||||
data.setdefault("message_type", MsgType.TEXT)
|
|
||||||
data["message_id"] = message_id
|
|
||||||
yield json.dumps(data, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
ret = OpenAIChatOutput(
|
||||||
|
id=f"chat{uuid.uuid4()}",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
content=data.get("text", ""),
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=data["tool_calls"],
|
||||||
|
model=models["llm_model"].model_name,
|
||||||
|
status = data["status"],
|
||||||
|
message_type = data["message_type"],
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
yield ret.model_dump_json()
|
||||||
await task
|
await task
|
||||||
|
|
||||||
|
if stream:
|
||||||
return EventSourceResponse(chat_iterator())
|
return EventSourceResponse(chat_iterator())
|
||||||
|
else:
|
||||||
|
ret = OpenAIChatOutput(
|
||||||
|
id=f"chat{uuid.uuid4()}",
|
||||||
|
object="chat.completion",
|
||||||
|
content="",
|
||||||
|
role="assistant",
|
||||||
|
finish_reason="stop",
|
||||||
|
tool_calls=[],
|
||||||
|
status = AgentStatus.agent_finish,
|
||||||
|
message_type = MsgType.TEXT,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in chat_iterator():
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if text := data["choices"][0]["delta"]["content"]:
|
||||||
|
ret.content += text
|
||||||
|
if data["status"] == AgentStatus.tool_end:
|
||||||
|
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
|
||||||
|
ret.model = data["model"]
|
||||||
|
ret.created = data["created"]
|
||||||
|
|
||||||
|
return ret.model_dump()
|
||||||
|
|||||||
@ -96,11 +96,7 @@ class ESKBService(KBService):
|
|||||||
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
|
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||||
|
|
||||||
def do_create_kb(self):
|
def do_create_kb(self):
|
||||||
if os.path.exists(self.doc_path):
|
...
|
||||||
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
|
|
||||||
os.makedirs(os.path.join(self.kb_path, "vector_store"))
|
|
||||||
else:
|
|
||||||
logger.warning("directory `vector_store` already exists.")
|
|
||||||
|
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.ES
|
return SupportedVSType.ES
|
||||||
|
|||||||
@ -65,7 +65,6 @@ class FaissKBService(KBService):
|
|||||||
embed_func = get_Embeddings(self.embed_model)
|
embed_func = get_Embeddings(self.embed_model)
|
||||||
embeddings = embed_func.embed_query(query)
|
embeddings = embed_func.embed_query(query)
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
embeddings = vs.embeddings.embed_query(query)
|
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|||||||
@ -1,24 +1,25 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import uuid
|
||||||
from chatchat.server.utils import get_tool_config
|
|
||||||
import streamlit as st
|
|
||||||
from streamlit_antd_components.utils import ParseItems
|
|
||||||
|
|
||||||
from chatchat.webui_pages.dialogue.utils import process_files
|
|
||||||
# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
|
||||||
# get_select_model_endpoint
|
|
||||||
from chatchat.webui_pages.utils import *
|
|
||||||
from streamlit_chatbox import *
|
|
||||||
from streamlit_modal import Modal
|
|
||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from streamlit_antd_components.utils import ParseItems
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from streamlit_chatbox import *
|
||||||
|
from streamlit_modal import Modal
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
|
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
|
||||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||||
from chatchat.server.utils import MsgType, get_config_models
|
from chatchat.server.utils import MsgType, get_config_models
|
||||||
import uuid
|
from chatchat.server.utils import get_tool_config
|
||||||
from typing import List, Dict
|
from chatchat.webui_pages.utils import *
|
||||||
|
from chatchat.webui_pages.dialogue.utils import process_files
|
||||||
|
|
||||||
|
|
||||||
img_dir = (Path(__file__).absolute().parent.parent.parent)
|
img_dir = (Path(__file__).absolute().parent.parent.parent)
|
||||||
|
|
||||||
@ -121,15 +122,50 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.write("\n\n".join(cmds))
|
st.write("\n\n".join(cmds))
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
tab1, tab2 = st.tabs(["对话设置", "模型设置"])
|
||||||
|
|
||||||
|
with tab1:
|
||||||
|
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
|
||||||
|
# 选择工具
|
||||||
|
tools = api.list_tools()
|
||||||
|
if use_agent:
|
||||||
|
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||||
|
else:
|
||||||
|
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||||
|
selected_tools = [selected_tool]
|
||||||
|
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
|
||||||
|
|
||||||
|
# 当不启用Agent时,手动生成工具参数
|
||||||
|
# TODO: 需要更精细的控制控件
|
||||||
|
tool_input = {}
|
||||||
|
if not use_agent and len(selected_tools) == 1:
|
||||||
|
with st.expander("工具参数", True):
|
||||||
|
for k, v in tools[selected_tools[0]]["args"].items():
|
||||||
|
if choices := v.get("choices", v.get("enum")):
|
||||||
|
tool_input[k] = st.selectbox(v["title"], choices)
|
||||||
|
else:
|
||||||
|
if v["type"] == "integer":
|
||||||
|
tool_input[k] = st.slider(v["title"], value=v.get("default"))
|
||||||
|
elif v["type"] == "number":
|
||||||
|
tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1)
|
||||||
|
else:
|
||||||
|
tool_input[k] = st.text_input(v["title"], v.get("default"))
|
||||||
|
|
||||||
|
|
||||||
|
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||||
|
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||||
|
|
||||||
|
with tab2:
|
||||||
|
# 会话
|
||||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
if st.session_state.get("cur_conv_name") in conv_names:
|
if st.session_state.get("cur_conv_name") in conv_names:
|
||||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||||
chat_box.use_chat_name(conversation_name)
|
chat_box.use_chat_name(conversation_name)
|
||||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||||
|
|
||||||
|
# 模型
|
||||||
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||||
platform = st.selectbox("选择模型平台", platforms)
|
platform = st.selectbox("选择模型平台", platforms)
|
||||||
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
||||||
@ -137,53 +173,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
|
|
||||||
# 传入后端的内容
|
# 传入后端的内容
|
||||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||||
tool_use = True
|
|
||||||
for key in LLM_MODEL_CONFIG:
|
for key in LLM_MODEL_CONFIG:
|
||||||
if key == 'llm_model':
|
|
||||||
continue
|
|
||||||
if key == 'action_model':
|
|
||||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
|
||||||
if first_key not in SUPPORT_AGENT_MODELS:
|
|
||||||
st.warning("不支持Agent的模型,无法执行任何工具调用")
|
|
||||||
tool_use = False
|
|
||||||
continue
|
|
||||||
if LLM_MODEL_CONFIG[key]:
|
if LLM_MODEL_CONFIG[key]:
|
||||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||||
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||||
|
|
||||||
# 选择工具
|
|
||||||
selected_tool_configs = {}
|
|
||||||
if tool_use:
|
|
||||||
from chatchat.configs import model_config as model_config_py
|
|
||||||
import importlib
|
|
||||||
importlib.reload(model_config_py)
|
|
||||||
|
|
||||||
tools = get_tool_config()
|
|
||||||
with st.expander("工具栏"):
|
|
||||||
for tool in tools:
|
|
||||||
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
|
|
||||||
if is_selected:
|
|
||||||
selected_tool_configs[tool] = tools[tool]
|
|
||||||
|
|
||||||
if llm_model is not None:
|
if llm_model is not None:
|
||||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||||
|
|
||||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
|
||||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
|
||||||
# print(len(files_upload["audios"])) if files_upload else None
|
|
||||||
|
|
||||||
# if dialogue_mode == "文件对话":
|
|
||||||
# with st.expander("文件对话配置", True):
|
|
||||||
# files = st.file_uploader("上传知识文件:",
|
|
||||||
# [i for ls in LOADER_DICT.values() for i in ls],
|
|
||||||
# accept_multiple_files=True,
|
|
||||||
# )
|
|
||||||
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
|
||||||
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
|
||||||
# if st.button("开始上传", disabled=len(files) == 0):
|
|
||||||
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
|
||||||
# Display chat messages from history on app rerun
|
# Display chat messages from history on app rerun
|
||||||
|
|
||||||
chat_box.output_messages()
|
chat_box.output_messages()
|
||||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||||
|
|
||||||
@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
|
|
||||||
chat_box.ai_say("正在思考...")
|
chat_box.ai_say("正在思考...")
|
||||||
text = ""
|
text = ""
|
||||||
text_action = ""
|
started = False
|
||||||
element_index = 0
|
|
||||||
|
|
||||||
for d in api.chat_chat(query=prompt,
|
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
|
||||||
|
messages = history + [{"role": "user", "content": prompt}]
|
||||||
|
tools = list(selected_tool_configs)
|
||||||
|
if len(selected_tools) == 1:
|
||||||
|
tool_choice = selected_tools[0]
|
||||||
|
else:
|
||||||
|
tool_choice = None
|
||||||
|
# 如果 tool_input 中有空的字段,设为用户输入
|
||||||
|
for k in tool_input:
|
||||||
|
if tool_input[k] in [None, ""]:
|
||||||
|
tool_input[k] = prompt
|
||||||
|
|
||||||
|
extra_body = dict(
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
history=history,
|
|
||||||
chat_model_config=chat_model_config,
|
chat_model_config=chat_model_config,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
tool_config=selected_tool_configs,
|
tool_input = tool_input,
|
||||||
|
)
|
||||||
|
for d in client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model,
|
||||||
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
extra_body=extra_body,
|
||||||
):
|
):
|
||||||
message_id = d.get("message_id", "")
|
print("\n\n", d.status, "\n", d, "\n\n")
|
||||||
|
message_id = d.message_id
|
||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
print(d)
|
|
||||||
if d["status"] == AgentStatus.error:
|
if d.status == AgentStatus.error:
|
||||||
st.error(d["text"])
|
st.error(d.choices[0].delta.content)
|
||||||
elif d["status"] == AgentStatus.agent_action:
|
elif d.status == AgentStatus.llm_start:
|
||||||
formatted_data = {
|
if not started:
|
||||||
"Function": d["tool_name"],
|
started = True
|
||||||
"function_input": d["tool_input"]
|
|
||||||
}
|
|
||||||
element_index += 1
|
|
||||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
|
||||||
chat_box.insert_msg(
|
|
||||||
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
|
|
||||||
text = """\n```{}\n```\n""".format(formatted_json)
|
|
||||||
chat_box.update_msg(Markdown(text), element_index=element_index)
|
|
||||||
elif d["status"] == AgentStatus.tool_end:
|
|
||||||
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
|
|
||||||
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
|
|
||||||
elif d["status"] == AgentStatus.llm_new_token:
|
|
||||||
text += d["text"]
|
|
||||||
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
|
||||||
elif d["status"] == AgentStatus.llm_end:
|
|
||||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
|
||||||
elif d["status"] == AgentStatus.agent_finish:
|
|
||||||
if d["message_type"] == MsgType.IMAGE:
|
|
||||||
for url in json.loads(d["text"]).get("images", []):
|
|
||||||
url = f"{api.base_url}/media/{url}"
|
|
||||||
chat_box.insert_msg(Image(url))
|
|
||||||
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
|
|
||||||
else:
|
else:
|
||||||
chat_box.insert_msg(Markdown(d["text"], expanded=True))
|
chat_box.insert_msg("正在解读工具输出结果...")
|
||||||
|
text = d.choices[0].delta.content or ""
|
||||||
|
elif d.status == AgentStatus.llm_new_token:
|
||||||
|
text += d.choices[0].delta.content or ""
|
||||||
|
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||||
|
elif d.status == AgentStatus.llm_end:
|
||||||
|
text += d.choices[0].delta.content or ""
|
||||||
|
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
|
||||||
|
# tool 的输出与 llm 输出重复了
|
||||||
|
# elif d.status == AgentStatus.tool_start:
|
||||||
|
# formatted_data = {
|
||||||
|
# "Function": d.choices[0].delta.tool_calls[0].function.name,
|
||||||
|
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
|
||||||
|
# }
|
||||||
|
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||||
|
# text = """\n```{}\n```\n""".format(formatted_json)
|
||||||
|
# chat_box.insert_msg( # TODO: insert text directly not shown
|
||||||
|
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
|
||||||
|
# elif d.status == AgentStatus.tool_end:
|
||||||
|
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
|
||||||
|
# if d.message_type == MsgType.IMAGE:
|
||||||
|
# for url in json.loads(tool_output).get("images", []):
|
||||||
|
# url = f"{api.base_url}/media/{url}"
|
||||||
|
# chat_box.insert_msg(Image(url))
|
||||||
|
# chat_box.update_msg(expanded=False, state="complete")
|
||||||
|
# else:
|
||||||
|
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
|
||||||
|
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
|
||||||
|
elif d.status == AgentStatus.agent_finish:
|
||||||
|
text = d.choices[0].delta.content or ""
|
||||||
|
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||||
|
elif d.status == None: # not agent chat
|
||||||
|
if getattr(d, "is_ref", False):
|
||||||
|
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
|
||||||
|
chat_box.insert_msg("")
|
||||||
|
else:
|
||||||
|
text += d.choices[0].delta.content or ""
|
||||||
|
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||||
|
chat_box.update_msg(text, streaming=False, metadata=metadata)
|
||||||
|
|
||||||
if os.path.exists("tmp/image.jpg"):
|
if os.path.exists("tmp/image.jpg"):
|
||||||
with open("tmp/image.jpg", "rb") as image_file:
|
with open("tmp/image.jpg", "rb") as image_file:
|
||||||
@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
with st.sidebar:
|
with tab1:
|
||||||
|
|
||||||
cols = st.columns(2)
|
cols = st.columns(2)
|
||||||
export_btn = cols[0]
|
export_btn = cols[0]
|
||||||
if cols[1].button(
|
if cols[1].button(
|
||||||
@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
mime="text/markdown",
|
mime="text/markdown",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# st.write(chat_box.history)
|
||||||
|
|||||||
@ -50,6 +50,14 @@ class ApiRequest:
|
|||||||
timeout=self.timeout)
|
timeout=self.timeout)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def _check_url(self, url: str) -> str:
|
||||||
|
'''
|
||||||
|
新版 httpx 强制要求 url 以 / 结尾,否则会返回 307
|
||||||
|
'''
|
||||||
|
if not url.endswith("/"):
|
||||||
|
url = url + "/"
|
||||||
|
return url
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
@ -58,6 +66,7 @@ class ApiRequest:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
|
url = self._check_url(url)
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
@ -79,6 +88,7 @@ class ApiRequest:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
|
url = self._check_url(url)
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
# print(kwargs)
|
# print(kwargs)
|
||||||
@ -101,6 +111,7 @@ class ApiRequest:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
|
url = self._check_url(url)
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
@ -638,6 +649,27 @@ class ApiRequest:
|
|||||||
resp = self.post("/chat/feedback", json=data)
|
resp = self.post("/chat/feedback", json=data)
|
||||||
return self._get_response_value(resp)
|
return self._get_response_value(resp)
|
||||||
|
|
||||||
|
def list_tools(self) -> Dict:
|
||||||
|
'''
|
||||||
|
列出所有工具
|
||||||
|
'''
|
||||||
|
resp = self.get("/tools")
|
||||||
|
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
|
def call_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
tool_input: Dict = {},
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
调用工具
|
||||||
|
'''
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"tool_input": tool_input,
|
||||||
|
}
|
||||||
|
resp = self.post("/tools/call", json=data)
|
||||||
|
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||||
|
|
||||||
class AsyncApiRequest(ApiRequest):
|
class AsyncApiRequest(ApiRequest):
|
||||||
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user