liunux4odoo d316efe8d3
release 0.2.6 (#1815)
## 🛠 新增功能

- 支持百川在线模型 (@hzg0601 @liunux4odoo in #1623)
- 支持 Azure OpenAI 与 claude 等 Langchain 自带模型 (@zRzRzRzRzRzRzR in #1808)
- Agent 功能大量更新,支持更多的工具、更换提示词、检索知识库 (@zRzRzRzRzRzRzR in #1626 #1666 #1785)
- 加长 32k 模型的历史记录 (@zRzRzRzRzRzRzR in #1629 #1630)
- *_chat 接口支持 max_tokens 参数 (@liunux4odoo in #1744)
- 实现 API 和 WebUI 的前后端分离 (@liunux4odoo in #1772)
- 支持 zlilliz 向量库 (@zRzRzRzRzRzRzR in #1785)
- 支持 metaphor 搜索引擎 (@liunux4odoo in #1792)
- 支持 p-tuning 模型 (@hzg0601 in #1810)
- 更新完善文档和 Wiki (@imClumsyPanda @zRzRzRzRzRzRzR @glide-the in #1680 #1811)

## 🐞 问题修复

- 修复 bge-* 模型匹配超过 1 的问题 (@zRzRzRzRzRzRzR in #1652)
- 修复系统代理为空的问题 (@glide-the in #1654)
- 修复重建知识库时 `d == self.d assert error` (@liunux4odoo in #1766)
- 修复对话历史消息错误 (@liunux4odoo in #1801)
- 修复 OpenAI 无法调用的 bug (@zRzRzRzRzRzRzR in #1808)
- 修复 windows下 BIND_HOST=0.0.0.0 时对话出错的问题 (@hzg0601 in #1810)
2023-10-20 23:16:06 +08:00

854 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import *
from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
LLM_MODEL,
TEMPERATURE,
SCORE_THRESHOLD,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
)
import httpx
from server.chat.openai_chat import OpenAiChatMsgIn
import contextlib
import json
import os
from io import BytesIO
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
from pprint import pprint
set_httpx_config()
class ApiRequest:
'''
api.py调用的封装同步模式,简化api调用方式
'''
def __init__(
self,
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
):
self.base_url = base_url
self.timeout = timeout
self._use_async = False
self._client = None
@property
def client(self):
if self._client is None or self._client.is_closed:
self._client = get_httpx_client(base_url=self.base_url,
use_async=self._use_async,
timeout=self.timeout)
return self._client
def get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream("GET", url, params=params, **kwargs)
else:
return self.client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs)
else:
return self.client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return self.client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
async def ret_async(response, as_json):
try:
async with response as r:
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try:
with response as r:
for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
if self._use_async:
return ret_async(response, as_json)
else:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
'''
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
'''
def to_json(r):
try:
return r.json()
except Exception as e:
msg = "API未能返回正确的JSON。" + str(e)
if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg}
if value_func is None:
value_func = (lambda r: r)
async def ret_async(response):
if as_json:
return value_func(to_json(await response))
else:
return value_func(await response)
if self._use_async:
return ret_async(response)
else:
if as_json:
return value_func(to_json(response))
else:
return value_func(response)
# 服务器信息
def get_server_configs(self, **kwargs):
response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, lambda r: r.json())
# 对话相关操作
def chat_fastchat(
self,
messages: List[Dict],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
**kwargs: Any,
):
'''
对应api.py/chat/fastchat接口
'''
msg = OpenAiChatMsgIn(**{
"messages": messages,
"stream": stream,
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs,
})
data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:")
pprint(data)
response = self.post(
"/chat/fastchat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response)
def chat_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
**kwargs,
):
'''
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
'''
data = {
"query": query,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response)
def agent_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/agent_chat 接口
'''
data = {
"query": query,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
def knowledge_base_chat(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/knowledge_base_chat接口
'''
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/search_engine_chat接口
'''
data = {
"query": query,
"search_engine_name": search_engine_name,
"top_k": top_k,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
response = self.post(
"/chat/search_engine_chat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
# 知识库相关操作
def list_knowledge_bases(
self,
):
'''
对应api.py/knowledge_base/list_knowledge_bases接口
'''
response = self.get("/knowledge_base/list_knowledge_bases")
return self._get_response_value(response,
as_json=True,
value_func=lambda r: r.get("data", []))
def create_knowledge_base(
self,
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
):
'''
对应api.py/knowledge_base/create_knowledge_base接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type,
"embed_model": embed_model,
}
response = self.post(
"/knowledge_base/create_knowledge_base",
json=data,
)
return self._get_response_value(response, as_json=True)
def delete_knowledge_base(
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/delete_knowledge_base接口
'''
response = self.post(
"/knowledge_base/delete_knowledge_base",
json=f"{knowledge_base_name}",
)
return self._get_response_value(response, as_json=True)
def list_kb_docs(
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/list_files接口
'''
response = self.get(
"/knowledge_base/list_files",
params={"knowledge_base_name": knowledge_base_name}
)
return self._get_response_value(response,
as_json=True,
value_func=lambda r: r.get("data", []))
def search_kb_docs(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
'''
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
}
response = self.post(
"/knowledge_base/search_docs",
json=data,
)
return self._get_response_value(response, as_json=True)
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/upload_docs接口
'''
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/upload_docs",
data=data,
files=[("files", (filename, file)) for filename, file in files],
)
return self._get_response_value(response, as_json=True)
def delete_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/delete_docs接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
response = self.post(
"/knowledge_base/delete_docs",
json=data,
)
return self._get_response_value(response, as_json=True)
def update_kb_info(self,knowledge_base_name,kb_info):
'''
对应api.py/knowledge_base/update_info接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"kb_info": kb_info,
}
response = self.post(
"/knowledge_base/update_info",
json=data,
)
return self._get_response_value(response, as_json=True)
def update_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/update_docs接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"override_custom_docs": override_custom_docs,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/update_docs",
json=data,
)
return self._get_response_value(response, as_json=True)
def recreate_vector_store(
self,
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
):
'''
对应api.py/knowledge_base/recreate_vector_store接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"allow_empty_kb": allow_empty_kb,
"vs_type": vs_type,
"embed_model": embed_model,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
}
response = self.post(
"/knowledge_base/recreate_vector_store",
json=data,
stream=True,
timeout=None,
)
return self._httpx_stream2generator(response, as_json=True)
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
):
'''
获取Fastchat中正运行的模型列表
'''
data = {
"controller_address": controller_address,
}
response = self.post(
"/llm_model/list_running_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
def list_config_models(self) -> Dict[str, List[str]]:
'''
获取服务器configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}。
'''
response = self.post(
"/llm_model/list_config_models",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def get_model_config(
self,
model_name: str = None,
) -> Dict:
'''
获取服务器上模型配置
'''
data={
"model_name": model_name,
}
response = self.post(
"/llm_model/get_model_config",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def list_search_engines(self) -> List[str]:
'''
获取服务器支持的搜索引擎
'''
response = self.post(
"/server/list_search_engines",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型。
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉。
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/stop",
json=data,
)
return self._get_response_value(response, as_json=True)
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型。
'''
if not model_name or not new_model_name:
return {
"code": 500,
"msg": f"未指定模型名称"
}
def ret_sync():
running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "无需切换"
}
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models.get("local", {}):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/change",
json=data,
)
return self._get_response_value(response, as_json=True)
async def ret_async():
running_models = await self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "无需切换"
}
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = await self.list_config_models()
if new_model_name not in config_models.get("local", {}):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/change",
json=data,
)
return self._get_response_value(response, as_json=True)
if self._use_async:
return ret_async()
else:
return ret_sync()
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
super().__init__(base_url, timeout)
self._use_async = True
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
'''
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""
if __name__ == "__main__":
api = ApiRequest()
aapi = AsyncApiRequest()
# print(api.chat_fastchat(
# messages=[{"role": "user", "content": "hello"}]
# ))
# with api.chat_chat("你好") as r:
# for t in r.iter_text(None):
# print(t)
# r = api.chat_chat("你好", no_remote_api=True)
# for t in r:
# print(t)
# r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True)
# for t in r:
# print(t)
# print(api.list_knowledge_bases())