mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-31 11:28:28 +08:00
优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)
- 修复:
- Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
- CallbackHandler 正确处理工具调用信息
- 重写 tool 定义方式:
- 添加 regist_tool 简化 tool 定义:
- 可以指定一个用户友好的名称
- 自动将函数的 __doc__ 作为 tool.description
- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
- 添加 BaseToolOutput 封装 tool 返回结果,以便同时获取原始值、给LLM的字符串值
- 支持工具热加载(有待测试)
- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
- Agent 对话
- 指定工具调用(如知识库RAG)
- LLM 对话
- 根据后端功能更新 webui
This commit is contained in:
parent
5e70aff522
commit
42aa900566
@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
|
||||
s = s[-1]
|
||||
return AgentFinish({"output": s}, log=text)
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
return AgentFinish({"output": text}, log=text)
|
||||
# raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
def save_base64_audio(base64_audio, file_path):
|
||||
@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query):
|
||||
return response
|
||||
|
||||
|
||||
@regist_tool
|
||||
def aqa_processor(query: str = Field(description="The question of the image in English")):
|
||||
@regist_tool(title="音频问答")
|
||||
def aqa_processor(query: str = Field(description="The question of the audio in English")):
|
||||
'''use this tool to get answer for audio question'''
|
||||
|
||||
from chatchat.server.agent.container import container
|
||||
@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E
|
||||
"audio": file_path,
|
||||
"text": query,
|
||||
}
|
||||
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
||||
ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
||||
else:
|
||||
return "No Audio, Please Try Again"
|
||||
ret = "No Audio, Please Try Again"
|
||||
|
||||
return BaseToolOutput(ret)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# LangChain 的 ArxivQueryRun 工具
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="ARXIV论文")
|
||||
def arxiv(query: str = Field(description="The search query title")):
|
||||
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
|
||||
tool = ArxivQueryRun()
|
||||
return tool.run(tool_input=query)
|
||||
return BaseToolOutput(tool.run(tool_input=query))
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="数学计算器")
|
||||
def calculate(text: str = Field(description="a math expression")) -> float:
|
||||
'''
|
||||
Useful to answer questions about simple calculations.
|
||||
@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float:
|
||||
import numexpr
|
||||
|
||||
try:
|
||||
return str(numexpr.evaluate(text))
|
||||
ret = str(numexpr.evaluate(text))
|
||||
except Exception as e:
|
||||
return f"wrong: {e}"
|
||||
ret = f"wrong: {e}"
|
||||
|
||||
return BaseToolOutput(ret)
|
||||
|
||||
@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
def bing_search(text, config):
|
||||
@ -95,8 +95,8 @@ def search_engine(query: str,
|
||||
return context
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="互联网搜索")
|
||||
def search_internet(query: str = Field(description="query for Internet search")):
|
||||
'''Use this tool to use bing search engine to search the internet and get information.'''
|
||||
tool_config = get_tool_config("search_internet")
|
||||
return search_engine(query=query, config=tool_config)
|
||||
return BaseToolOutput(search_engine(query=query, config=tool_config))
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from urllib.parse import urlencode
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||
from chatchat.configs import KB_INFO
|
||||
|
||||
|
||||
@ -11,35 +12,44 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||
|
||||
|
||||
class KBToolOutput(BaseToolOutput):
|
||||
def __str__(self) -> str:
|
||||
context = ""
|
||||
docs = self.data
|
||||
source_documents = []
|
||||
|
||||
for inum, doc in enumerate(docs):
|
||||
doc = DocumentWithVSId.parse_obj(doc)
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename})
|
||||
url = f"download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0:
|
||||
context = "没有找到相关文档,请更换关键词重试"
|
||||
else:
|
||||
for doc in source_documents:
|
||||
context += doc + "\n"
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def search_knowledgebase(query: str, database: str, config: dict):
|
||||
docs = search_docs(
|
||||
query=query,
|
||||
knowledge_base_name=database,
|
||||
top_k=config["top_k"],
|
||||
score_threshold=config["score_threshold"])
|
||||
context = ""
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
|
||||
url = f"download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0:
|
||||
context = "没有找到相关文档,请更换关键词重试"
|
||||
else:
|
||||
for doc in source_documents:
|
||||
context += doc + "\n"
|
||||
|
||||
return context
|
||||
return docs
|
||||
|
||||
|
||||
@regist_tool(description=template_knowledge)
|
||||
@regist_tool(description=template_knowledge, title="本地知识库")
|
||||
def search_local_knowledgebase(
|
||||
database: str = Field(description="Database for Knowledge Search"),
|
||||
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
|
||||
query: str = Field(description="Query for Knowledge Search"),
|
||||
):
|
||||
''''''
|
||||
tool_config = get_tool_config("search_local_knowledgebase")
|
||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
return BaseToolOutput(ret, database=database)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="油管视频")
|
||||
def search_youtube(query: str = Field(description="Query for Videos search")):
|
||||
'''use this tools_factory to search youtube videos'''
|
||||
from langchain_community.tools import YouTubeSearchTool
|
||||
tool = YouTubeSearchTool()
|
||||
return tool.run(tool_input=query)
|
||||
return BaseToolOutput(tool.run(tool_input=query))
|
||||
|
||||
@ -2,11 +2,11 @@
|
||||
from langchain.tools.shell import ShellTool
|
||||
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="系统命令")
|
||||
def shell(query: str = Field(description="The command to execute")):
|
||||
'''Use Shell to execute system shell commands'''
|
||||
tool = ShellTool()
|
||||
return tool.run(tool_input=query)
|
||||
return BaseToolOutput(tool.run(tool_input=query))
|
||||
|
||||
@ -7,7 +7,7 @@ import uuid
|
||||
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
import openai
|
||||
|
||||
from chatchat.configs.basic_config import MEDIA_PATH
|
||||
@ -26,7 +26,7 @@ def get_image_model_config() -> dict:
|
||||
return config
|
||||
|
||||
|
||||
@regist_tool(return_direct=True)
|
||||
@regist_tool(title="文生图", return_direct=True)
|
||||
def text2images(
|
||||
prompt: str,
|
||||
n: int = Field(1, description="需生成图片的数量"),
|
||||
@ -56,7 +56,7 @@ def text2images(
|
||||
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
|
||||
fp.write(base64.b64decode(x.b64_json))
|
||||
images.append(filename)
|
||||
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
|
||||
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
|
||||
|
||||
from langchain.agents import tool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from chatchat.server.pydantic_v1 import BaseModel
|
||||
from chatchat.server.pydantic_v1 import BaseModel, Extra
|
||||
|
||||
|
||||
__all__ = ["regist_tool", "BaseToolOutput"]
|
||||
|
||||
|
||||
_TOOLS_REGISTRY = {}
|
||||
|
||||
|
||||
# patch BaseTool to support extra fields e.g. a title
|
||||
BaseTool.Config.extra = Extra.allow
|
||||
|
||||
################################### TODO: workaround to langchain #15855
|
||||
# patch BaseTool to support tool parameters defined using pydantic Field
|
||||
|
||||
@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
|
||||
|
||||
def regist_tool(
|
||||
*args: Any,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
@ -70,9 +78,10 @@ def regist_tool(
|
||||
add tool to regstiry automatically
|
||||
'''
|
||||
def _parse_tool(t: BaseTool):
|
||||
nonlocal description
|
||||
nonlocal description, title
|
||||
|
||||
_TOOLS_REGISTRY[t.name] = t
|
||||
|
||||
# change default description
|
||||
if not description:
|
||||
if t.func is not None:
|
||||
@ -80,6 +89,10 @@ def regist_tool(
|
||||
elif t.coroutine is not None:
|
||||
description = t.coroutine.__doc__
|
||||
t.description = " ".join(re.split(r"\n+\s*", description))
|
||||
# set a default title for human
|
||||
if not title:
|
||||
title = "".join([x.capitalize() for x in t.name.split("_")])
|
||||
t.title = title
|
||||
|
||||
def wrapper(def_func: Callable) -> BaseTool:
|
||||
partial_ = tool(*args,
|
||||
@ -101,3 +114,30 @@ def regist_tool(
|
||||
)
|
||||
_parse_tool(t)
|
||||
return t
|
||||
|
||||
|
||||
class BaseToolOutput:
|
||||
'''
|
||||
LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。
|
||||
只需要将 Tool 返回值用该类封装,能同时满足两者的需要。
|
||||
基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。
|
||||
用户也可以继承该类定义自己的转换方法。
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
format: str="",
|
||||
data_alias: str="",
|
||||
**extras: Any,
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.format = format
|
||||
self.extras = extras
|
||||
if data_alias:
|
||||
setattr(self, data_alias, property(lambda obj: obj.data))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.format == "json":
|
||||
return json.dumps(self.data, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return str(self.data)
|
||||
|
||||
@ -6,7 +6,7 @@ from io import BytesIO
|
||||
from PIL import Image, ImageDraw
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
import re
|
||||
from chatchat.server.agent.container import container
|
||||
|
||||
@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
||||
return response
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="图片对话")
|
||||
def vqa_processor(query: str = Field(description="The question of the image in English")):
|
||||
'''use this tool to get answer for image question'''
|
||||
|
||||
@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E
|
||||
# end_marker = "Grounded Operation:"
|
||||
# ans = extract_between_markers(ans, start_marker, end_marker)
|
||||
|
||||
return ans
|
||||
ret = ans
|
||||
else:
|
||||
return "No Image, Please Try Again"
|
||||
ret = "No Image, Please Try Again"
|
||||
|
||||
return BaseToolOutput(ret)
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
"""
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
import requests
|
||||
|
||||
|
||||
@regist_tool
|
||||
@regist_tool(title="天气查询")
|
||||
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
|
||||
'''Use this tool to check the weather at a specific city'''
|
||||
|
||||
@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun
|
||||
"temperature": data["results"][0]["now"]["temperature"],
|
||||
"description": data["results"][0]["now"]["text"],
|
||||
}
|
||||
return weather
|
||||
return BaseToolOutput(weather)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to retrieve weather: {response.status_code}")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
|
||||
|
||||
@regist_tool
|
||||
@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")):
|
||||
|
||||
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
|
||||
ans = wolfram.run(query)
|
||||
return ans
|
||||
return BaseToolOutput(ans)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import UploadFile
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
@ -12,8 +12,10 @@ from openai.types.chat import (
|
||||
completion_create_params,
|
||||
)
|
||||
|
||||
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
|
||||
|
||||
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
|
||||
from chatchat.server.utils import MsgType
|
||||
|
||||
class OpenAIBaseInput(BaseModel):
|
||||
user: Optional[str] = None
|
||||
@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel):
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict] = None
|
||||
extra_query: Optional[Dict] = None
|
||||
extra_body: Optional[Dict] = None
|
||||
extra_json: Optional[Dict] = Field(None, alias="extra_body")
|
||||
timeout: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput):
|
||||
stop: Union[Optional[str], List[str]] = None
|
||||
stream: Optional[bool] = None
|
||||
temperature: Optional[float] = TEMPERATURE
|
||||
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
|
||||
tools: List[ChatCompletionToolParam] = None
|
||||
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
|
||||
tools: List[Union[ChatCompletionToolParam, str]] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
|
||||
@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
|
||||
voice: str
|
||||
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
|
||||
speed: Optional[float] = None
|
||||
|
||||
|
||||
class OpenAIBaseOutput(BaseModel):
|
||||
id: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
finish_reason: Optional[str] = None
|
||||
created: int = Field(default_factory=lambda : int(time.time()))
|
||||
tool_calls: List[Dict] = []
|
||||
|
||||
status: Optional[int] = None # AgentStatus
|
||||
message_type: int = MsgType.TEXT
|
||||
message_id: Optional[str] = None # id in database table
|
||||
is_ref: bool = False # wheather show in seperated expander
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def model_dump(self) -> dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"object": self.object,
|
||||
"model": self.model,
|
||||
"created": self.created,
|
||||
|
||||
"status": self.status,
|
||||
"message_type": self.message_type,
|
||||
"message_id": self.message_id,
|
||||
"is_ref": self.is_ref,
|
||||
**(self.model_extra or {}),
|
||||
}
|
||||
|
||||
if self.object == "chat.completion.chunk":
|
||||
result["choices"] = [{
|
||||
"delta": {
|
||||
"content": self.content,
|
||||
"tool_calls": self.tool_calls,
|
||||
},
|
||||
"role": self.role,
|
||||
}]
|
||||
elif self.object == "chat.completion":
|
||||
result["choices"] = [{
|
||||
"message": {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
"finish_reason": self.finish_reason,
|
||||
"tool_calls": self.tool_calls,
|
||||
}
|
||||
}]
|
||||
return result
|
||||
|
||||
def model_dump_json(self):
|
||||
return json.dumps(self.model_dump(), ensure_ascii=False)
|
||||
|
||||
|
||||
class OpenAIChatOutput(OpenAIBaseOutput):
|
||||
...
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
|
||||
from chatchat.server.chat.chat import chat
|
||||
from chatchat.server.chat.feedback import chat_feedback
|
||||
from chatchat.server.chat.file_chat import file_chat
|
||||
from chatchat.server.db.repository import add_message_to_db
|
||||
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
|
||||
from .openai_routes import openai_request
|
||||
|
||||
|
||||
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
|
||||
@ -21,4 +26,115 @@ chat_router.post("/feedback",
|
||||
|
||||
chat_router.post("/file_chat",
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
)(file_chat)
|
||||
|
||||
|
||||
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||
async def chat_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
) -> Dict:
|
||||
'''
|
||||
请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数
|
||||
tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换
|
||||
通过不同的参数组合调用不同的 chat 功能:
|
||||
- tool_choice
|
||||
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
|
||||
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
|
||||
- tools: agent 对话
|
||||
- 其它:LLM 对话
|
||||
以后还要考虑其它的组合(如文件对话)
|
||||
返回与 openai 兼容的 Dict
|
||||
'''
|
||||
client = get_OpenAIClient(model_name=body.model, is_async=True)
|
||||
extra = {**body.model_extra} or {}
|
||||
for key in list(extra):
|
||||
delattr(body, key)
|
||||
|
||||
# check tools & tool_choice in request body
|
||||
if isinstance(body.tool_choice, str):
|
||||
if t := get_tool(body.tool_choice):
|
||||
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
|
||||
if isinstance(body.tools, list):
|
||||
for i in range(len(body.tools)):
|
||||
if isinstance(body.tools[i], str):
|
||||
if t := get_tool(body.tools[i]):
|
||||
body.tools[i] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"parameters": t.args,
|
||||
}
|
||||
}
|
||||
|
||||
conversation_id = extra.get("conversation_id")
|
||||
|
||||
# chat based on result from one choiced tool
|
||||
if body.tool_choice:
|
||||
tool = get_tool(body.tool_choice["function"]["name"])
|
||||
if not body.tools:
|
||||
body.tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.args,
|
||||
}
|
||||
}]
|
||||
if tool_input := extra.get("tool_input"):
|
||||
message_id = add_message_to_db(
|
||||
chat_type="tool_call",
|
||||
query=body.messages[-1]["content"],
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
|
||||
tool_result = await tool.ainvoke(tool_input)
|
||||
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
|
||||
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
|
||||
del body.tools
|
||||
del body.tool_choice
|
||||
extra_json = {
|
||||
"message_id": message_id,
|
||||
"status": None,
|
||||
}
|
||||
header = [{**extra_json,
|
||||
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
|
||||
"tool_output":tool_result.data,
|
||||
"is_ref": True,
|
||||
}]
|
||||
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
|
||||
|
||||
# agent chat with tool calls
|
||||
if body.tools:
|
||||
message_id = add_message_to_db(
|
||||
chat_type="agent_chat",
|
||||
query=body.messages[-1]["content"],
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
|
||||
chat_model_config = {} # TODO: 前端支持配置模型
|
||||
tool_names = [x["function"]["name"] for x in body.tools]
|
||||
tool_config = {name: get_tool_config(name) for name in tool_names}
|
||||
result = await chat(query=body.messages[-1]["content"],
|
||||
metadata=extra.get("metadata", {}),
|
||||
conversation_id=extra.get("conversation_id", ""),
|
||||
message_id=message_id,
|
||||
history_len=-1,
|
||||
history=body.messages[:-1],
|
||||
stream=body.stream,
|
||||
chat_model_config=extra.get("chat_model_config", chat_model_config),
|
||||
tool_config=extra.get("tool_config", tool_config),
|
||||
)
|
||||
return result
|
||||
else: # LLM chat directly
|
||||
message_id = add_message_to_db(
|
||||
chat_type="llm_chat",
|
||||
query=body.messages[-1]["content"],
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
extra_json = {
|
||||
"message_id": message_id,
|
||||
"status": None,
|
||||
}
|
||||
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Tuple, AsyncGenerator
|
||||
from typing import Dict, Tuple, AsyncGenerator, Iterable
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from openai import AsyncClient
|
||||
@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
||||
async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
||||
'''
|
||||
对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
|
||||
'''
|
||||
@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def openai_request(method, body):
|
||||
async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]):
|
||||
'''
|
||||
helper function to make openai request
|
||||
helper function to make openai request with extra fields
|
||||
'''
|
||||
async def generator():
|
||||
async for chunk in await method(**params):
|
||||
yield {"data": chunk.json()}
|
||||
for x in header:
|
||||
if isinstance(x, str):
|
||||
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||
elif isinstance(x, dict):
|
||||
x = OpenAIChatOutput.model_validate(x)
|
||||
else:
|
||||
raise RuntimeError(f"unsupported value: {header}")
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
yield x.model_dump_json()
|
||||
|
||||
async for chunk in await method(**params):
|
||||
for k, v in extra_json.items():
|
||||
setattr(chunk, k, v)
|
||||
yield chunk.model_dump_json()
|
||||
|
||||
for x in tail:
|
||||
if isinstance(x, str):
|
||||
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||
elif isinstance(x, dict):
|
||||
x = OpenAIChatOutput.model_validate(x)
|
||||
else:
|
||||
raise RuntimeError(f"unsupported value: {tail}")
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
yield x.model_dump_json()
|
||||
|
||||
params = body.model_dump(exclude_unset=True)
|
||||
|
||||
params = body.dict(exclude_unset=True)
|
||||
if hasattr(body, "stream") and body.stream:
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
return (await method(**params)).dict()
|
||||
result = await method(**params)
|
||||
for k, v in extra_json.items():
|
||||
setattr(result, k, v)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@openai_router.get("/models")
|
||||
@ -74,12 +102,12 @@ async def list_models() -> List:
|
||||
client = get_OpenAIClient(name, is_async=True)
|
||||
models = await client.models.list()
|
||||
if config.get("platform_type") == "xinference":
|
||||
models = models.dict(exclude={"data":..., "object":...})
|
||||
models = models.model_dump(exclude={"data":..., "object":...})
|
||||
for x in models:
|
||||
models[x]["platform_name"] = name
|
||||
return [{**v, "id": k} for k, v in models.items()]
|
||||
elif config.get("platform_type") == "oneapi":
|
||||
return [{**x.dict(), "platform_name": name} for x in models.data]
|
||||
return [{**x.model_dump(), "platform_name": name} for x in models.data]
|
||||
except Exception:
|
||||
logger.error(f"failed request to platform: {name}", exc_info=True)
|
||||
return {}
|
||||
@ -97,12 +125,8 @@ async def create_chat_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
result = await openai_request(client.chat.completions.create, body)
|
||||
# result["related_docs"] = ["doc1"]
|
||||
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
|
||||
# print(result)
|
||||
# breakpoint()
|
||||
return result
|
||||
|
||||
|
||||
@ -111,7 +135,7 @@ async def create_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.completions.create, body)
|
||||
|
||||
|
||||
@ -130,7 +154,7 @@ async def create_image_generations(
|
||||
request: Request,
|
||||
body: OpenAIImageGenerationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.images.generate, body)
|
||||
|
||||
|
||||
@ -139,7 +163,7 @@ async def create_image_variations(
|
||||
request: Request,
|
||||
body: OpenAIImageVariationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.images.create_variation, body)
|
||||
|
||||
|
||||
@ -148,7 +172,7 @@ async def create_image_edit(
|
||||
request: Request,
|
||||
body: OpenAIImageEditsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.images.edit, body)
|
||||
|
||||
|
||||
@ -157,7 +181,7 @@ async def create_audio_translations(
|
||||
request: Request,
|
||||
body: OpenAIAudioTranslationsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.translations.create, body)
|
||||
|
||||
|
||||
@ -166,7 +190,7 @@ async def create_audio_transcriptions(
|
||||
request: Request,
|
||||
body: OpenAIAudioTranscriptionsInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.transcriptions.create, body)
|
||||
|
||||
|
||||
@ -175,7 +199,7 @@ async def create_audio_speech(
|
||||
request: Request,
|
||||
body: OpenAIAudioSpeechInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.audio.speech.create, body)
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import List
|
||||
from fastapi import APIRouter, Request, Body
|
||||
|
||||
from chatchat.configs import logger
|
||||
from chatchat.server.utils import BaseResponse, get_tool
|
||||
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
|
||||
|
||||
|
||||
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
||||
@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
||||
@tool_router.get("/", response_model=BaseResponse)
|
||||
async def list_tools():
|
||||
tools = get_tool()
|
||||
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
|
||||
data = {t.name: {
|
||||
"name": t.name,
|
||||
"title": t.title,
|
||||
"description": t.description,
|
||||
"args": t.args,
|
||||
"config": get_tool_config(t.name),
|
||||
} for t in tools.values()}
|
||||
return {"data": data}
|
||||
|
||||
|
||||
@tool_router.post("/call", response_model=BaseResponse)
|
||||
async def call_tool(
|
||||
name: str = Body(examples=["calculate"]),
|
||||
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
|
||||
tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]),
|
||||
):
|
||||
tools = get_tool()
|
||||
|
||||
if tool := tools.get(name):
|
||||
if tool := get_tool(name):
|
||||
try:
|
||||
result = await tool.ainvoke(kwargs)
|
||||
result = await tool.ainvoke(tool_input)
|
||||
return {"data": result}
|
||||
except Exception:
|
||||
msg = f"failed to call tool '{name}'"
|
||||
|
||||
@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
data = {
|
||||
"run_id": str(run_id),
|
||||
"status": AgentStatus.tool_start,
|
||||
"tool": serialized["name"],
|
||||
"tool_input": input_str,
|
||||
}
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
|
||||
async def on_tool_end(
|
||||
@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
data = {
|
||||
"run_id": str(run_id),
|
||||
"status": AgentStatus.tool_end,
|
||||
"tool_output": output,
|
||||
}
|
||||
@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
data = {
|
||||
"run_id": str(run_id),
|
||||
"status": AgentStatus.tool_end,
|
||||
"text": str(error),
|
||||
"tool_output": str(error),
|
||||
"is_error": True,
|
||||
}
|
||||
# self.done.clear()
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterable, List
|
||||
import uuid
|
||||
|
||||
from fastapi import Body
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from chatchat.configs.model_config import LLM_MODEL_CONFIG
|
||||
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
|
||||
from chatchat.server.agent.container import container
|
||||
|
||||
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
|
||||
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
|
||||
from chatchat.server.chat.utils import History
|
||||
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from chatchat.server.db.repository import add_message_to_db
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
configs = configs or LLM_MODEL_CONFIG
|
||||
models = {}
|
||||
prompts = {}
|
||||
for model_type, model_configs in configs.items():
|
||||
@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
message_id: str = Body(None, description="数据库消息ID"),
|
||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
history: List[History] = Body(
|
||||
[],
|
||||
@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
||||
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
message_id = add_message_to_db(
|
||||
chat_type="llm_chat",
|
||||
query=query,
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
'''Agent 对话'''
|
||||
|
||||
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
||||
@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
}
|
||||
), callback.done))
|
||||
|
||||
last_tool = {}
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == AgentStatus.tool_end:
|
||||
data["tool_calls"] = []
|
||||
data["message_type"] = MsgType.TEXT
|
||||
|
||||
if data["status"] == AgentStatus.tool_start:
|
||||
last_tool = {
|
||||
"index": 0,
|
||||
"id": data["run_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": data["tool"],
|
||||
"arguments": data["tool_input"],
|
||||
},
|
||||
"tool_output": None,
|
||||
"is_error": False,
|
||||
}
|
||||
data["tool_calls"].append(last_tool)
|
||||
if data["status"] in [AgentStatus.tool_end]:
|
||||
last_tool.update(
|
||||
tool_output=data["tool_output"],
|
||||
is_error=data.get("is_error", False)
|
||||
)
|
||||
data["tool_calls"] = [last_tool]
|
||||
last_tool = {}
|
||||
try:
|
||||
tool_output = json.loads(data["tool_output"])
|
||||
if message_type := tool_output.get("message_type"):
|
||||
@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
data["message_type"] = message_type
|
||||
except:
|
||||
...
|
||||
data.setdefault("message_type", MsgType.TEXT)
|
||||
data["message_id"] = message_id
|
||||
yield json.dumps(data, ensure_ascii=False)
|
||||
|
||||
ret = OpenAIChatOutput(
|
||||
id=f"chat{uuid.uuid4()}",
|
||||
object="chat.completion.chunk",
|
||||
content=data.get("text", ""),
|
||||
role="assistant",
|
||||
tool_calls=data["tool_calls"],
|
||||
model=models["llm_model"].model_name,
|
||||
status = data["status"],
|
||||
message_type = data["message_type"],
|
||||
message_id=message_id,
|
||||
)
|
||||
yield ret.model_dump_json()
|
||||
await task
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
if stream:
|
||||
return EventSourceResponse(chat_iterator())
|
||||
else:
|
||||
ret = OpenAIChatOutput(
|
||||
id=f"chat{uuid.uuid4()}",
|
||||
object="chat.completion",
|
||||
content="",
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
tool_calls=[],
|
||||
status = AgentStatus.agent_finish,
|
||||
message_type = MsgType.TEXT,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async for chunk in chat_iterator():
|
||||
data = json.loads(chunk)
|
||||
if text := data["choices"][0]["delta"]["content"]:
|
||||
ret.content += text
|
||||
if data["status"] == AgentStatus.tool_end:
|
||||
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
|
||||
ret.model = data["model"]
|
||||
ret.created = data["created"]
|
||||
|
||||
return ret.model_dump()
|
||||
|
||||
@ -96,11 +96,7 @@ class ESKBService(KBService):
|
||||
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
def do_create_kb(self):
|
||||
if os.path.exists(self.doc_path):
|
||||
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
|
||||
os.makedirs(os.path.join(self.kb_path, "vector_store"))
|
||||
else:
|
||||
logger.warning("directory `vector_store` already exists.")
|
||||
...
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ES
|
||||
|
||||
@ -65,7 +65,6 @@ class FaissKBService(KBService):
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
embeddings = vs.embeddings.embed_query(query)
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
|
||||
@ -1,24 +1,25 @@
|
||||
import base64
|
||||
|
||||
from chatchat.server.utils import get_tool_config
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from chatchat.webui_pages.dialogue.utils import process_files
|
||||
# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
# get_select_model_endpoint
|
||||
from chatchat.webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict
|
||||
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
import openai
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
from datetime import datetime
|
||||
|
||||
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from chatchat.server.utils import MsgType, get_config_models
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.webui_pages.utils import *
|
||||
from chatchat.webui_pages.dialogue.utils import process_files
|
||||
|
||||
|
||||
img_dir = (Path(__file__).absolute().parent.parent.parent)
|
||||
|
||||
@ -121,69 +122,66 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.write("\n\n".join(cmds))
|
||||
|
||||
with st.sidebar:
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
tab1, tab2 = st.tabs(["对话设置", "模型设置"])
|
||||
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
with tab1:
|
||||
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
|
||||
# 选择工具
|
||||
tools = api.list_tools()
|
||||
if use_agent:
|
||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||
else:
|
||||
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||
selected_tools = [selected_tool]
|
||||
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
|
||||
|
||||
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||
platform = st.selectbox("选择模型平台", platforms)
|
||||
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
||||
llm_model = st.selectbox("选择LLM模型", llm_models)
|
||||
# 当不启用Agent时,手动生成工具参数
|
||||
# TODO: 需要更精细的控制控件
|
||||
tool_input = {}
|
||||
if not use_agent and len(selected_tools) == 1:
|
||||
with st.expander("工具参数", True):
|
||||
for k, v in tools[selected_tools[0]]["args"].items():
|
||||
if choices := v.get("choices", v.get("enum")):
|
||||
tool_input[k] = st.selectbox(v["title"], choices)
|
||||
else:
|
||||
if v["type"] == "integer":
|
||||
tool_input[k] = st.slider(v["title"], value=v.get("default"))
|
||||
elif v["type"] == "number":
|
||||
tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1)
|
||||
else:
|
||||
tool_input[k] = st.text_input(v["title"], v.get("default"))
|
||||
|
||||
# 传入后端的内容
|
||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
tool_use = True
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if key == 'llm_model':
|
||||
continue
|
||||
if key == 'action_model':
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
if first_key not in SUPPORT_AGENT_MODELS:
|
||||
st.warning("不支持Agent的模型,无法执行任何工具调用")
|
||||
tool_use = False
|
||||
continue
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
|
||||
# 选择工具
|
||||
selected_tool_configs = {}
|
||||
if tool_use:
|
||||
from chatchat.configs import model_config as model_config_py
|
||||
import importlib
|
||||
importlib.reload(model_config_py)
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
|
||||
with tab2:
|
||||
# 会话
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
tools = get_tool_config()
|
||||
with st.expander("工具栏"):
|
||||
for tool in tools:
|
||||
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = tools[tool]
|
||||
# 模型
|
||||
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||
platform = st.selectbox("选择模型平台", platforms)
|
||||
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
||||
llm_model = st.selectbox("选择LLM模型", llm_models)
|
||||
|
||||
if llm_model is not None:
|
||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
# 传入后端的内容
|
||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
# print(len(files_upload["audios"])) if files_upload else None
|
||||
if llm_model is not None:
|
||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
|
||||
# if dialogue_mode == "文件对话":
|
||||
# with st.expander("文件对话配置", True):
|
||||
# files = st.file_uploader("上传知识文件:",
|
||||
# [i for ls in LOADER_DICT.values() for i in ls],
|
||||
# accept_multiple_files=True,
|
||||
# )
|
||||
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
# if st.button("开始上传", disabled=len(files) == 0):
|
||||
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
chat_box.output_messages()
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
text_action = ""
|
||||
element_index = 0
|
||||
started = False
|
||||
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
chat_model_config=chat_model_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_config=selected_tool_configs,
|
||||
):
|
||||
message_id = d.get("message_id", "")
|
||||
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
|
||||
messages = history + [{"role": "user", "content": prompt}]
|
||||
tools = list(selected_tool_configs)
|
||||
if len(selected_tools) == 1:
|
||||
tool_choice = selected_tools[0]
|
||||
else:
|
||||
tool_choice = None
|
||||
# 如果 tool_input 中有空的字段,设为用户输入
|
||||
for k in tool_input:
|
||||
if tool_input[k] in [None, ""]:
|
||||
tool_input[k] = prompt
|
||||
|
||||
extra_body = dict(
|
||||
metadata=files_upload,
|
||||
chat_model_config=chat_model_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_input = tool_input,
|
||||
)
|
||||
for d in client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
extra_body=extra_body,
|
||||
):
|
||||
print("\n\n", d.status, "\n", d, "\n\n")
|
||||
message_id = d.message_id
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
print(d)
|
||||
if d["status"] == AgentStatus.error:
|
||||
st.error(d["text"])
|
||||
elif d["status"] == AgentStatus.agent_action:
|
||||
formatted_data = {
|
||||
"Function": d["tool_name"],
|
||||
"function_input": d["tool_input"]
|
||||
}
|
||||
element_index += 1
|
||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
chat_box.insert_msg(
|
||||
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
|
||||
text = """\n```{}\n```\n""".format(formatted_json)
|
||||
chat_box.update_msg(Markdown(text), element_index=element_index)
|
||||
elif d["status"] == AgentStatus.tool_end:
|
||||
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
|
||||
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
|
||||
elif d["status"] == AgentStatus.llm_new_token:
|
||||
text += d["text"]
|
||||
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
||||
elif d["status"] == AgentStatus.llm_end:
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
elif d["status"] == AgentStatus.agent_finish:
|
||||
if d["message_type"] == MsgType.IMAGE:
|
||||
for url in json.loads(d["text"]).get("images", []):
|
||||
url = f"{api.base_url}/media/{url}"
|
||||
chat_box.insert_msg(Image(url))
|
||||
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
|
||||
|
||||
if d.status == AgentStatus.error:
|
||||
st.error(d.choices[0].delta.content)
|
||||
elif d.status == AgentStatus.llm_start:
|
||||
if not started:
|
||||
started = True
|
||||
else:
|
||||
chat_box.insert_msg(Markdown(d["text"], expanded=True))
|
||||
chat_box.insert_msg("正在解读工具输出结果...")
|
||||
text = d.choices[0].delta.content or ""
|
||||
elif d.status == AgentStatus.llm_new_token:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||
elif d.status == AgentStatus.llm_end:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
|
||||
# tool 的输出与 llm 输出重复了
|
||||
# elif d.status == AgentStatus.tool_start:
|
||||
# formatted_data = {
|
||||
# "Function": d.choices[0].delta.tool_calls[0].function.name,
|
||||
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
|
||||
# }
|
||||
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
# text = """\n```{}\n```\n""".format(formatted_json)
|
||||
# chat_box.insert_msg( # TODO: insert text directly not shown
|
||||
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
|
||||
# elif d.status == AgentStatus.tool_end:
|
||||
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
|
||||
# if d.message_type == MsgType.IMAGE:
|
||||
# for url in json.loads(tool_output).get("images", []):
|
||||
# url = f"{api.base_url}/media/{url}"
|
||||
# chat_box.insert_msg(Image(url))
|
||||
# chat_box.update_msg(expanded=False, state="complete")
|
||||
# else:
|
||||
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
|
||||
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
|
||||
elif d.status == AgentStatus.agent_finish:
|
||||
text = d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||
elif d.status == None: # not agent chat
|
||||
if getattr(d, "is_ref", False):
|
||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
|
||||
chat_box.insert_msg("")
|
||||
else:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata)
|
||||
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.rerun()
|
||||
|
||||
now = datetime.now()
|
||||
with st.sidebar:
|
||||
|
||||
with tab1:
|
||||
cols = st.columns(2)
|
||||
export_btn = cols[0]
|
||||
if cols[1].button(
|
||||
@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
mime="text/markdown",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
# st.write(chat_box.history)
|
||||
|
||||
@ -50,6 +50,14 @@ class ApiRequest:
|
||||
timeout=self.timeout)
|
||||
return self._client
|
||||
|
||||
def _check_url(self, url: str) -> str:
|
||||
'''
|
||||
新版 httpx 强制要求 url 以 / 结尾,否则会返回 307
|
||||
'''
|
||||
if not url.endswith("/"):
|
||||
url = url + "/"
|
||||
return url
|
||||
|
||||
def get(
|
||||
self,
|
||||
url: str,
|
||||
@ -58,6 +66,7 @@ class ApiRequest:
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
url = self._check_url(url)
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
@ -79,6 +88,7 @@ class ApiRequest:
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
url = self._check_url(url)
|
||||
while retry > 0:
|
||||
try:
|
||||
# print(kwargs)
|
||||
@ -101,6 +111,7 @@ class ApiRequest:
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
url = self._check_url(url)
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
@ -638,6 +649,27 @@ class ApiRequest:
|
||||
resp = self.post("/chat/feedback", json=data)
|
||||
return self._get_response_value(resp)
|
||||
|
||||
def list_tools(self) -> Dict:
|
||||
'''
|
||||
列出所有工具
|
||||
'''
|
||||
resp = self.get("/tools")
|
||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
tool_input: Dict = {},
|
||||
):
|
||||
'''
|
||||
调用工具
|
||||
'''
|
||||
data = {
|
||||
"name": name,
|
||||
"tool_input": tool_input,
|
||||
}
|
||||
resp = self.post("/tools/call", json=data)
|
||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||
|
||||
class AsyncApiRequest(ApiRequest):
|
||||
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user