mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
feat: add db memory (#2046)
* feat: add db memory * WEBUI 添加多会话功能 --------- Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com> Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
parent
569209289b
commit
1c97673d41
@ -9,9 +9,6 @@ import shutil
|
||||
log_verbose = False
|
||||
langchain.verbose = False
|
||||
|
||||
# 是否保存聊天记录
|
||||
SAVE_CHAT_HISTORY = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# 日志格式
|
||||
|
||||
31
server/callback_handler/conversation_callback_handler.py
Normal file
31
server/callback_handler/conversation_callback_handler.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from server.db.repository import update_message
|
||||
|
||||
|
||||
class ConversationCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_id: str, message_id: str, chat_type: str, query: str):
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self.chat_type = chat_type
|
||||
self.query = query
|
||||
self.start_at = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
# 如果想存更多信息,则prompts 也需要持久化
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
answer = response.generations[0][0].text
|
||||
update_message(self.message_id, answer)
|
||||
@ -1,6 +1,6 @@
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, SAVE_CHAT_HISTORY
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
@ -8,19 +8,24 @@ from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
from server.chat.utils import History
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template
|
||||
from server.db.repository import add_chat_history_to_db, update_chat_history
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
conversation_id: str = Body(None, description="对话框ID"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
@ -28,26 +33,41 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal history
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
memory = None
|
||||
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
# 负责保存llm response到message db
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
if conversation_id is None:
|
||||
history = [History.from_data(h) for h in history]
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
else:
|
||||
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
||||
prompt = get_prompt_template("llm_chat", "with_history")
|
||||
chat_prompt = PromptTemplate.from_template(prompt)
|
||||
# 根据conversation_id 获取message 列表进而拼凑 memory
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
@ -55,30 +75,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
callback.done),
|
||||
)
|
||||
|
||||
answer = ""
|
||||
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps(
|
||||
{"text": token, "chat_history_id": chat_history_id},
|
||||
{"text": token, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps(
|
||||
{"text": answer, "chat_history_id": chat_history_id},
|
||||
{"text": answer, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
|
||||
if SAVE_CHAT_HISTORY and len(chat_history_id) > 0:
|
||||
# 后续可以加入一些其他信息,比如真实的prompt等
|
||||
update_chat_history(chat_history_id, response=answer)
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query=query,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
return StreamingResponse(chat_iterator(), media_type="text/event-stream")
|
||||
|
||||
@ -1,19 +1,18 @@
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose
|
||||
from server.utils import BaseResponse
|
||||
from server.db.repository.chat_history_repository import feedback_chat_history_to_db
|
||||
from server.db.repository import feedback_message_to_db
|
||||
|
||||
|
||||
def chat_feedback(chat_history_id: str = Body("", max_length=32, description="聊天记录id"),
|
||||
def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"),
|
||||
score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"),
|
||||
reason: str = Body("", description="用户评分理由,比如不符合事实等")
|
||||
):
|
||||
try:
|
||||
feedback_chat_history_to_db(chat_history_id, score, reason)
|
||||
feedback_message_to_db(message_id, score, reason)
|
||||
except Exception as e:
|
||||
msg = f"反馈聊天记录出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=200, msg=f"已反馈聊天记录 {chat_history_id}")
|
||||
return BaseResponse(code=200, msg=f"已反馈聊天记录 {message_id}")
|
||||
|
||||
17
server/db/models/conversation_model.py
Normal file
17
server/db/models/conversation_model.py
Normal file
@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, JSON, func
|
||||
from server.db.base import Base
|
||||
|
||||
|
||||
class ConversationModel(Base):
|
||||
"""
|
||||
聊天记录模型
|
||||
"""
|
||||
__tablename__ = 'conversation'
|
||||
id = Column(String(32), primary_key=True, comment='对话框ID')
|
||||
name = Column(String(50), comment='对话框名称')
|
||||
# chat/agent_chat等
|
||||
chat_type = Column(String(50), comment='聊天类型')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Conversation(id='{self.id}', name='{self.name}', chat_type='{self.chat_type}', create_time='{self.create_time}')>"
|
||||
@ -3,13 +3,13 @@ from sqlalchemy import Column, Integer, String, DateTime, JSON, func
|
||||
from server.db.base import Base
|
||||
|
||||
|
||||
class ChatHistoryModel(Base):
|
||||
class MessageModel(Base):
|
||||
"""
|
||||
聊天记录模型
|
||||
"""
|
||||
__tablename__ = 'chat_history'
|
||||
# 由前端生成的uuid,如果是自增的话,则需要将id 传给前端,这在流式返回里有点麻烦
|
||||
__tablename__ = 'message'
|
||||
id = Column(String(32), primary_key=True, comment='聊天记录ID')
|
||||
conversation_id = Column(String(32), index=True, comment='对话框ID')
|
||||
# chat/agent_chat等
|
||||
chat_type = Column(String(50), comment='聊天类型')
|
||||
query = Column(String(4096), comment='用户问题')
|
||||
@ -22,4 +22,4 @@ class ChatHistoryModel(Base):
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatHistory(id='{self.id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"
|
||||
return f"<message(id='{self.id}', conversation_id='{self.conversation_id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"
|
||||
@ -1,3 +1,4 @@
|
||||
from .chat_history_repository import *
|
||||
from .conversation_repository import *
|
||||
from .message_repository import *
|
||||
from .knowledge_base_repository import *
|
||||
from .knowledge_file_repository import *
|
||||
@ -1,75 +0,0 @@
|
||||
from server.db.session import with_session
|
||||
from server.db.models.chat_history_model import ChatHistoryModel
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def _convert_query(query: str) -> str:
|
||||
p = re.sub(r"\s+", "%", query)
|
||||
return f"%{p}%"
|
||||
|
||||
|
||||
@with_session
|
||||
def add_chat_history_to_db(session, chat_type, query, response="", chat_history_id=None, metadata: Dict = {}):
|
||||
"""
|
||||
新增聊天记录
|
||||
"""
|
||||
if not chat_history_id:
|
||||
chat_history_id = uuid.uuid4().hex
|
||||
ch = ChatHistoryModel(id=chat_history_id, chat_type=chat_type, query=query, response=response,
|
||||
metadata=metadata)
|
||||
session.add(ch)
|
||||
session.commit()
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def update_chat_history(session, chat_history_id, response: str = None, metadata: Dict = None):
|
||||
"""
|
||||
更新已有的聊天记录
|
||||
"""
|
||||
ch = get_chat_history_by_id(chat_history_id)
|
||||
if ch is not None:
|
||||
if response is not None:
|
||||
ch.response = response
|
||||
if isinstance(metadata, dict):
|
||||
ch.meta_data = metadata
|
||||
session.add(ch)
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def feedback_chat_history_to_db(session, chat_history_id, feedback_score, feedback_reason):
|
||||
"""
|
||||
反馈聊天记录
|
||||
"""
|
||||
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
|
||||
if ch:
|
||||
ch.feedback_score = feedback_score
|
||||
ch.feedback_reason = feedback_reason
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def get_chat_history_by_id(session, chat_history_id) -> ChatHistoryModel:
|
||||
"""
|
||||
查询聊天记录
|
||||
"""
|
||||
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
|
||||
return ch
|
||||
|
||||
|
||||
@with_session
|
||||
def filter_chat_history(session, query=None, response=None, score=None, reason=None) -> List[ChatHistoryModel]:
|
||||
ch =session.query(ChatHistoryModel)
|
||||
if query is not None:
|
||||
ch = ch.filter(ChatHistoryModel.query.ilike(_convert_query(query)))
|
||||
if response is not None:
|
||||
ch = ch.filter(ChatHistoryModel.response.ilike(_convert_query(response)))
|
||||
if score is not None:
|
||||
ch = ch.filter_by(feedback_score=score)
|
||||
if reason is not None:
|
||||
ch = ch.filter(ChatHistoryModel.feedback_reason.ilike(_convert_query(reason)))
|
||||
|
||||
return ch
|
||||
16
server/db/repository/conversation_repository.py
Normal file
16
server/db/repository/conversation_repository.py
Normal file
@ -0,0 +1,16 @@
|
||||
from server.db.session import with_session
|
||||
import uuid
|
||||
from server.db.models.conversation_model import ConversationModel
|
||||
|
||||
|
||||
@with_session
|
||||
def add_conversation_to_db(session, chat_type, name="", conversation_id=None):
|
||||
"""
|
||||
新增聊天记录
|
||||
"""
|
||||
if not conversation_id:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
c = ConversationModel(id=conversation_id, chat_type=chat_type, name=name)
|
||||
|
||||
session.add(c)
|
||||
return c.id
|
||||
72
server/db/repository/message_repository.py
Normal file
72
server/db/repository/message_repository.py
Normal file
@ -0,0 +1,72 @@
|
||||
from server.db.session import with_session
|
||||
from typing import Dict, List
|
||||
import uuid
|
||||
from server.db.models.message_model import MessageModel
|
||||
|
||||
|
||||
@with_session
|
||||
def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None,
|
||||
metadata: Dict = {}):
|
||||
"""
|
||||
新增聊天记录
|
||||
"""
|
||||
if not message_id:
|
||||
message_id = uuid.uuid4().hex
|
||||
m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response,
|
||||
conversation_id=conversation_id,
|
||||
meta_data=metadata)
|
||||
session.add(m)
|
||||
session.commit()
|
||||
return m.id
|
||||
|
||||
|
||||
@with_session
|
||||
def update_message(session, message_id, response: str = None, metadata: Dict = None):
|
||||
"""
|
||||
更新已有的聊天记录
|
||||
"""
|
||||
m = get_message_by_id(message_id)
|
||||
if m is not None:
|
||||
if response is not None:
|
||||
m.response = response
|
||||
if isinstance(metadata, dict):
|
||||
m.meta_data = metadata
|
||||
session.add(m)
|
||||
session.commit()
|
||||
return m.id
|
||||
|
||||
|
||||
@with_session
|
||||
def get_message_by_id(session, message_id) -> MessageModel:
|
||||
"""
|
||||
查询聊天记录
|
||||
"""
|
||||
m = session.query(MessageModel).filter_by(id=message_id).first()
|
||||
return m
|
||||
|
||||
|
||||
@with_session
|
||||
def feedback_message_to_db(session, message_id, feedback_score, feedback_reason):
|
||||
"""
|
||||
反馈聊天记录
|
||||
"""
|
||||
m = session.query(MessageModel).filter_by(id=message_id).first()
|
||||
if m:
|
||||
m.feedback_score = feedback_score
|
||||
m.feedback_reason = feedback_reason
|
||||
session.commit()
|
||||
return m.id
|
||||
|
||||
|
||||
@with_session
|
||||
def filter_message(session, conversation_id: str, limit: int = 10):
|
||||
messages = (session.query(MessageModel).filter_by(conversation_id=conversation_id).
|
||||
# 用户最新的query 也会插入到db,忽略这个message record
|
||||
filter(MessageModel.response != '').
|
||||
# 返回最近的limit 条记录
|
||||
order_by(MessageModel.create_time.desc()).limit(limit).all())
|
||||
# 直接返回 List[MessageModel] 报错
|
||||
data = []
|
||||
for m in messages:
|
||||
data.append({"query": m.query, "response": m.response})
|
||||
return data
|
||||
@ -9,7 +9,8 @@ from server.knowledge_base.utils import (
|
||||
KnowledgeFile
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.models.chat_history_model import ChatHistoryModel
|
||||
from server.db.models.conversation_model import ConversationModel
|
||||
from server.db.models.message_model import MessageModel
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||
from server.db.base import Base, engine
|
||||
from server.db.session import session_scope
|
||||
|
||||
73
server/memory/conversation_db_buffer_memory.py
Normal file
73
server/memory/conversation_db_buffer_memory.py
Normal file
@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from server.db.repository.message_repository import filter_message
|
||||
from server.db.models.message_model import MessageModel
|
||||
|
||||
|
||||
class ConversationBufferDBMemory(BaseChatMemory):
|
||||
conversation_id: str
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
# fetch limited messages desc, and return reversed
|
||||
|
||||
messages = filter_message(conversation_id=self.conversation_id, limit=self.message_limit)
|
||||
# 返回的记录按时间倒序,转为正序
|
||||
messages = list(reversed(messages))
|
||||
chat_messages: List[BaseMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(HumanMessage(content=message["query"]))
|
||||
chat_messages.append(AIMessage(content=message["response"]))
|
||||
|
||||
if not chat_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages))
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages))
|
||||
|
||||
return chat_messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer: Any = self.buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
File diff suppressed because one or more lines are too long
@ -13,6 +13,7 @@ from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
||||
import httpx
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
||||
import logging
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
@ -20,6 +21,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# TODO: handle exception
|
||||
msg = f"Caught exception: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@ -47,8 +48,11 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
st.session_state.setdefault("file_chat_id", None)
|
||||
default_model = api.get_default_llm_model()[0]
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||
@ -57,6 +61,32 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
chat_box.init_session()
|
||||
|
||||
with st.sidebar:
|
||||
# 多会话
|
||||
cols = st.columns([3, 1])
|
||||
conv_name = cols[0].text_input("会话名称")
|
||||
with cols[1]:
|
||||
if st.button("添加"):
|
||||
if not conv_name or conv_name in st.session_state["conversation_ids"]:
|
||||
st.error("请指定有效的会话名称")
|
||||
else:
|
||||
st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
|
||||
st.session_state["cur_conv_name"] = conv_name
|
||||
st.session_state["conv_name"] = ""
|
||||
if st.button("删除"):
|
||||
if not conv_name or conv_name not in st.session_state["conversation_ids"]:
|
||||
st.error("请指定有效的会话名称")
|
||||
else:
|
||||
st.session_state["conversation_ids"].pop(conv_name, None)
|
||||
st.session_state["cur_conv_name"] = ""
|
||||
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
# TODO: 对话模型与会话绑定
|
||||
def on_mode_change():
|
||||
mode = st.session_state.dialogue_mode
|
||||
@ -210,12 +240,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
|
||||
def on_feedback(
|
||||
feedback,
|
||||
chat_history_id: str = "",
|
||||
message_id: str = "",
|
||||
history_index: int = -1,
|
||||
):
|
||||
reason = feedback["text"]
|
||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
api.chat_feedback(chat_history_id=chat_history_id,
|
||||
api.chat_feedback(message_id=message_id,
|
||||
score=score_int,
|
||||
reason=reason)
|
||||
st.session_state["need_rerun"] = True
|
||||
@ -231,9 +261,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
chat_history_id = ""
|
||||
message_id = ""
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
@ -243,16 +274,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
chat_history_id = t.get("chat_history_id", "")
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"chat_history_id": chat_history_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=chat_history_id,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1})
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
|
||||
@ -290,6 +290,7 @@ class ApiRequest:
|
||||
def chat_chat(
|
||||
self,
|
||||
query: str,
|
||||
conversation_id: str = None,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
@ -303,6 +304,7 @@ class ApiRequest:
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"conversation_id": conversation_id,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
@ -978,7 +980,7 @@ class ApiRequest:
|
||||
|
||||
def chat_feedback(
|
||||
self,
|
||||
chat_history_id: str,
|
||||
message_id: str,
|
||||
score: int,
|
||||
reason: str = "",
|
||||
) -> int:
|
||||
@ -986,7 +988,7 @@ class ApiRequest:
|
||||
反馈对话评价
|
||||
'''
|
||||
data = {
|
||||
"chat_history_id": chat_history_id,
|
||||
"message_id": message_id,
|
||||
"score": score,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user