feat: add db memory (#2046)

* feat: add db memory
* WEBUI 添加多会话功能

---------

Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
qiankunli 2023-11-22 18:38:26 +08:00 committed by GitHub
parent 569209289b
commit 1c97673d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 312 additions and 135 deletions

View File

@ -9,9 +9,6 @@ import shutil
log_verbose = False
langchain.verbose = False
# 是否保存聊天记录
SAVE_CHAT_HISTORY = False
# 通常情况下不需要更改以下内容
# 日志格式

View File

@ -0,0 +1,31 @@
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
from server.db.repository import update_message
class ConversationCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, conversation_id: str, message_id: str, chat_type: str, query: str):
self.conversation_id = conversation_id
self.message_id = message_id
self.chat_type = chat_type
self.query = query
self.start_at = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
# 如果想存更多信息则prompts 也需要持久化
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
answer = response.generations[0][0].text
update_message(self.message_id, answer)

View File

@ -1,6 +1,6 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODELS, TEMPERATURE, SAVE_CHAT_HISTORY
from configs import LLM_MODELS, TEMPERATURE
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -8,19 +8,24 @@ from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from typing import List, Optional, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.db.repository import add_chat_history_to_db, update_chat_history
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
conversation_id: str = Body(None, description="对话框ID"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
@ -28,26 +33,41 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
async def chat_iterator() -> AsyncIterable[str]:
nonlocal history
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
memory = None
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
callbacks=callbacks,
)
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
if conversation_id is None:
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
else:
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
@ -55,30 +75,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback.done),
)
answer = ""
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
if stream:
async for token in callback.aiter():
answer += token
# Use server-sent-events to stream the response
yield json.dumps(
{"text": token, "chat_history_id": chat_history_id},
{"text": token, "message_id": message_id},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps(
{"text": answer, "chat_history_id": chat_history_id},
{"text": answer, "message_id": message_id},
ensure_ascii=False)
if SAVE_CHAT_HISTORY and len(chat_history_id) > 0:
# 后续可以加入一些其他信息比如真实的prompt等
update_chat_history(chat_history_id, response=answer)
await task
return StreamingResponse(chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
return StreamingResponse(chat_iterator(), media_type="text/event-stream")

View File

@ -1,19 +1,18 @@
from fastapi import Body
from configs import logger, log_verbose
from server.utils import BaseResponse
from server.db.repository.chat_history_repository import feedback_chat_history_to_db
from server.db.repository import feedback_message_to_db
def chat_feedback(chat_history_id: str = Body("", max_length=32, description="聊天记录id"),
def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"),
score: int = Body(0, max=100, description="用户评分满分100越大表示评价越高"),
reason: str = Body("", description="用户评分理由,比如不符合事实等")
):
try:
feedback_chat_history_to_db(chat_history_id, score, reason)
feedback_message_to_db(message_id, score, reason)
except Exception as e:
msg = f"反馈聊天记录出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已反馈聊天记录 {chat_history_id}")
return BaseResponse(code=200, msg=f"已反馈聊天记录 {message_id}")

View File

@ -0,0 +1,17 @@
from sqlalchemy import Column, Integer, String, DateTime, JSON, func
from server.db.base import Base
class ConversationModel(Base):
"""
聊天记录模型
"""
__tablename__ = 'conversation'
id = Column(String(32), primary_key=True, comment='对话框ID')
name = Column(String(50), comment='对话框名称')
# chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):
return f"<Conversation(id='{self.id}', name='{self.name}', chat_type='{self.chat_type}', create_time='{self.create_time}')>"

View File

@ -3,13 +3,13 @@ from sqlalchemy import Column, Integer, String, DateTime, JSON, func
from server.db.base import Base
class ChatHistoryModel(Base):
class MessageModel(Base):
"""
聊天记录模型
"""
__tablename__ = 'chat_history'
# 由前端生成的uuid如果是自增的话则需要将id 传给前端,这在流式返回里有点麻烦
__tablename__ = 'message'
id = Column(String(32), primary_key=True, comment='聊天记录ID')
conversation_id = Column(String(32), index=True, comment='对话框ID')
# chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
query = Column(String(4096), comment='用户问题')
@ -22,4 +22,4 @@ class ChatHistoryModel(Base):
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):
return f"<ChatHistory(id='{self.id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"
return f"<message(id='{self.id}', conversation_id='{self.conversation_id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"

