liunux4odoo 9c525b7fa5
publish 0.2.10 (#2797)
新功能:
- 优化 PDF 文件的 OCR,过滤无意义的小图片 by @liunux4odoo #2525
- 支持 Gemini 在线模型 by @yhfgyyf #2630
- 支持 GLM4 在线模型 by @zRzRzRzRzRzRzR
- elasticsearch更新https连接 by @xldistance #2390
- 增强对PPT、DOC知识库文件的OCR识别 by @596192804 #2013
- 更新 Agent 对话功能 by @zRzRzRzRzRzRzR
- 每次创建对象时从连接池获取连接,避免每次执行方法时都新建连接 by @Lijia0 #2480
- 实现 ChatOpenAI 判断token有没有超过模型的context上下文长度 by @glide-the
- 更新运行数据库报错和项目里程碑 by @zRzRzRzRzRzRzR #2659
- 更新配置文件/文档/依赖 by @imClumsyPanda @zRzRzRzRzRzRzR
- 添加日文版 readme by @eltociear #2787

修复:
- langchain 更新后,PGVector 向量库连接错误 by @HALIndex #2591
- Minimax's model worker 错误 by @xyhshen 
- ES库无法向量检索.添加mappings创建向量索引 by MSZheng20 #2688
2024-01-26 06:58:49 +08:00

179 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import asyncio
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import AsyncIterable, Optional, List
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
{"role": "assistant",
"content": "使用天气查询工具查询到今天北京多云10-14摄氏度东北风2级易感冒"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
model_container.MODEL = model_agent
else:
model_container.MODEL = model
prompt_template = get_prompt_template("agent_chat", prompt_name)
prompt_template_agent = CustomPromptTemplate(
template=prompt_template,
tools=tools,
input_variables=["input", "intermediate_steps", "history"]
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
if message.role == 'user':
memory.chat_memory.add_user_message(message.content)
else:
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
callback_manager=None,
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
verbose=True,
)
else:
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation"],
allowed_tools=tool_names,
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
while True:
try:
task = asyncio.create_task(wrap_done(
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
callback.done))
break
except:
pass
if stream:
async for chunk in callback.aiter():
tools_use = []
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
elif data["status"] == Status.error:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用失败")
tools_use.append("错误信息: " + data["error"])
tools_use.append("重新开始尝试")
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.tool_finish:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用成功")
tools_use.append("工具输入: " + data["input_str"])
tools_use.append("工具输出: " + data["output_str"])
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
else:
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
else:
answer = ""
final_answer = ""
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
if data["status"] == Status.error:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用失败" + "\n"
answer += "错误信息: " + data["error"] + "\n"
answer += "\n```\n"
if data["status"] == Status.tool_finish:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用成功" + "\n"
answer += "工具输入: " + data["input_str"] + "\n"
answer += "工具输出: " + data["output_str"] + "\n"
answer += "\n```\n"
if data["status"] == Status.agent_finish:
final_answer = data["final_answer"]
else:
answer += data["llm_token"]
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)