mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-25 00:05:58 +08:00
* 北京黑客松更新 知识库支持: 支持zilliz数据库 Agent支持: 支持以下工具调用 1. 支持互联网Agent调用 2. 支持知识库Agent调用 3. 支持旅游助手工具(未上传) 知识库更新 1. 支持知识库简介,用于Agent选择 2. UI对应知识库简介 提示词选择 1. UI 和模板支持提示词模板更换选择
101 lines
5.5 KiB
Python
101 lines
5.5 KiB
Python
from fastapi import Body, Request
|
||
from fastapi.responses import StreamingResponse
|
||
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from server.utils import BaseResponse, get_prompt_template
|
||
from langchain.chains import LLMChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable, List, Optional
|
||
import asyncio
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from server.chat.utils import History
|
||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||
import json
|
||
import os
|
||
from urllib.parse import urlencode
|
||
from server.knowledge_base.kb_doc_api import search_docs
|
||
|
||
|
||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||
history: List[History] = Body([],
|
||
description="历史对话",
|
||
examples=[[
|
||
{"role": "user",
|
||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
{"role": "assistant",
|
||
"content": "虎头虎脑"}]]
|
||
),
|
||
stream: bool = Body(False, description="流式输出"),
|
||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
||
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||
):
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
|
||
history = [History.from_data(h) for h in history]
|
||
|
||
async def knowledge_base_chat_iterator(query: str,
|
||
top_k: int,
|
||
history: Optional[List[History]],
|
||
model_name: str = LLM_MODEL,
|
||
prompt_name: str = prompt_name,
|
||
) -> AsyncIterable[str]:
|
||
callback = AsyncIteratorCallbackHandler()
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=[callback],
|
||
)
|
||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||
context = "\n".join([doc.page_content for doc in docs])
|
||
|
||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||
chat_prompt = ChatPromptTemplate.from_messages(
|
||
[i.to_msg_template() for i in history] + [input_msg])
|
||
|
||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||
|
||
# Begin a task that runs in the background.
|
||
task = asyncio.create_task(wrap_done(
|
||
chain.acall({"context": context, "question": query}),
|
||
callback.done),
|
||
)
|
||
|
||
source_documents = []
|
||
for inum, doc in enumerate(docs):
|
||
filename = os.path.split(doc.metadata["source"])[-1]
|
||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||
url = f"/knowledge_base/download_doc?" + parameters
|
||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||
source_documents.append(text)
|
||
if stream:
|
||
async for token in callback.aiter():
|
||
# Use server-sent-events to stream the response
|
||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||
else:
|
||
answer = ""
|
||
async for token in callback.aiter():
|
||
answer += token
|
||
yield json.dumps({"answer": answer,
|
||
"docs": source_documents},
|
||
ensure_ascii=False)
|
||
|
||
await task
|
||
|
||
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
||
top_k=top_k,
|
||
history=history,
|
||
model_name=model_name,
|
||
prompt_name=prompt_name),
|
||
media_type="text/event-stream")
|