mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-21 14:23:33 +08:00
* 修复Azure 不设置Max token的bug * 重写agent 1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能 2. 删除agent_chat 集成到llm_chat中 3. 重写大部分工具,适应新Agent * 更新架构 * 删除web_chat,自动融合 * 移除所有聊天,都变成Agent控制 * 更新配置文件 * 更新配置模板和提示词 * 更改参数选择bug
163 lines
5.2 KiB
Python
163 lines
5.2 KiB
Python
from __future__ import annotations
|
|
from uuid import UUID
|
|
import json
|
|
from langchain.schema import AgentFinish, AgentAction
|
|
import asyncio
|
|
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
|
|
from langchain_core.outputs import LLMResult
|
|
from langchain.callbacks.base import AsyncCallbackHandler
|
|
|
|
def dumps(obj: Dict) -> str:
|
|
return json.dumps(obj, ensure_ascii=False)
|
|
|
|
|
|
class Status:
|
|
start: int = 1
|
|
running: int = 2
|
|
complete: int = 3
|
|
agent_action: int = 4
|
|
agent_finish: int = 5
|
|
error: int = 6
|
|
tool_finish: int = 7
|
|
|
|
|
|
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.queue = asyncio.Queue()
|
|
self.done = asyncio.Event()
|
|
self.cur_tool = {}
|
|
self.out = True
|
|
|
|
async def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
print("on_tool_start")
|
|
|
|
async def on_tool_end(
|
|
self,
|
|
output: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
print("on_tool_end")
|
|
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
|
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.error,
|
|
error=str(error),
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
special_tokens = ["Action", "<|observation|>"]
|
|
for stoken in special_tokens:
|
|
if stoken in token:
|
|
before_action = token.split(stoken)[0]
|
|
self.cur_tool.update(
|
|
status=Status.running,
|
|
llm_token=before_action + "\n",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
self.out = False
|
|
break
|
|
|
|
if token is not None and token != "" and self.out:
|
|
self.cur_tool.update(
|
|
status=Status.running,
|
|
llm_token=token,
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.start,
|
|
llm_token="",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.start,
|
|
llm_token="",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.complete,
|
|
llm_token="",
|
|
)
|
|
self.out = True
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.error,
|
|
error=str(error),
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_agent_action(
|
|
self,
|
|
action: AgentAction,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.agent_action,
|
|
tool_name=action.tool,
|
|
tool_input=action.tool_input,
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
async def on_agent_finish(
|
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.agent_finish,
|
|
agent_finish=finish.return_values["output"],
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def aiter(self) -> AsyncIterator[str]:
|
|
while not self.queue.empty() or not self.done.is_set():
|
|
done, other = await asyncio.wait(
|
|
[
|
|
asyncio.ensure_future(self.queue.get()),
|
|
asyncio.ensure_future(self.done.wait()),
|
|
],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
if other:
|
|
other.pop().cancel()
|
|
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
|
if token_or_done is True:
|
|
break
|
|
yield token_or_done
|