View File

@ -1,3 +1,4 @@
from .chat_history_repository import *
from .conversation_repository import *
from .message_repository import *
from .knowledge_base_repository import *
from .knowledge_file_repository import *

View File

@ -1,75 +0,0 @@
from server.db.session import with_session
from server.db.models.chat_history_model import ChatHistoryModel
import re
import uuid
from typing import Dict, List
def _convert_query(query: str) -> str:
p = re.sub(r"\s+", "%", query)
return f"%{p}%"
@with_session
def add_chat_history_to_db(session, chat_type, query, response="", chat_history_id=None, metadata: Dict = {}):
"""
新增聊天记录
"""
if not chat_history_id:
chat_history_id = uuid.uuid4().hex
ch = ChatHistoryModel(id=chat_history_id, chat_type=chat_type, query=query, response=response,
metadata=metadata)
session.add(ch)
session.commit()
return ch.id
@with_session
def update_chat_history(session, chat_history_id, response: str = None, metadata: Dict = None):
"""
更新已有的聊天记录
"""
ch = get_chat_history_by_id(chat_history_id)
if ch is not None:
if response is not None:
ch.response = response
if isinstance(metadata, dict):
ch.meta_data = metadata
session.add(ch)
return ch.id
@with_session
def feedback_chat_history_to_db(session, chat_history_id, feedback_score, feedback_reason):
"""
反馈聊天记录
"""
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
if ch:
ch.feedback_score = feedback_score
ch.feedback_reason = feedback_reason
return ch.id
@with_session
def get_chat_history_by_id(session, chat_history_id) -> ChatHistoryModel:
"""
查询聊天记录
"""
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
return ch
@with_session
def filter_chat_history(session, query=None, response=None, score=None, reason=None) -> List[ChatHistoryModel]:
ch =session.query(ChatHistoryModel)
if query is not None:
ch = ch.filter(ChatHistoryModel.query.ilike(_convert_query(query)))
if response is not None:
ch = ch.filter(ChatHistoryModel.response.ilike(_convert_query(response)))
if score is not None:
ch = ch.filter_by(feedback_score=score)
if reason is not None:
ch = ch.filter(ChatHistoryModel.feedback_reason.ilike(_convert_query(reason)))
return ch

View File

@ -0,0 +1,16 @@
from server.db.session import with_session
import uuid
from server.db.models.conversation_model import ConversationModel
@with_session
def add_conversation_to_db(session, chat_type, name="", conversation_id=None):
"""
新增聊天记录
"""
if not conversation_id:
conversation_id = uuid.uuid4().hex
c = ConversationModel(id=conversation_id, chat_type=chat_type, name=name)
session.add(c)
return c.id

View File

@ -0,0 +1,72 @@
from server.db.session import with_session
from typing import Dict, List
import uuid
from server.db.models.message_model import MessageModel
@with_session
def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None,
metadata: Dict = {}):
"""
新增聊天记录
"""
if not message_id:
message_id = uuid.uuid4().hex
m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response,
conversation_id=conversation_id,
meta_data=metadata)
session.add(m)
session.commit()
return m.id
@with_session
def update_message(session, message_id, response: str = None, metadata: Dict = None):
"""
更新已有的聊天记录
"""
m = get_message_by_id(message_id)
if m is not None:
if response is not None:
m.response = response
if isinstance(metadata, dict):
m.meta_data = metadata
session.add(m)
session.commit()
return m.id
@with_session
def get_message_by_id(session, message_id) -> MessageModel:
"""
查询聊天记录
"""
m = session.query(MessageModel).filter_by(id=message_id).first()
return m
@with_session
def feedback_message_to_db(session, message_id, feedback_score, feedback_reason):
"""
反馈聊天记录
"""
m = session.query(MessageModel).filter_by(id=message_id).first()
if m:
m.feedback_score = feedback_score
m.feedback_reason = feedback_reason
session.commit()
return m.id
@with_session
def filter_message(session, conversation_id: str, limit: int = 10):
messages = (session.query(MessageModel).filter_by(conversation_id=conversation_id).
# 用户最新的query 也会插入到db忽略这个message record
filter(MessageModel.response != '').
# 返回最近的limit 条记录
order_by(MessageModel.create_time.desc()).limit(limit).all())
# 直接返回 List[MessageModel] 报错
data = []
for m in messages:
data.append({"query": m.query, "response": m.response})
return data

