mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 22:33:24 +08:00
修复 callback 无效的问题
This commit is contained in:
parent
7257521e10
commit
6f155aec1f
@ -18,12 +18,8 @@ def agents_registry(
|
|||||||
callbacks: List[BaseCallbackHandler] = [],
|
callbacks: List[BaseCallbackHandler] = [],
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
verbose: bool = False):
|
verbose: bool = False):
|
||||||
if prompt is not None:
|
|
||||||
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
|
||||||
else:
|
|
||||||
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
|
|
||||||
|
|
||||||
# llm.callbacks = callbacks
|
# llm.callbacks = callbacks
|
||||||
|
llm.streaming = False # qwen agent not support streaming
|
||||||
|
|
||||||
# Write any optimized method here.
|
# Write any optimized method here.
|
||||||
if "glm3" in llm.model_name.lower():
|
if "glm3" in llm.model_name.lower():
|
||||||
@ -32,6 +28,10 @@ def agents_registry(
|
|||||||
elif "qwen" in llm.model_name.lower():
|
elif "qwen" in llm.model_name.lower():
|
||||||
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
|
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
|
||||||
else:
|
else:
|
||||||
|
if prompt is not None:
|
||||||
|
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
||||||
|
else:
|
||||||
|
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
|
||||||
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
|
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
|
||||||
|
|
||||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
|
||||||
|
|||||||
@ -168,8 +168,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
"status": AgentStatus.agent_finish,
|
"status": AgentStatus.agent_finish,
|
||||||
"text": finish.return_values["output"],
|
"text": finish.return_values["output"],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.done.set()
|
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
|
||||||
|
async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
self.done.set()
|
||||||
self.out = True
|
self.out = True
|
||||||
|
|||||||
@ -21,11 +21,6 @@ from server.db.repository import add_message_to_db
|
|||||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
# from langchain.globals import set_debug, set_verbose
|
|
||||||
# set_debug(True)
|
|
||||||
# set_verbose(True)
|
|
||||||
|
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks, stream):
|
def create_models_from_config(configs, callbacks, stream):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
configs = {}
|
configs = {}
|
||||||
@ -128,7 +123,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
|
|
||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream)
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
@ -143,13 +138,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
full_chain.ainvoke(
|
full_chain.ainvoke(
|
||||||
{
|
{
|
||||||
"input": query,
|
"input": query,
|
||||||
"chat_history": [
|
"chat_history": [],
|
||||||
HumanMessage(content="今天北京的温度是多少度"),
|
|
||||||
AIMessage(content="今天北京的温度是1度"),
|
|
||||||
],
|
|
||||||
|
|
||||||
}
|
}
|
||||||
),callback.done))
|
), callback.done))
|
||||||
|
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|||||||
@ -13,10 +13,10 @@ from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIte
|
|||||||
from langchain import globals
|
from langchain import globals
|
||||||
|
|
||||||
globals.set_debug(True)
|
globals.set_debug(True)
|
||||||
# globals.set_verbose(True)
|
globals.set_verbose(True)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def test1():
|
||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
|
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
|
||||||
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||||
@ -38,4 +38,52 @@ async def main():
|
|||||||
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
||||||
await ret
|
await ret
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
async def test2():
|
||||||
|
from server.chat.chat import chat
|
||||||
|
|
||||||
|
mc={'preprocess_model': {
|
||||||
|
'qwen': {
|
||||||
|
'temperature': 0.4,
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'history_len': 100,
|
||||||
|
'prompt_name': 'default',
|
||||||
|
'callbacks': False}
|
||||||
|
},
|
||||||
|
'llm_model': {
|
||||||
|
'qwen': {
|
||||||
|
'temperature': 0.9,
|
||||||
|
'max_tokens': 4096,
|
||||||
|
'history_len': 3,
|
||||||
|
'prompt_name': 'default',
|
||||||
|
'callbacks': True}
|
||||||
|
},
|
||||||
|
'action_model': {
|
||||||
|
'qwen': {
|
||||||
|
'temperature': 0.01,
|
||||||
|
'max_tokens': 4096,
|
||||||
|
'prompt_name': 'qwen',
|
||||||
|
'callbacks': True}
|
||||||
|
},
|
||||||
|
'postprocess_model': {
|
||||||
|
'qwen': {
|
||||||
|
'temperature': 0.01,
|
||||||
|
'max_tokens': 4096,
|
||||||
|
'prompt_name': 'default',
|
||||||
|
'callbacks': True}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tc={'weather_check': {'use': False, 'api-key': 'your key'}}
|
||||||
|
|
||||||
|
async for x in (await chat("苏州天气如何",{},
|
||||||
|
model_config=mc,
|
||||||
|
tool_config=tc,
|
||||||
|
conversation_id=None,
|
||||||
|
history_len=-1,
|
||||||
|
history=[],
|
||||||
|
stream=True)).body_iterator:
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test2())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user