优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)

- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui
This commit is contained in:
liunux4odoo 2024-03-29 11:55:32 +08:00 committed by GitHub
parent 5e70aff522
commit 42aa900566
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 614 additions and 227 deletions

View File

@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
s = s[-1]
return AgentFinish({"output": s}, log=text)
else:
raise OutputParserException(f"Could not parse LLM output: {text}")
return AgentFinish({"output": text}, log=text)
# raise OutputParserException(f"Could not parse LLM output: {text}")
@property
def _type(self) -> str:

View File

@ -1,7 +1,7 @@
import base64
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
def save_base64_audio(base64_audio, file_path):
@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query):
return response
@regist_tool
def aqa_processor(query: str = Field(description="The question of the image in English")):
@regist_tool(title="音频问答")
def aqa_processor(query: str = Field(description="The question of the audio in English")):
'''use this tool to get answer for audio question'''
from chatchat.server.agent.container import container
@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E
"audio": file_path,
"text": query,
}
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else:
return "No Audio, Please Try Again"
ret = "No Audio, Please Try Again"
return BaseToolOutput(ret)

View File

@ -1,12 +1,12 @@
# LangChain 的 ArxivQueryRun 工具
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="ARXIV论文")
def arxiv(query: str = Field(description="The search query title")):
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
from langchain.tools.arxiv.tool import ArxivQueryRun
tool = ArxivQueryRun()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -1,8 +1,8 @@
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="数学计算器")
def calculate(text: str = Field(description="a math expression")) -> float:
'''
Useful to answer questions about simple calculations.
@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float:
import numexpr
try:
return str(numexpr.evaluate(text))
ret = str(numexpr.evaluate(text))
except Exception as e:
return f"wrong: {e}"
ret = f"wrong: {e}"
return BaseToolOutput(ret)

View File

@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
def bing_search(text, config):
@ -95,8 +95,8 @@ def search_engine(query: str,
return context
@regist_tool
@regist_tool(title="互联网搜索")
def search_internet(query: str = Field(description="query for Internet search")):
'''Use this tool to use bing search engine to search the internet and get information.'''
tool_config = get_tool_config("search_internet")
return search_engine(query=query, config=tool_config)
return BaseToolOutput(search_engine(query=query, config=tool_config))

View File

@ -1,8 +1,9 @@
from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from chatchat.server.knowledge_base.kb_doc_api import search_docs
from .tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO
@ -11,35 +12,44 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
class KBToolOutput(BaseToolOutput):
def __str__(self) -> str:
context = ""
docs = self.data
source_documents = []
for inum, doc in enumerate(docs):
doc = DocumentWithVSId.parse_obj(doc)
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
def search_knowledgebase(query: str, database: str, config: dict):
docs = search_docs(
query=query,
knowledge_base_name=database,
top_k=config["top_k"],
score_threshold=config["score_threshold"])
context = ""
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
return docs
@regist_tool(description=template_knowledge)
@regist_tool(description=template_knowledge, title="本地知识库")
def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search"),
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
query: str = Field(description="Query for Knowledge Search"),
):
''''''
tool_config = get_tool_config("search_local_knowledgebase")
return search_knowledgebase(query=query, database=database, config=tool_config)
ret = search_knowledgebase(query=query, database=database, config=tool_config)
return BaseToolOutput(ret, database=database)

View File

@ -1,10 +1,10 @@
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="油管视频")
def search_youtube(query: str = Field(description="Query for Videos search")):
'''use this tools_factory to search youtube videos'''
from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -2,11 +2,11 @@
from langchain.tools.shell import ShellTool
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="系统命令")
def shell(query: str = Field(description="The command to execute")):
'''Use Shell to execute system shell commands'''
tool = ShellTool()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -7,7 +7,7 @@ import uuid
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import openai
from chatchat.configs.basic_config import MEDIA_PATH
@ -26,7 +26,7 @@ def get_image_model_config() -> dict:
return config
@regist_tool(return_direct=True)
@regist_tool(title="文生图", return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
@ -56,7 +56,7 @@ def text2images(
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
if __name__ == "__main__":

View File

@ -1,15 +1,22 @@
import json
import re
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
from langchain.agents import tool
from langchain_core.tools import BaseTool
from chatchat.server.pydantic_v1 import BaseModel
from chatchat.server.pydantic_v1 import BaseModel, Extra
__all__ = ["regist_tool", "BaseToolOutput"]
_TOOLS_REGISTRY = {}
# patch BaseTool to support extra fields e.g. a title
BaseTool.Config.extra = Extra.allow
################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field
@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
def regist_tool(
*args: Any,
title: str = "",
description: str = "",
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
@ -70,9 +78,10 @@ def regist_tool(
add tool to regstiry automatically
'''
def _parse_tool(t: BaseTool):
nonlocal description
nonlocal description, title
_TOOLS_REGISTRY[t.name] = t
# change default description
if not description:
if t.func is not None:
@ -80,6 +89,10 @@ def regist_tool(
elif t.coroutine is not None:
description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description))
# set a default title for human
if not title:
title = "".join([x.capitalize() for x in t.name.split("_")])
t.title = title
def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args,
@ -101,3 +114,30 @@ def regist_tool(
)
_parse_tool(t)
return t
class BaseToolOutput:
'''
LLM 要求 Tool 的输出为 str Tool 用在别处时希望它正常返回结构化数据
只需要将 Tool 返回值用该类封装能同时满足两者的需要
基类简单的将返回值字符串化或指定 format="json" 将其转为 json
用户也可以继承该类定义自己的转换方法
'''
def __init__(
self,
data: Any,
format: str="",
data_alias: str="",
**extras: Any,
) -> None:
self.data = data
self.format = format
self.extras = extras
if data_alias:
setattr(self, data_alias, property(lambda obj: obj.data))
def __str__(self) -> str:
if self.format == "json":
return json.dumps(self.data, ensure_ascii=False, indent=2)
else:
return str(self.data)

