优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)

- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui
This commit is contained in:
liunux4odoo 2024-03-29 11:55:32 +08:00 committed by GitHub
parent 5e70aff522
commit 42aa900566
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 614 additions and 227 deletions

View File

@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
s = s[-1] s = s[-1]
return AgentFinish({"output": s}, log=text) return AgentFinish({"output": s}, log=text)
else: else:
raise OutputParserException(f"Could not parse LLM output: {text}") return AgentFinish({"output": text}, log=text)
# raise OutputParserException(f"Could not parse LLM output: {text}")
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -1,7 +1,7 @@
import base64 import base64
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
def save_base64_audio(base64_audio, file_path): def save_base64_audio(base64_audio, file_path):
@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query):
return response return response
@regist_tool @regist_tool(title="音频问答")
def aqa_processor(query: str = Field(description="The question of the image in English")): def aqa_processor(query: str = Field(description="The question of the audio in English")):
'''use this tool to get answer for audio question''' '''use this tool to get answer for audio question'''
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E
"audio": file_path, "audio": file_path,
"text": query, "text": query,
} }
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else: else:
return "No Audio, Please Try Again" ret = "No Audio, Please Try Again"
return BaseToolOutput(ret)

View File

@ -1,12 +1,12 @@
# LangChain 的 ArxivQueryRun 工具 # LangChain 的 ArxivQueryRun 工具
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
@regist_tool @regist_tool(title="ARXIV论文")
def arxiv(query: str = Field(description="The search query title")): def arxiv(query: str = Field(description="The search query title")):
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.''' '''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
from langchain.tools.arxiv.tool import ArxivQueryRun from langchain.tools.arxiv.tool import ArxivQueryRun
tool = ArxivQueryRun() tool = ArxivQueryRun()
return tool.run(tool_input=query) return BaseToolOutput(tool.run(tool_input=query))

View File

@ -1,8 +1,8 @@
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
@regist_tool @regist_tool(title="数学计算器")
def calculate(text: str = Field(description="a math expression")) -> float: def calculate(text: str = Field(description="a math expression")) -> float:
''' '''
Useful to answer questions about simple calculations. Useful to answer questions about simple calculations.
@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float:
import numexpr import numexpr
try: try:
return str(numexpr.evaluate(text)) ret = str(numexpr.evaluate(text))
except Exception as e: except Exception as e:
return f"wrong: {e}" ret = f"wrong: {e}"
return BaseToolOutput(ret)

View File

@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
def bing_search(text, config): def bing_search(text, config):
@ -95,8 +95,8 @@ def search_engine(query: str,
return context return context
@regist_tool @regist_tool(title="互联网搜索")
def search_internet(query: str = Field(description="query for Internet search")): def search_internet(query: str = Field(description="query for Internet search")):
'''Use this tool to use bing search engine to search the internet and get information.''' '''Use this tool to use bing search engine to search the internet and get information.'''
tool_config = get_tool_config("search_internet") tool_config = get_tool_config("search_internet")
return search_engine(query=query, config=tool_config) return BaseToolOutput(search_engine(query=query, config=tool_config))

View File

@ -1,8 +1,9 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_doc_api import search_docs from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO from chatchat.configs import KB_INFO
@ -11,35 +12,44 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples") template_knowledge = template.format(KB_info=KB_info_str, key="samples")
class KBToolOutput(BaseToolOutput):
def __str__(self) -> str:
context = ""
docs = self.data
source_documents = []
for inum, doc in enumerate(docs):
doc = DocumentWithVSId.parse_obj(doc)
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
def search_knowledgebase(query: str, database: str, config: dict): def search_knowledgebase(query: str, database: str, config: dict):
docs = search_docs( docs = search_docs(
query=query, query=query,
knowledge_base_name=database, knowledge_base_name=database,
top_k=config["top_k"], top_k=config["top_k"],
score_threshold=config["score_threshold"]) score_threshold=config["score_threshold"])
context = "" return docs
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
@regist_tool(description=template_knowledge) @regist_tool(description=template_knowledge, title="本地知识库")
def search_local_knowledgebase( def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search"), database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
query: str = Field(description="Query for Knowledge Search"), query: str = Field(description="Query for Knowledge Search"),
): ):
'''''' ''''''
tool_config = get_tool_config("search_local_knowledgebase") tool_config = get_tool_config("search_local_knowledgebase")
return search_knowledgebase(query=query, database=database, config=tool_config) ret = search_knowledgebase(query=query, database=database, config=tool_config)
return BaseToolOutput(ret, database=database)