View File

@ -9,7 +9,8 @@ from server.knowledge_base.utils import (
KnowledgeFile
)
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.models.chat_history_model import ChatHistoryModel
from server.db.models.conversation_model import ConversationModel
from server.db.models.message_model import MessageModel
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
from server.db.base import Base, engine
from server.db.session import session_scope

View File

@ -0,0 +1,73 @@
import logging
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from server.db.repository.message_repository import filter_message
from server.db.models.message_model import MessageModel
class ConversationBufferDBMemory(BaseChatMemory):
conversation_id: str
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
memory_key: str = "history"
max_token_limit: int = 2000
message_limit: int = 10
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
# fetch limited messages desc, and return reversed
messages = filter_message(conversation_id=self.conversation_id, limit=self.message_limit)
# 返回的记录按时间倒序,转为正序
messages = list(reversed(messages))
chat_messages: List[BaseMessage] = []
for message in messages:
chat_messages.append(HumanMessage(content=message["query"]))
chat_messages.append(AIMessage(content=message["response"]))
if not chat_messages:
return []
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages))
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages))
return chat_messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer: Any = self.buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

File diff suppressed because one or more lines are too long

View File

@ -13,6 +13,7 @@ from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
import logging
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -20,6 +21,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logging.exception(e)
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',

View File

@ -6,6 +6,7 @@ import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
@ -47,8 +48,11 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
st.session_state.setdefault("file_chat_id", None)
default_model = api.get_default_llm_model()[0]
if not chat_box.chat_inited:
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
@ -57,6 +61,32 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.init_session()
with st.sidebar:
# 多会话
cols = st.columns([3, 1])
conv_name = cols[0].text_input("会话名称")
with cols[1]:
if st.button("添加"):
if not conv_name or conv_name in st.session_state["conversation_ids"]:
st.error("请指定有效的会话名称")
else:
st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
st.session_state["cur_conv_name"] = conv_name
st.session_state["conv_name"] = ""
if st.button("删除"):
if not conv_name or conv_name not in st.session_state["conversation_ids"]:
st.error("请指定有效的会话名称")
else:
st.session_state["conversation_ids"].pop(conv_name, None)
st.session_state["cur_conv_name"] = ""
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
# TODO: 对话模型与会话绑定
def on_mode_change():
mode = st.session_state.dialogue_mode
@ -210,12 +240,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
def on_feedback(
feedback,
chat_history_id: str = "",
message_id: str = "",
history_index: int = -1,
):
reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
api.chat_feedback(chat_history_id=chat_history_id,
api.chat_feedback(message_id=message_id,
score=score_int,
reason=reason)
st.session_state["need_rerun"] = True
@ -231,9 +261,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
chat_history_id = ""
message_id = ""
r = api.chat_chat(prompt,
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
@ -243,16 +274,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
break
text += t.get("text", "")
chat_box.update_msg(text)
chat_history_id = t.get("chat_history_id", "")
message_id = t.get("message_id", "")
metadata = {
"chat_history_id": chat_history_id,
"message_id": message_id,
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=chat_history_id,
key=message_id,
on_submit=on_feedback,
kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1})
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):

View File

@ -290,6 +290,7 @@ class ApiRequest:
def chat_chat(
self,
query: str,
conversation_id: str = None,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
@ -303,6 +304,7 @@ class ApiRequest:
'''
data = {
"query": query,
"conversation_id": conversation_id,
"history": history,
"stream": stream,
"model_name": model,
@ -978,7 +980,7 @@ class ApiRequest:
def chat_feedback(
self,
chat_history_id: str,
message_id: str,
score: int,
reason: str = "",
) -> int:
@ -986,7 +988,7 @@ class ApiRequest:
反馈对话评价
'''
data = {
"chat_history_id": chat_history_id,
"message_id": message_id,
"score": score,
"reason": reason,
}