View File

@ -6,7 +6,7 @@ from io import BytesIO
from PIL import Image, ImageDraw
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import re
from chatchat.server.agent.container import container
@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
return response
@regist_tool
@regist_tool(title="图片对话")
def vqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for image question'''
@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E
# end_marker = "Grounded Operation:"
# ans = extract_between_markers(ans, start_marker, end_marker)
return ans
ret = ans
else:
return "No Image, Please Try Again"
ret = "No Image, Please Try Again"
return BaseToolOutput(ret)

View File

@ -3,11 +3,11 @@
"""
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import requests
@regist_tool
@regist_tool(title="天气查询")
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
'''Use this tool to check the weather at a specific city'''
@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
return BaseToolOutput(weather)
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")

View File

@ -2,7 +2,7 @@
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
ans = wolfram.run(query)
return ans
return BaseToolOutput(ans)

View File

@ -1,10 +1,10 @@
from __future__ import annotations
import re
import json
import time
from typing import Dict, List, Literal, Optional, Union
from fastapi import UploadFile
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
@ -12,8 +12,10 @@ from openai.types.chat import (
completion_create_params,
)
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
from chatchat.server.utils import MsgType
class OpenAIBaseInput(BaseModel):
user: Optional[str] = None
@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel):
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None
extra_body: Optional[Dict] = None
extra_json: Optional[Dict] = Field(None, alias="extra_body")
timeout: Optional[float] = None
class Config:
@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput):
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
tools: List[ChatCompletionToolParam] = None
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
voice: str
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
speed: Optional[float] = None
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None
model: Optional[str] = None
object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk"
role: Literal["assistant"] = "assistant"
finish_reason: Optional[str] = None
created: int = Field(default_factory=lambda : int(time.time()))
tool_calls: List[Dict] = []
status: Optional[int] = None # AgentStatus
message_type: int = MsgType.TEXT
message_id: Optional[str] = None # id in database table
is_ref: bool = False # wheather show in seperated expander
class Config:
extra = "allow"
def model_dump(self) -> dict:
result = {
"id": self.id,
"object": self.object,
"model": self.model,
"created": self.created,
"status": self.status,
"message_type": self.message_type,
"message_id": self.message_id,
"is_ref": self.is_ref,
**(self.model_extra or {}),
}
if self.object == "chat.completion.chunk":
result["choices"] = [{
"delta": {
"content": self.content,
"tool_calls": self.tool_calls,
},
"role": self.role,
}]
elif self.object == "chat.completion":
result["choices"] = [{
"message": {
"role": self.role,
"content": self.content,
"finish_reason": self.finish_reason,
"tool_calls": self.tool_calls,
}
}]
return result
def model_dump_json(self):
return json.dumps(self.model_dump(), ensure_ascii=False)
class OpenAIChatOutput(OpenAIBaseOutput):
...

View File

@ -1,12 +1,17 @@
from __future__ import annotations
from typing import List
from typing import List, Dict
from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
from chatchat.server.chat.chat import chat
from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
from .openai_routes import openai_request
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
@ -21,4 +26,115 @@ chat_router.post("/feedback",
chat_router.post("/file_chat",
summary="文件对话"
)(file_chat)
)(file_chat)
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
'''
请求参数与 openai.chat.completions.create 一致可以通过 extra_body 传入额外参数
tools tool_choice 可以直接传工具名称会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合如文件对话
返回与 openai 兼容的 Dict
'''
client = get_OpenAIClient(model_name=body.model, is_async=True)
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice):
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
}
}
conversation_id = extra.get("conversation_id")
# chat based on result from one choiced tool
if body.tool_choice:
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
body.tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
}
}]
if tool_input := extra.get("tool_input"):
message_id = add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
tool_result = await tool.ainvoke(tool_input)
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
}
header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
"tool_output":tool_result.data,
"is_ref": True,
}]
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
# agent chat with tool calls
if body.tools:
message_id = add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
)
return result
else: # LLM chat directly
message_id = add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
extra_json = {
"message_id": message_id,
"status": None,
}
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, Tuple, AsyncGenerator
from typing import Dict, Tuple, AsyncGenerator, Iterable
from fastapi import APIRouter, Request
from openai import AsyncClient
@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]
@asynccontextmanager
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
'''
对重名模型进行调度依次选择空闲的模型 -> 当前访问数最少的模型
'''
@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
semaphore.release()
async def openai_request(method, body):
async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]):
'''
helper function to make openai request
helper function to make openai request with extra fields
'''
async def generator():
async for chunk in await method(**params):
yield {"data": chunk.json()}
for x in header:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()
for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
params = body.model_dump(exclude_unset=True)
params = body.dict(exclude_unset=True)
if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
return (await method(**params)).dict()
result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()
@openai_router.get("/models")
@ -74,12 +102,12 @@ async def list_models() -> List:
client = get_OpenAIClient(name, is_async=True)
models = await client.models.list()
if config.get("platform_type") == "xinference":
models = models.dict(exclude={"data":..., "object":...})
models = models.model_dump(exclude={"data":..., "object":...})
for x in models:
models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.dict(), "platform_name": name} for x in models.data]
return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True)
return {}
@ -97,12 +125,8 @@ async def create_chat_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
result = await openai_request(client.chat.completions.create, body)
# result["related_docs"] = ["doc1"]
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
# print(result)
# breakpoint()
return result
@ -111,7 +135,7 @@ async def create_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.completions.create, body)
@ -130,7 +154,7 @@ async def create_image_generations(
request: Request,
body: OpenAIImageGenerationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.generate, body)
@ -139,7 +163,7 @@ async def create_image_variations(
request: Request,
body: OpenAIImageVariationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.create_variation, body)
@ -148,7 +172,7 @@ async def create_image_edit(
request: Request,
body: OpenAIImageEditsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.edit, body)
@ -157,7 +181,7 @@ async def create_audio_translations(
request: Request,
body: OpenAIAudioTranslationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.translations.create, body)
@ -166,7 +190,7 @@ async def create_audio_transcriptions(
request: Request,
body: OpenAIAudioTranscriptionsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.transcriptions.create, body)
@ -175,7 +199,7 @@ async def create_audio_speech(
request: Request,
body: OpenAIAudioSpeechInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.speech.create, body)

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Request, Body
from chatchat.configs import logger
from chatchat.server.utils import BaseResponse, get_tool
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse)
async def list_tools():
tools = get_tool()
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
data = {t.name: {
"name": t.name,
"title": t.title,
"description": t.description,
"args": t.args,
"config": get_tool_config(t.name),
} for t in tools.values()}
return {"data": data}
@tool_router.post("/call", response_model=BaseResponse)
async def call_tool(
name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]),
):
tools = get_tool()
if tool := tools.get(name):
if tool := get_tool(name):
try:
result = await tool.ainvoke(kwargs)
result = await tool.ainvoke(tool_input)
return {"data": result}
except Exception:
msg = f"failed to call tool '{name}'"

View File

@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
pass
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_start,
"tool": serialized["name"],
"tool_input": input_str,
}
self.queue.put_nowait(dumps(data))
async def on_tool_end(
@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None:
"""Run when tool ends running."""
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end,
"tool_output": output,
}
@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None:
"""Run when tool errors."""
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end,
"text": str(error),
"tool_output": str(error),
"is_error": True,
}
# self.done.clear()
self.queue.put_nowait(dumps(data))

View File

@ -1,6 +1,8 @@
import asyncio
import json
import time
from typing import AsyncIterable, List
import uuid
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from chatchat.configs.model_config import LLM_MODEL_CONFIG
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.container import container
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream):
if configs is None:
configs = {}
configs = configs or LLM_MODEL_CONFIG
models = {}
prompts = {}
for model_type, model_configs in configs.items():
@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
message_id: str = Body(None, description="数据库消息ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: List[History] = Body(
[],
@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(
chat_type="llm_chat",
query=query,
conversation_id=conversation_id
) if conversation_id else None
'''Agent 对话'''
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
}
), callback.done))
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == AgentStatus.tool_end:
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT
if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False)
)
data["tool_calls"] = [last_tool]
last_tool = {}
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
data["message_type"] = message_type
except:
...
data.setdefault("message_type", MsgType.TEXT)
data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False)
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=models["llm_model"].model_name,
status = data["status"],
message_type = data["message_type"],
message_id=message_id,
)
yield ret.model_dump_json()
await task
return EventSourceResponse(chat_iterator())
if stream:
return EventSourceResponse(chat_iterator())
else:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status = AgentStatus.agent_finish,
message_type = MsgType.TEXT,
message_id=message_id,
)
async for chunk in chat_iterator():
data = json.loads(chunk)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]
return ret.model_dump()

View File

@ -96,11 +96,7 @@ class ESKBService(KBService):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self):
if os.path.exists(self.doc_path):
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
...
def vs_type(self) -> str:
return SupportedVSType.ES

View File

@ -65,7 +65,6 @@ class FaissKBService(KBService):
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs

View File

@ -1,24 +1,25 @@
import base64
from chatchat.server.utils import get_tool_config
import streamlit as st
from streamlit_antd_components.utils import ParseItems
from chatchat.webui_pages.dialogue.utils import process_files
# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
# get_select_model_endpoint
from chatchat.webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
import uuid
import os
import re
import time
from typing import List, Dict
import streamlit as st
from streamlit_antd_components.utils import ParseItems
import openai
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models
import uuid
from typing import List, Dict
from chatchat.server.utils import get_tool_config
from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.utils import process_files
img_dir = (Path(__file__).absolute().parent.parent.parent)
@ -121,69 +122,66 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.write("\n\n".join(cmds))
with st.sidebar:
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
tab1, tab2 = st.tabs(["对话设置", "模型设置"])
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
with tab1:
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
# 选择工具
tools = api.list_tools()
if use_agent:
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
else:
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
selected_tools = [selected_tool]
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms)
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = st.selectbox("选择LLM模型", llm_models)
# 当不启用Agent时手动生成工具参数
# TODO: 需要更精细的控制控件
tool_input = {}
if not use_agent and len(selected_tools) == 1:
with st.expander("工具参数", True):
for k, v in tools[selected_tools[0]]["args"].items():
if choices := v.get("choices", v.get("enum")):
tool_input[k] = st.selectbox(v["title"], choices)
else:
if v["type"] == "integer":
tool_input[k] = st.slider(v["title"], value=v.get("default"))
elif v["type"] == "number":
tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1)
else:
tool_input[k] = st.text_input(v["title"], v.get("default"))
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
continue
if key == 'action_model':
first_key = next(iter(LLM_MODEL_CONFIG[key]))
if first_key not in SUPPORT_AGENT_MODELS:
st.warning("不支持Agent的模型无法执行任何工具调用")
tool_use = False
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具
selected_tool_configs = {}
if tool_use:
from chatchat.configs import model_config as model_config_py
import importlib
importlib.reload(model_config_py)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
with tab2:
# 会话
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
tools = get_tool_config()
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = tools[tool]
# 模型
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms)
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = st.selectbox("选择LLM模型", llm_models)
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
for key in LLM_MODEL_CONFIG:
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
# print(len(files_upload["audios"])) if files_upload else None
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
# files = st.file_uploader("上传知识文件:",
# [i for ls in LOADER_DICT.values() for i in ls],
# accept_multiple_files=True,
# )
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
# if st.button("开始上传", disabled=len(files) == 0):
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.ai_say("正在思考...")
text = ""
text_action = ""
element_index = 0
started = False
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_config=selected_tool_configs,
):
message_id = d.get("message_id", "")
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
messages = history + [{"role": "user", "content": prompt}]
tools = list(selected_tool_configs)
if len(selected_tools) == 1:
tool_choice = selected_tools[0]
else:
tool_choice = None
# 如果 tool_input 中有空的字段,设为用户输入
for k in tool_input:
if tool_input[k] in [None, ""]:
tool_input[k] = prompt
extra_body = dict(
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input = tool_input,
)
for d in client.chat.completions.create(
messages=messages,
model=llm_model,
stream=True,
tools=tools,
tool_choice=tool_choice,
extra_body=extra_body,
):
print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id
metadata = {
"message_id": message_id,
}
print(d)
if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
formatted_data = {
"Function": d["tool_name"],
"function_input": d["tool_input"]
}
element_index += 1
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
chat_box.insert_msg(
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
text = """\n```{}\n```\n""".format(formatted_json)
chat_box.update_msg(Markdown(text), element_index=element_index)
elif d["status"] == AgentStatus.tool_end:
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
elif d["status"] == AgentStatus.llm_new_token:
text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
if d["message_type"] == MsgType.IMAGE:
for url in json.loads(d["text"]).get("images", []):
url = f"{api.base_url}/media/{url}"
chat_box.insert_msg(Image(url))
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
if d.status == AgentStatus.error:
st.error(d.choices[0].delta.content)
elif d.status == AgentStatus.llm_start:
if not started:
started = True
else:
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
elif d.status == AgentStatus.llm_end:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
# tool 的输出与 llm 输出重复了
# elif d.status == AgentStatus.tool_start:
# formatted_data = {
# "Function": d.choices[0].delta.tool_calls[0].function.name,
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
# }
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
# text = """\n```{}\n```\n""".format(formatted_json)
# chat_box.insert_msg( # TODO: insert text directly not shown
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
# elif d.status == AgentStatus.tool_end:
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
# if d.message_type == MsgType.IMAGE:
# for url in json.loads(tool_output).get("images", []):
# url = f"{api.base_url}/media/{url}"
# chat_box.insert_msg(Image(url))
# chat_box.update_msg(expanded=False, state="complete")
# else:
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat
if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
chat_box.insert_msg("")
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
chat_box.update_msg(text, streaming=False, metadata=metadata)
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file:
@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun()
now = datetime.now()
with st.sidebar:
with tab1:
cols = st.columns(2)
export_btn = cols[0]
if cols[1].button(
@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
mime="text/markdown",
use_container_width=True,
)
# st.write(chat_box.history)

View File

@ -50,6 +50,14 @@ class ApiRequest:
timeout=self.timeout)
return self._client
def _check_url(self, url: str) -> str:
'''
新版 httpx 强制要求 url / 结尾否则会返回 307
'''
if not url.endswith("/"):
url = url + "/"
return url
def get(
self,
url: str,
@ -58,6 +66,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream:
@ -79,6 +88,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
# print(kwargs)
@ -101,6 +111,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream:
@ -638,6 +649,27 @@ class ApiRequest:
resp = self.post("/chat/feedback", json=data)
return self._get_response_value(resp)
def list_tools(self) -> Dict:
'''
列出所有工具
'''
resp = self.get("/tools")
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {}))
def call_tool(
self,
name: str,
tool_input: Dict = {},
):
'''
调用工具
'''
data = {
"name": name,
"tool_input": tool_input,
}
resp = self.post("/tools/call", json=data)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):