liunux4odoo 9ff7bef2c2
新功能:知识库管理界面支持查看、编辑、删除向量库文档 (#2471)
* 新功能:
- 知识库管理界面支持查看、编辑、删除向量库文档。暂不支持增加(aggrid添加新行比较麻烦,需要另外实现)
- 去除知识库管理界面中重建知识库和删除知识库按钮,建议用户到终端命令操作

修复:
- 所有与知识库名称、文件名称有关的数据库操作函数都改成大小写不敏感,所有路径统一为 posix 风格,避免因路径文本不一致导致数据重复和操作失效 (close #2232)

开发者:
- 添加 update_docs_by_id 函数与 API 接口。当前仅支持 FAISS,暂时未用到,未将来对知识库做更细致的修改做准备
- 统一 DocumentWithScore 与 DocumentWithVsId
- FAISS 返回的 Document.metadata 中包含 ID, 方便后续查找比对
- /knowledge_base/search_docs 接口支持 file_name, metadata 参数,可以据此检索文档

* fix bug
2023-12-26 13:44:36 +08:00

104 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal history, max_tokens
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
memory = None
# 负责保存llm response到message db
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
if history: # 优先使用前端传入的历史消息
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=model,
message_limit=history_len)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps(
{"text": token, "message_id": message_id},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())