Langchain-Chatchat/tests/api/test_stream_chat_api.py
liunux4odoo b3c7f8b072
修复webui中重建知识库以及对话界面UI错误 (#1615)
* 修复bug:webui点重建知识库时,如果存在不支持的文件会导致整个接口错误;migrate中没有导入CHUNK_SIZE

* 修复:webui对话界面的expander一直为running状态;简化历史消息获取方法
2023-09-28 15:12:03 +08:00

126 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是人工智能大模型"
}
],
"stream": True,
"temperature": 0.7,
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {
"stream": True,
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
}
dump_input(data2, api)
response = requests.post(url, headers=headers, json=data2, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_chat_chat(api="/chat/chat"):
url = f"{api_base_url}{api}"
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
global data
data["query"] = "室温超导最新进展是什么样?"
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
data["search_engine_name"] = se
dump_input(data, api + f" by {se}")
response = requests.post(url, json=data, stream=True)
if se == "bing" and not BING_SUBSCRIPTION_KEY:
data = response.json()
assert data["code"] == 404
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200