Langchain-Chatchat/server/chat/search_engine_chat.py
liunux4odoo b4c68ddd05
优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式 (#1886)
* 优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式

新功能
- 智谱AI、Minimax、千帆、千问 4 个在线模型支持 embeddings(不通过Fastchat,后续会单独提供相关api接口)
- 在线模型自动检测传入参数,在传入非 messages 格式的 prompt 时,自动转换为 completion 形式,以支持 completion 接口

开发者:
- 重构ApiModelWorker:
  - 所有在线 API 请求封装到 do_chat 方法:自动传入参数 ApiChatParams,简化参数与配置项的获取;自动处理与fastchat的接口
  - 加强 API 请求错误处理,返回更有意义的信息
  - 改用 qianfan sdk 重写 qianfan-api
  - 将所有在线模型的测试用例统一在一起,简化测试用例编写

* Delete requirements_langflow.txt
2023-10-26 22:44:48 +08:00

200 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
import json
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found",
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
bing_search_url=BING_SEARCH_URL)
return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
# metaphor 返回的内容都是长文本,需要分词再检索
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
return docs
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
}
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
async def lookup_search_engine(
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
):
search_engine = SEARCH_ENGINES[search_engine_name]
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
docs = search_result2docs(results)
return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`")
history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")