mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 14:57:44 +08:00
fix callback handler
This commit is contained in:
parent
6f04e15aed
commit
e2a46a1d0f
@ -19,7 +19,7 @@ from langchain.output_parsers import OutputFixingParser
|
|||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -169,7 +169,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callbacks: List[BaseCallbackHandler] = [],
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
@ -187,7 +187,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callbacks=callbacks,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
@ -208,6 +208,7 @@ def initialize_glm3_agent(
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
|
callbacks: List[BaseCallbackHandler] = [],
|
||||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||||
agent_kwargs: Optional[dict] = None,
|
agent_kwargs: Optional[dict] = None,
|
||||||
*,
|
*,
|
||||||
@ -216,6 +217,7 @@ def initialize_glm3_agent(
|
|||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
tags_ = list(tags) if tags else []
|
tags_ = list(tags) if tags else []
|
||||||
agent_kwargs = agent_kwargs or {}
|
agent_kwargs = agent_kwargs or {}
|
||||||
|
llm.callbacks=callbacks
|
||||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@ -225,6 +227,7 @@ def initialize_glm3_agent(
|
|||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent_obj,
|
agent=agent_obj,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
callbacks=callbacks,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tags=tags_,
|
tags=tags_,
|
||||||
intermediate_steps=[],
|
intermediate_steps=[],
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from langchain.pydantic_v1 import Field
|
|||||||
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
||||||
HumanMessage, SystemMessage, AIMessage)
|
HumanMessage, SystemMessage, AIMessage)
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.render import format_tool_to_openai_function
|
from langchain.tools.render import format_tool_to_openai_function
|
||||||
@ -107,16 +107,16 @@ class QwenChatAgent(LLMSingleActionAgent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callbacks: List[BaseCallbackHandler] = [],
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
|
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QwenChatAgent:
|
) -> QwenChatAgent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
@ -129,7 +129,7 @@ class QwenChatAgent(LLMSingleActionAgent):
|
|||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
output_parser = output_parser or QwenChatAgentOutputParser()
|
output_parser = output_parser or QwenChatAgentOutputParser()
|
||||||
@ -147,34 +147,33 @@ class QwenChatAgent(LLMSingleActionAgent):
|
|||||||
|
|
||||||
|
|
||||||
def initialize_qwen_agent(
|
def initialize_qwen_agent(
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callbacks: List[BaseCallbackHandler] = [],
|
||||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||||
agent_kwargs: Optional[dict] = None,
|
agent_kwargs: Optional[dict] = None,
|
||||||
*,
|
*,
|
||||||
return_direct: Optional[bool] = None,
|
return_direct: Optional[bool] = None,
|
||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
tags_ = list(tags) if tags else []
|
tags_ = list(tags) if tags else []
|
||||||
agent_kwargs = agent_kwargs or {}
|
agent_kwargs = agent_kwargs or {}
|
||||||
|
|
||||||
if isinstance(return_direct, bool): # can make all tools return directly
|
if isinstance(return_direct, bool): # can make all tools return directly
|
||||||
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
|
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
|
||||||
|
llm.callbacks=callbacks
|
||||||
agent_obj = QwenChatAgent.from_llm_and_tools(
|
agent_obj = QwenChatAgent.from_llm_and_tools(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
**agent_kwargs,
|
||||||
**agent_kwargs
|
|
||||||
)
|
)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent_obj,
|
agent=agent_obj,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=callback_manager,
|
callbacks=callbacks,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tags=tags_,
|
tags=tags_,
|
||||||
intermediate_steps=[],
|
intermediate_steps=[],
|
||||||
|
|||||||
@ -13,83 +13,50 @@ def dumps(obj: Dict) -> str:
|
|||||||
return json.dumps(obj, ensure_ascii=False)
|
return json.dumps(obj, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
class Status:
|
class AgentStatus:
|
||||||
start: int = 1
|
llm_start: int = 1
|
||||||
running: int = 2
|
llm_new_token: int = 2
|
||||||
complete: int = 3
|
llm_end: int = 3
|
||||||
agent_action: int = 4
|
agent_action: int = 4
|
||||||
agent_finish: int = 5
|
agent_finish: int = 5
|
||||||
error: int = 6
|
error: int = 6
|
||||||
tool_finish: int = 7
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.done = asyncio.Event()
|
self.done = asyncio.Event()
|
||||||
self.cur_tool = {}
|
|
||||||
self.out = True
|
self.out = True
|
||||||
|
|
||||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
data = {
|
||||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
"status" : AgentStatus.llm_start,
|
||||||
self.cur_tool = {
|
"text" : "",
|
||||||
"tool_name": serialized["name"],
|
|
||||||
"input_str": input_str,
|
|
||||||
"output_str": "",
|
|
||||||
"status": Status.agent_action,
|
|
||||||
"run_id": run_id.hex,
|
|
||||||
"llm_token": "",
|
|
||||||
"final_answer": "",
|
|
||||||
"error": "",
|
|
||||||
}
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.done.clear()
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
|
||||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
||||||
|
|
||||||
self.out = True
|
|
||||||
self.cur_tool.update(
|
|
||||||
status=Status.tool_finish,
|
|
||||||
output_str=output.replace("Answer:", ""),
|
|
||||||
)
|
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
|
|
||||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
||||||
self.cur_tool.update(
|
|
||||||
status=Status.error,
|
|
||||||
error=str(error),
|
|
||||||
)
|
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
|
|
||||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
special_tokens = ["Action", "<|observation|>"]
|
special_tokens = ["Action", "<|observation|>"]
|
||||||
for stoken in special_tokens:
|
for stoken in special_tokens:
|
||||||
if stoken in token:
|
if stoken in token:
|
||||||
before_action = token.split(stoken)[0]
|
before_action = token.split(stoken)[0]
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.running,
|
"status" : AgentStatus.llm_new_token,
|
||||||
llm_token=before_action + "\n",
|
"text": before_action + "\n",
|
||||||
)
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(data))
|
||||||
self.out = False
|
self.out = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if token is not None and token != "" and self.out:
|
if token is not None and token != "" and self.out:
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.running,
|
"status" : AgentStatus.llm_new_token,
|
||||||
llm_token=token,
|
"text" : token,
|
||||||
)
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
|
||||||
self.cur_tool.update(
|
|
||||||
status=Status.start,
|
|
||||||
llm_token="",
|
|
||||||
)
|
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
|
|
||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
@ -102,26 +69,27 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.start,
|
"status" : AgentStatus.llm_start,
|
||||||
llm_token="",
|
"text" : "",
|
||||||
)
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.done.clear()
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.complete,
|
"status" : AgentStatus.llm_end,
|
||||||
llm_token="",
|
"text" : response.generations[0][0].message.content,
|
||||||
)
|
}
|
||||||
self.out = True
|
self.out = True
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.error,
|
"status" : AgentStatus.error,
|
||||||
error=str(error),
|
"text" : str(error),
|
||||||
)
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_agent_action(
|
async def on_agent_action(
|
||||||
self,
|
self,
|
||||||
@ -132,12 +100,13 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.agent_action,
|
"status" : AgentStatus.agent_action,
|
||||||
tool_name=action.tool,
|
"tool_name" : action.tool,
|
||||||
tool_input=action.tool_input,
|
"tool_input" : action.tool_input,
|
||||||
)
|
"text": action.log,
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
}
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||||
@ -147,8 +116,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
if "Thought:" in finish.return_values["output"]:
|
if "Thought:" in finish.return_values["output"]:
|
||||||
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
||||||
|
|
||||||
self.cur_tool.update(
|
data = {
|
||||||
status=Status.agent_finish,
|
"status" : AgentStatus.agent_finish,
|
||||||
agent_finish=finish.return_values["output"],
|
"text" : finish.return_values["output"],
|
||||||
)
|
}
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.done.set()
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|||||||
@ -20,22 +20,23 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
|||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||||
from server.db.repository import add_message_to_db
|
from server.db.repository import add_message_to_db
|
||||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks):
|
def create_models_from_config(configs, callbacks, stream):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
configs = {}
|
configs = {}
|
||||||
models = {}
|
models = {}
|
||||||
prompts = {}
|
prompts = {}
|
||||||
for model_type, model_configs in configs.items():
|
for model_type, model_configs in configs.items():
|
||||||
for model_name, params in model_configs.items():
|
for model_name, params in model_configs.items():
|
||||||
callback = callbacks if params.get('callbacks', False) else None
|
callbacks = callbacks if params.get('callbacks', False) else None
|
||||||
model_instance = get_ChatOpenAI(
|
model_instance = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=params.get('temperature', 0.5),
|
temperature=params.get('temperature', 0.5),
|
||||||
max_tokens=params.get('max_tokens', 1000),
|
max_tokens=params.get('max_tokens', 1000),
|
||||||
callbacks=callback
|
callbacks=callbacks,
|
||||||
|
streaming=stream,
|
||||||
)
|
)
|
||||||
models[model_type] = model_instance
|
models[model_type] = model_instance
|
||||||
prompt_name = params.get('prompt_name', 'default')
|
prompt_name = params.get('prompt_name', 'default')
|
||||||
@ -83,7 +84,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=prompts["action_model"],
|
prompt=prompts["action_model"],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
callbacks=callbacks,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
elif "qwen" in models["action_model"].model_name.lower():
|
elif "qwen" in models["action_model"].model_name.lower():
|
||||||
@ -92,7 +93,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=prompts["action_model"],
|
prompt=prompts["action_model"],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
callbacks=callbacks,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -131,7 +132,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(True, description="流式输出"),
|
||||||
model_config: Dict = Body({}, description="LLM 模型配置"),
|
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||||
tool_config: Dict = Body({}, description="工具配置"),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
):
|
):
|
||||||
@ -142,11 +143,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
conversation_id=conversation_id
|
conversation_id=conversation_id
|
||||||
) if conversation_id else None
|
) if conversation_id else None
|
||||||
|
|
||||||
callback = CustomAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
|
|
||||||
# 从配置中选择模型
|
# 从配置中选择模型
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
|
||||||
|
|
||||||
# 从配置中选择工具
|
# 从配置中选择工具
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
@ -164,56 +165,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
# Execute Chain
|
# Execute Chain
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
|
||||||
if stream:
|
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
if data["status"] == Status.start:
|
data["message_id"] = message_id
|
||||||
continue
|
yield json.dumps(data, ensure_ascii=False)
|
||||||
elif data["status"] == Status.agent_action:
|
|
||||||
tool_info = {
|
|
||||||
"tool_name": data["tool_name"],
|
|
||||||
"tool_input": data["tool_input"]
|
|
||||||
}
|
|
||||||
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
|
|
||||||
elif data["status"] == Status.agent_finish:
|
|
||||||
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
|
|
||||||
ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
text = ""
|
|
||||||
agent_finish = ""
|
|
||||||
tool_info = None
|
|
||||||
async for chunk in callback.aiter():
|
|
||||||
data = json.loads(chunk)
|
|
||||||
if data["status"] == Status.agent_action:
|
|
||||||
tool_info = {
|
|
||||||
"tool_name": data["tool_name"],
|
|
||||||
"tool_input": data["tool_input"]
|
|
||||||
}
|
|
||||||
if data["status"] == Status.agent_finish:
|
|
||||||
agent_finish = data["agent_finish"]
|
|
||||||
else:
|
|
||||||
text += data["llm_token"]
|
|
||||||
if tool_info:
|
|
||||||
yield json.dumps(
|
|
||||||
{
|
|
||||||
"text": text,
|
|
||||||
"agent_action": tool_info,
|
|
||||||
"agent_finish": agent_finish,
|
|
||||||
"message_id": message_id
|
|
||||||
},
|
|
||||||
ensure_ascii=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield json.dumps(
|
|
||||||
{
|
|
||||||
"text": text,
|
|
||||||
"message_id": message_id
|
|
||||||
},
|
|
||||||
ensure_ascii=False
|
|
||||||
)
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(chat_iterator())
|
return EventSourceResponse(chat_iterator())
|
||||||
|
|||||||
@ -2,19 +2,31 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from server.utils import get_ChatOpenAI
|
from server.utils import get_ChatOpenAI
|
||||||
from server.agent.tools_factory.tools_registry import all_tools
|
from server.agent.tools_factory.tools_registry import all_tools
|
||||||
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
|
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
|
||||||
|
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||||
from langchain import globals
|
from langchain import globals
|
||||||
|
|
||||||
globals.set_debug(True)
|
# globals.set_debug(True)
|
||||||
globals.set_verbose(True)
|
# globals.set_verbose(True)
|
||||||
|
|
||||||
|
|
||||||
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=False)
|
async def main():
|
||||||
executor = initialize_qwen_agent(tools=all_tools, llm=qwen_model)
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
|
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||||
|
executor = initialize_qwen_agent(tools=all_tools,
|
||||||
|
llm=qwen_model,
|
||||||
|
callbacks=[callback],
|
||||||
|
)
|
||||||
|
|
||||||
# ret = executor.invoke("苏州今天冷吗")
|
# ret = executor.invoke("苏州今天冷吗")
|
||||||
ret = executor.invoke("从知识库samples中查询chatchat项目简介")
|
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
|
||||||
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
async for chunk in callback.aiter():
|
||||||
print(ret)
|
print(chunk)
|
||||||
|
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
|
||||||
|
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
||||||
|
print(ret)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
||||||
|
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
@ -271,10 +272,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||||
unsafe_allow_html=True)
|
unsafe_allow_html=True)
|
||||||
|
|
||||||
chat_box.ai_say("正在思考...")
|
chat_box.ai_say(["正在思考...",
|
||||||
|
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
|
||||||
|
Markdown()])
|
||||||
text = ""
|
text = ""
|
||||||
message_id = ""
|
message_id = ""
|
||||||
element_index = 0
|
tool_called = False
|
||||||
|
|
||||||
for d in api.chat_chat(query=prompt,
|
for d in api.chat_chat(query=prompt,
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
@ -283,33 +286,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
tool_config=selected_tool_configs,
|
tool_config=selected_tool_configs,
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
d = json.loads(d)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
message_id = d.get("message_id", "")
|
message_id = d.get("message_id", "")
|
||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
if error_msg := check_error_msg(d):
|
if not tool_called: # 避免工具调用之后重复输出,将LLM输出分为工具调用前后分别处理
|
||||||
st.error(error_msg)
|
element_index = 0
|
||||||
if chunk := d.get("agent_action"):
|
else:
|
||||||
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
|
element_index = -1
|
||||||
element_index = 1
|
if d["status"] == AgentStatus.error:
|
||||||
|
st.error(d["text"])
|
||||||
|
elif d["status"] == AgentStatus.agent_action:
|
||||||
formatted_data = {
|
formatted_data = {
|
||||||
"action": chunk["tool_name"],
|
"action": d["tool_name"],
|
||||||
"action_input": chunk["tool_input"]
|
"action_input": d["tool_input"]
|
||||||
}
|
}
|
||||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||||
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
|
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
|
||||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
tool_called = True
|
||||||
if chunk := d.get("text"):
|
text = ""
|
||||||
text += chunk
|
elif d["status"] == AgentStatus.llm_new_token:
|
||||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
text += d["text"]
|
||||||
if chunk := d.get("agent_finish"):
|
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
||||||
element_index = 0
|
elif d["status"] == AgentStatus.llm_end:
|
||||||
text = chunk
|
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
elif d["status"] == AgentStatus.agent_finish:
|
||||||
|
chat_box.update_msg(element_index=1, state="complete", expanded=False)
|
||||||
|
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
|
||||||
|
|
||||||
if os.path.exists("tmp/image.jpg"):
|
if os.path.exists("tmp/image.jpg"):
|
||||||
with open("tmp/image.jpg", "rb") as image_file:
|
with open("tmp/image.jpg", "rb") as image_file:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user