qiankunli fa906b33a8
添加对话评分与历史消息保存功能 (#1940)
* 新功能:
- WEBUI 添加对话评分功能
- 增加 /chat/feedback 接口,用于接收对话评分
- /chat/chat 接口返回值由 str 改为 {"text":str, "chat_history_id": str}
- init_database.py 添加 --create-tables --clear-tables 参数

依赖:
- streamlit-chatbox==1.1.11

开发者:
- ChatHistoryModel 的 id 字段支持自动生成
- SAVE_CHAT_HISTORY 改到 basic_config.py

* 修复:点击反馈后页面未刷新

---------

Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
2023-11-03 11:31:45 +08:00

85 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.utils import get_prompt_template
from server.db.repository.chat_history_repository import add_chat_history_to_db, update_chat_history
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
answer = ""
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
if stream:
async for token in callback.aiter():
answer += token
# Use server-sent-events to stream the response
yield json.dumps(
{"text": token, "chat_history_id": chat_history_id},
ensure_ascii=False)
else:
async for token in callback.aiter():
answer += token
yield json.dumps(
{"text": answer, "chat_history_id": chat_history_id},
ensure_ascii=False)
if SAVE_CHAT_HISTORY and len(chat_history_id) > 0:
# 后续可以加入一些其他信息比如真实的prompt等
update_chat_history(chat_history_id, response=answer)
await task
return StreamingResponse(chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")