View File

@ -1,10 +1,10 @@
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
@regist_tool @regist_tool(title="油管视频")
def search_youtube(query: str = Field(description="Query for Videos search")): def search_youtube(query: str = Field(description="Query for Videos search")):
'''use this tools_factory to search youtube videos''' '''use this tools_factory to search youtube videos'''
from langchain_community.tools import YouTubeSearchTool from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool() tool = YouTubeSearchTool()
return tool.run(tool_input=query) return BaseToolOutput(tool.run(tool_input=query))

View File

@ -2,11 +2,11 @@
from langchain.tools.shell import ShellTool from langchain.tools.shell import ShellTool
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
@regist_tool @regist_tool(title="系统命令")
def shell(query: str = Field(description="The command to execute")): def shell(query: str = Field(description="The command to execute")):
'''Use Shell to execute system shell commands''' '''Use Shell to execute system shell commands'''
tool = ShellTool() tool = ShellTool()
return tool.run(tool_input=query) return BaseToolOutput(tool.run(tool_input=query))

View File

@ -7,7 +7,7 @@ import uuid
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
import openai import openai
from chatchat.configs.basic_config import MEDIA_PATH from chatchat.configs.basic_config import MEDIA_PATH
@ -26,7 +26,7 @@ def get_image_model_config() -> dict:
return config return config
@regist_tool(return_direct=True) @regist_tool(title="文生图", return_direct=True)
def text2images( def text2images(
prompt: str, prompt: str,
n: int = Field(1, description="需生成图片的数量"), n: int = Field(1, description="需生成图片的数量"),
@ -56,7 +56,7 @@ def text2images(
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp: with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json)) fp.write(base64.b64decode(x.b64_json))
images.append(filename) images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images}) return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,15 +1,22 @@
import json
import re import re
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
from langchain.agents import tool from langchain.agents import tool
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from chatchat.server.pydantic_v1 import BaseModel from chatchat.server.pydantic_v1 import BaseModel, Extra
__all__ = ["regist_tool", "BaseToolOutput"]
_TOOLS_REGISTRY = {} _TOOLS_REGISTRY = {}
# patch BaseTool to support extra fields e.g. a title
BaseTool.Config.extra = Extra.allow
################################### TODO: workaround to langchain #15855 ################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field # patch BaseTool to support tool parameters defined using pydantic Field
@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
def regist_tool( def regist_tool(
*args: Any, *args: Any,
title: str = "",
description: str = "", description: str = "",
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
@ -70,9 +78,10 @@ def regist_tool(
add tool to regstiry automatically add tool to regstiry automatically
''' '''
def _parse_tool(t: BaseTool): def _parse_tool(t: BaseTool):
nonlocal description nonlocal description, title
_TOOLS_REGISTRY[t.name] = t _TOOLS_REGISTRY[t.name] = t
# change default description # change default description
if not description: if not description:
if t.func is not None: if t.func is not None:
@ -80,6 +89,10 @@ def regist_tool(
elif t.coroutine is not None: elif t.coroutine is not None:
description = t.coroutine.__doc__ description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description)) t.description = " ".join(re.split(r"\n+\s*", description))
# set a default title for human
if not title:
title = "".join([x.capitalize() for x in t.name.split("_")])
t.title = title
def wrapper(def_func: Callable) -> BaseTool: def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args, partial_ = tool(*args,
@ -101,3 +114,30 @@ def regist_tool(
) )
_parse_tool(t) _parse_tool(t)
return t return t
class BaseToolOutput:
'''
LLM 要求 Tool 的输出为 str Tool 用在别处时希望它正常返回结构化数据
只需要将 Tool 返回值用该类封装能同时满足两者的需要
基类简单的将返回值字符串化或指定 format="json" 将其转为 json
用户也可以继承该类定义自己的转换方法
'''
def __init__(
self,
data: Any,
format: str="",
data_alias: str="",
**extras: Any,
) -> None:
self.data = data
self.format = format
self.extras = extras
if data_alias:
setattr(self, data_alias, property(lambda obj: obj.data))
def __str__(self) -> str:
if self.format == "json":
return json.dumps(self.data, ensure_ascii=False, indent=2)
else:
return str(self.data)

View File

@ -6,7 +6,7 @@ from io import BytesIO
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
import re import re
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
return response return response
@regist_tool @regist_tool(title="图片对话")
def vqa_processor(query: str = Field(description="The question of the image in English")): def vqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for image question''' '''use this tool to get answer for image question'''
@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E
# end_marker = "Grounded Operation:" # end_marker = "Grounded Operation:"
# ans = extract_between_markers(ans, start_marker, end_marker) # ans = extract_between_markers(ans, start_marker, end_marker)
return ans ret = ans
else: else:
return "No Image, Please Try Again" ret = "No Image, Please Try Again"
return BaseToolOutput(ret)

View File

@ -3,11 +3,11 @@
""" """
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
import requests import requests
@regist_tool @regist_tool(title="天气查询")
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")): def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
'''Use this tool to check the weather at a specific city''' '''Use this tool to check the weather at a specific city'''
@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun
"temperature": data["results"][0]["now"]["temperature"], "temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"], "description": data["results"][0]["now"]["text"],
} }
return weather return BaseToolOutput(weather)
else: else:
raise Exception( raise Exception(
f"Failed to retrieve weather: {response.status_code}") f"Failed to retrieve weather: {response.status_code}")

View File

@ -2,7 +2,7 @@
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool from .tools_registry import regist_tool, BaseToolOutput
@regist_tool @regist_tool
@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid")) wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
ans = wolfram.run(query) ans = wolfram.run(query)
return ans return BaseToolOutput(ans)

View File

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import re import json
import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
from fastapi import UploadFile from fastapi import UploadFile
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam, ChatCompletionToolChoiceOptionParam,
@ -12,8 +12,10 @@ from openai.types.chat import (
completion_create_params, completion_create_params,
) )
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
from chatchat.server.utils import MsgType
class OpenAIBaseInput(BaseModel): class OpenAIBaseInput(BaseModel):
user: Optional[str] = None user: Optional[str] = None
@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel):
# The extra values given here take precedence over values defined on the client or passed to this method. # The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None extra_query: Optional[Dict] = None
extra_body: Optional[Dict] = None extra_json: Optional[Dict] = Field(None, alias="extra_body")
timeout: Optional[float] = None timeout: Optional[float] = None
class Config: class Config:
@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput):
stop: Union[Optional[str], List[str]] = None stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[ChatCompletionToolParam] = None tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None top_logprobs: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
voice: str voice: str
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
speed: Optional[float] = None speed: Optional[float] = None
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None
model: Optional[str] = None
object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk"
role: Literal["assistant"] = "assistant"
finish_reason: Optional[str] = None
created: int = Field(default_factory=lambda : int(time.time()))
tool_calls: List[Dict] = []
status: Optional[int] = None # AgentStatus
message_type: int = MsgType.TEXT
message_id: Optional[str] = None # id in database table
is_ref: bool = False # wheather show in seperated expander
class Config:
extra = "allow"
def model_dump(self) -> dict:
result = {
"id": self.id,
"object": self.object,
"model": self.model,
"created": self.created,
"status": self.status,
"message_type": self.message_type,
"message_id": self.message_id,
"is_ref": self.is_ref,
**(self.model_extra or {}),
}
if self.object == "chat.completion.chunk":
result["choices"] = [{
"delta": {
"content": self.content,
"tool_calls": self.tool_calls,
},
"role": self.role,
}]
elif self.object == "chat.completion":
result["choices"] = [{
"message": {
"role": self.role,
"content": self.content,
"finish_reason": self.finish_reason,
"tool_calls": self.tool_calls,
}
}]
return result
def model_dump_json(self):
return json.dumps(self.model_dump(), ensure_ascii=False)
class OpenAIChatOutput(OpenAIBaseOutput):
...

View File

@ -1,12 +1,17 @@
from __future__ import annotations from __future__ import annotations
from typing import List from typing import List, Dict
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
from chatchat.server.chat.chat import chat from chatchat.server.chat.chat import chat
from chatchat.server.chat.feedback import chat_feedback from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
from .openai_routes import openai_request
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"]) chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
@ -21,4 +26,115 @@ chat_router.post("/feedback",
chat_router.post("/file_chat", chat_router.post("/file_chat",
summary="文件对话" summary="文件对话"
)(file_chat) )(file_chat)
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
'''
请求参数与 openai.chat.completions.create 一致可以通过 extra_body 传入额外参数
tools tool_choice 可以直接传工具名称会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合如文件对话
返回与 openai 兼容的 Dict
'''
client = get_OpenAIClient(model_name=body.model, is_async=True)
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice):
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
}
}
conversation_id = extra.get("conversation_id")
# chat based on result from one choiced tool
if body.tool_choice:
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
body.tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
}
}]
if tool_input := extra.get("tool_input"):
message_id = add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
tool_result = await tool.ainvoke(tool_input)
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
}
header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
"tool_output":tool_result.data,
"is_ref": True,
}]
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
# agent chat with tool calls
if body.tools:
message_id = add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
)
return result
else: # LLM chat directly
message_id = add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
extra_json = {
"message_id": message_id,
"status": None,
}
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Tuple, AsyncGenerator from typing import Dict, Tuple, AsyncGenerator, Iterable
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from openai import AsyncClient from openai import AsyncClient
@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]
@asynccontextmanager @asynccontextmanager
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
''' '''
对重名模型进行调度依次选择空闲的模型 -> 当前访问数最少的模型 对重名模型进行调度依次选择空闲的模型 -> 当前访问数最少的模型
''' '''
@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
semaphore.release() semaphore.release()
async def openai_request(method, body): async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]):
''' '''
helper function to make openai request helper function to make openai request with extra fields
''' '''
async def generator(): async def generator():
async for chunk in await method(**params): for x in header:
yield {"data": chunk.json()} if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()
for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
params = body.model_dump(exclude_unset=True)
params = body.dict(exclude_unset=True)
if hasattr(body, "stream") and body.stream: if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:
return (await method(**params)).dict() result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()
@openai_router.get("/models") @openai_router.get("/models")
@ -74,12 +102,12 @@ async def list_models() -> List:
client = get_OpenAIClient(name, is_async=True) client = get_OpenAIClient(name, is_async=True)
models = await client.models.list() models = await client.models.list()
if config.get("platform_type") == "xinference": if config.get("platform_type") == "xinference":
models = models.dict(exclude={"data":..., "object":...}) models = models.model_dump(exclude={"data":..., "object":...})
for x in models: for x in models:
models[x]["platform_name"] = name models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()] return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi": elif config.get("platform_type") == "oneapi":
return [{**x.dict(), "platform_name": name} for x in models.data] return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception: except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True) logger.error(f"failed request to platform: {name}", exc_info=True)
return {} return {}
@ -97,12 +125,8 @@ async def create_chat_completions(
request: Request, request: Request,
body: OpenAIChatInput, body: OpenAIChatInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
result = await openai_request(client.chat.completions.create, body) result = await openai_request(client.chat.completions.create, body)
# result["related_docs"] = ["doc1"]
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
# print(result)
# breakpoint()
return result return result
@ -111,7 +135,7 @@ async def create_completions(
request: Request, request: Request,
body: OpenAIChatInput, body: OpenAIChatInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.completions.create, body) return await openai_request(client.completions.create, body)
@ -130,7 +154,7 @@ async def create_image_generations(
request: Request, request: Request,
body: OpenAIImageGenerationsInput, body: OpenAIImageGenerationsInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.images.generate, body) return await openai_request(client.images.generate, body)
@ -139,7 +163,7 @@ async def create_image_variations(
request: Request, request: Request,
body: OpenAIImageVariationsInput, body: OpenAIImageVariationsInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.images.create_variation, body) return await openai_request(client.images.create_variation, body)
@ -148,7 +172,7 @@ async def create_image_edit(
request: Request, request: Request,
body: OpenAIImageEditsInput, body: OpenAIImageEditsInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.images.edit, body) return await openai_request(client.images.edit, body)
@ -157,7 +181,7 @@ async def create_audio_translations(
request: Request, request: Request,
body: OpenAIAudioTranslationsInput, body: OpenAIAudioTranslationsInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.audio.translations.create, body) return await openai_request(client.audio.translations.create, body)
@ -166,7 +190,7 @@ async def create_audio_transcriptions(
request: Request, request: Request,
body: OpenAIAudioTranscriptionsInput, body: OpenAIAudioTranscriptionsInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.audio.transcriptions.create, body) return await openai_request(client.audio.transcriptions.create, body)
@ -175,7 +199,7 @@ async def create_audio_speech(
request: Request, request: Request,
body: OpenAIAudioSpeechInput, body: OpenAIAudioSpeechInput,
): ):
async with acquire_model_client(body.model) as client: async with get_model_client(body.model) as client:
return await openai_request(client.audio.speech.create, body) return await openai_request(client.audio.speech.create, body)

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Request, Body from fastapi import APIRouter, Request, Body
from chatchat.configs import logger from chatchat.configs import logger
from chatchat.server.utils import BaseResponse, get_tool from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse) @tool_router.get("/", response_model=BaseResponse)
async def list_tools(): async def list_tools():
tools = get_tool() tools = get_tool()
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools} data = {t.name: {
"name": t.name,
"title": t.title,
"description": t.description,
"args": t.args,
"config": get_tool_config(t.name),
} for t in tools.values()}
return {"data": data} return {"data": data}
@tool_router.post("/call", response_model=BaseResponse) @tool_router.post("/call", response_model=BaseResponse)
async def call_tool( async def call_tool(
name: str = Body(examples=["calculate"]), name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]), tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]),
): ):
tools = get_tool() if tool := get_tool(name):
if tool := tools.get(name):
try: try:
result = await tool.ainvoke(kwargs) result = await tool.ainvoke(tool_input)
return {"data": result} return {"data": result}
except Exception: except Exception:
msg = f"failed to call tool '{name}'" msg = f"failed to call tool '{name}'"

View File

@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
pass data = {
"run_id": str(run_id),
"status": AgentStatus.tool_start,
"tool": serialized["name"],
"tool_input": input_str,
}
self.queue.put_nowait(dumps(data))
async def on_tool_end( async def on_tool_end(
@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None: ) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""
data = { data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end, "status": AgentStatus.tool_end,
"tool_output": output, "tool_output": output,
} }
@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None: ) -> None:
"""Run when tool errors.""" """Run when tool errors."""
data = { data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end, "status": AgentStatus.tool_end,
"text": str(error), "tool_output": str(error),
"is_error": True,
} }
# self.done.clear() # self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
import time
from typing import AsyncIterable, List from typing import AsyncIterable, List
import uuid
from fastapi import Body from fastapi import Body
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from chatchat.configs.model_config import LLM_MODEL_CONFIG
from chatchat.server.agent.agent_factory.agents_registry import agents_registry from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
from chatchat.server.chat.utils import History from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream): def create_models_from_config(configs, callbacks, stream):
if configs is None: configs = configs or LLM_MODEL_CONFIG
configs = {}
models = {} models = {}
prompts = {} prompts = {}
for model_type, model_configs in configs.items(): for model_type, model_configs in configs.items():
@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"), conversation_id: str = Body("", description="对话框ID"),
message_id: str = Body(None, description="数据库消息ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: List[History] = Body( history: List[History] = Body(
[], [],
@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]), tool_config: dict = Body({}, description="工具配置", examples=[]),
): ):
async def chat_iterator() -> AsyncIterable[str]: '''Agent 对话'''
message_id = add_message_to_db(
chat_type="llm_chat",
query=query,
conversation_id=conversation_id
) if conversation_id else None
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback] callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
} }
), callback.done)) ), callback.done))
last_tool = {}
async for chunk in callback.aiter(): async for chunk in callback.aiter():
data = json.loads(chunk) data = json.loads(chunk)
if data["status"] == AgentStatus.tool_end: data["tool_calls"] = []
data["message_type"] = MsgType.TEXT
if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False)
)
data["tool_calls"] = [last_tool]
last_tool = {}
try: try:
tool_output = json.loads(data["tool_output"]) tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"): if message_type := tool_output.get("message_type"):
@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
data["message_type"] = message_type data["message_type"] = message_type
except: except:
... ...
data.setdefault("message_type", MsgType.TEXT)
data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False)
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=models["llm_model"].model_name,
status = data["status"],
message_type = data["message_type"],
message_id=message_id,
)
yield ret.model_dump_json()
await task await task
return EventSourceResponse(chat_iterator()) if stream:
return EventSourceResponse(chat_iterator())
else:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status = AgentStatus.agent_finish,
message_type = MsgType.TEXT,
message_id=message_id,
)
async for chunk in chat_iterator():
data = json.loads(chunk)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]
return ret.model_dump()

View File

@ -96,11 +96,7 @@ class ESKBService(KBService):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store") return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self): def do_create_kb(self):
if os.path.exists(self.doc_path): ...
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.ES return SupportedVSType.ES

View File

@ -65,7 +65,6 @@ class FaissKBService(KBService):
embed_func = get_Embeddings(self.embed_model) embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query) embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs return docs

View File

@ -1,24 +1,25 @@
import base64 import base64
import uuid
from chatchat.server.utils import get_tool_config
import streamlit as st
from streamlit_antd_components.utils import ParseItems
from chatchat.webui_pages.dialogue.utils import process_files
# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
# get_select_model_endpoint
from chatchat.webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
import os import os
import re import re
import time import time
from typing import List, Dict
import streamlit as st
from streamlit_antd_components.utils import ParseItems
import openai
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS) from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models from chatchat.server.utils import MsgType, get_config_models
import uuid from chatchat.server.utils import get_tool_config
from typing import List, Dict from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.utils import process_files
img_dir = (Path(__file__).absolute().parent.parent.parent) img_dir = (Path(__file__).absolute().parent.parent.parent)
@ -121,69 +122,66 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.write("\n\n".join(cmds)) st.write("\n\n".join(cmds))
with st.sidebar: with st.sidebar:
conv_names = list(st.session_state["conversation_ids"].keys()) tab1, tab2 = st.tabs(["对话设置", "模型设置"])
index = 0
if st.session_state.get("cur_conv_name") in conv_names: with tab1:
index = conv_names.index(st.session_state.get("cur_conv_name")) use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
conversation_name = st.selectbox("当前会话", conv_names, index=index) # 选择工具
chat_box.use_chat_name(conversation_name) tools = api.list_tools()
conversation_id = st.session_state["conversation_ids"][conversation_name] if use_agent:
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
else:
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
selected_tools = [selected_tool]
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] # 当不启用Agent时手动生成工具参数
platform = st.selectbox("选择模型平台", platforms) # TODO: 需要更精细的控制控件
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform)) tool_input = {}
llm_model = st.selectbox("选择LLM模型", llm_models) if not use_agent and len(selected_tools) == 1:
with st.expander("工具参数", True):
for k, v in tools[selected_tools[0]]["args"].items():
if choices := v.get("choices", v.get("enum")):
tool_input[k] = st.selectbox(v["title"], choices)
else:
if v["type"] == "integer":
tool_input[k] = st.slider(v["title"], value=v.get("default"))
elif v["type"] == "number":
tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1)
else:
tool_input[k] = st.text_input(v["title"], v.get("default"))
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
continue
if key == 'action_model':
first_key = next(iter(LLM_MODEL_CONFIG[key]))
if first_key not in SUPPORT_AGENT_MODELS:
st.warning("不支持Agent的模型无法执行任何工具调用")
tool_use = False
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具 uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
selected_tool_configs = {} files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
if tool_use:
from chatchat.configs import model_config as model_config_py with tab2:
import importlib # 会话
importlib.reload(model_config_py) conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
tools = get_tool_config() # 模型
with st.expander("工具栏"): platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
for tool in tools: platform = st.selectbox("选择模型平台", platforms)
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool) llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
if is_selected: llm_model = st.selectbox("选择LLM模型", llm_models)
selected_tool_configs[tool] = tools[tool]
if llm_model is not None: # 传入后端的内容
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
for key in LLM_MODEL_CONFIG:
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) if llm_model is not None:
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# print(len(files_upload["audios"])) if files_upload else None
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
# files = st.file_uploader("上传知识文件:",
# [i for ls in LOADER_DICT.values() for i in ls],
# accept_multiple_files=True,
# )
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
# if st.button("开始上传", disabled=len(files) == 0):
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun # Display chat messages from history on app rerun
chat_box.output_messages() chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 " chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.ai_say("正在思考...") chat_box.ai_say("正在思考...")
text = "" text = ""
text_action = "" started = False
element_index = 0
for d in api.chat_chat(query=prompt, client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
metadata=files_upload, messages = history + [{"role": "user", "content": prompt}]
history=history, tools = list(selected_tool_configs)
chat_model_config=chat_model_config, if len(selected_tools) == 1:
conversation_id=conversation_id, tool_choice = selected_tools[0]
tool_config=selected_tool_configs, else:
): tool_choice = None
message_id = d.get("message_id", "") # 如果 tool_input 中有空的字段,设为用户输入
for k in tool_input:
if tool_input[k] in [None, ""]:
tool_input[k] = prompt
extra_body = dict(
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input = tool_input,
)
for d in client.chat.completions.create(
messages=messages,
model=llm_model,
stream=True,
tools=tools,
tool_choice=tool_choice,
extra_body=extra_body,
):
print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
print(d)
if d["status"] == AgentStatus.error: if d.status == AgentStatus.error:
st.error(d["text"]) st.error(d.choices[0].delta.content)
elif d["status"] == AgentStatus.agent_action: elif d.status == AgentStatus.llm_start:
formatted_data = { if not started:
"Function": d["tool_name"], started = True
"function_input": d["tool_input"]
}
element_index += 1
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
chat_box.insert_msg(
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
text = """\n```{}\n```\n""".format(formatted_json)
chat_box.update_msg(Markdown(text), element_index=element_index)
elif d["status"] == AgentStatus.tool_end:
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
elif d["status"] == AgentStatus.llm_new_token:
text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
if d["message_type"] == MsgType.IMAGE:
for url in json.loads(d["text"]).get("images", []):
url = f"{api.base_url}/media/{url}"
chat_box.insert_msg(Image(url))
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
else: else:
chat_box.insert_msg(Markdown(d["text"], expanded=True)) chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
elif d.status == AgentStatus.llm_end:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
# tool 的输出与 llm 输出重复了
# elif d.status == AgentStatus.tool_start:
# formatted_data = {
# "Function": d.choices[0].delta.tool_calls[0].function.name,
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
# }
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
# text = """\n```{}\n```\n""".format(formatted_json)
# chat_box.insert_msg( # TODO: insert text directly not shown
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
# elif d.status == AgentStatus.tool_end:
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
# if d.message_type == MsgType.IMAGE:
# for url in json.loads(tool_output).get("images", []):
# url = f"{api.base_url}/media/{url}"
# chat_box.insert_msg(Image(url))
# chat_box.update_msg(expanded=False, state="complete")
# else:
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat
if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
chat_box.insert_msg("")
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
chat_box.update_msg(text, streaming=False, metadata=metadata)
if os.path.exists("tmp/image.jpg"): if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file: with open("tmp/image.jpg", "rb") as image_file:
@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun() st.rerun()
now = datetime.now() now = datetime.now()
with st.sidebar: with tab1:
cols = st.columns(2) cols = st.columns(2)
export_btn = cols[0] export_btn = cols[0]
if cols[1].button( if cols[1].button(
@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
mime="text/markdown", mime="text/markdown",
use_container_width=True, use_container_width=True,
) )
# st.write(chat_box.history)

View File

@ -50,6 +50,14 @@ class ApiRequest:
timeout=self.timeout) timeout=self.timeout)
return self._client return self._client
def _check_url(self, url: str) -> str:
'''
新版 httpx 强制要求 url / 结尾否则会返回 307
'''
if not url.endswith("/"):
url = url + "/"
return url
def get( def get(
self, self,
url: str, url: str,
@ -58,6 +66,7 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
@ -79,6 +88,7 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
# print(kwargs) # print(kwargs)
@ -101,6 +111,7 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
@ -638,6 +649,27 @@ class ApiRequest:
resp = self.post("/chat/feedback", json=data) resp = self.post("/chat/feedback", json=data)
return self._get_response_value(resp) return self._get_response_value(resp)
def list_tools(self) -> Dict:
'''
列出所有工具
'''
resp = self.get("/tools")
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {}))
def call_tool(
self,
name: str,
tool_input: Dict = {},
):
'''
调用工具
'''
data = {
"name": name,
"tool_input": tool_input,
}
resp = self.post("/tools/call", json=data)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
class AsyncApiRequest(ApiRequest): class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):