diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example
index a857e80b..40165b96 100644
--- a/configs/kb_config.py.example
+++ b/configs/kb_config.py.example
@@ -1,5 +1,7 @@
import os
+# 默认使用的知识库
+DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg.
DEFAULT_VS_TYPE = "faiss"
@@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
SCORE_THRESHOLD = 1
+# 默认搜索引擎。可选:bing, duckduckgo, metaphor
+DEFAULT_SEARCH_ENGINE = "duckduckgo"
+
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
@@ -36,6 +41,10 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
BING_SUBSCRIPTION_KEY = ""
+# metaphor搜索需要KEY
+METAPHOR_API_KEY = ""
+
+
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
@@ -47,7 +56,6 @@ KB_INFO = {
"知识库名称": "知识库介绍",
"samples": "关于本项目issue的解答",
}
-
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index 0111fc5d..78a10e9a 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -44,7 +44,7 @@ MODEL_PATH = {
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
- "baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
+ "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
@@ -112,7 +112,8 @@ TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
-ONLINE_LLM_MODEL = {
+LANGCHAIN_LLM_MODEL = {
+ # 不需要走Fschat封装的,Langchain直接支持的模型。
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = {
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
- "gpt-3.5-turbo": {
+
+ # 这些配置文件的名字不能改动
+ "Azure-OpenAI": {
+ "deployment_name": "your Azure deployment name",
+ "model_version": "0701",
+ "openai_api_type": "azure",
+ "api_base_url": "https://your Azure point.azure.com",
+ "api_version": "2023-07-01-preview",
+ "api_key": "your Azure api key",
+ "openai_proxy": "",
+ },
+ "OpenAI": {
+ "model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY",
- "openai_proxy": "your OPENAI_PROXY",
+ "openai_proxy": "",
},
+ "Anthropic": {
+ "model_name": "your claude model name(such as claude2-100k)",
+ "api_key":"your ANTHROPIC_API_KEY",
+ }
+}
+ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
diff --git a/init_database.py b/init_database.py
index 9e807a8e..c2cd1d49 100644
--- a/init_database.py
+++ b/init_database.py
@@ -1,3 +1,5 @@
+import sys
+sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk
diff --git a/requirements.txt b/requirements.txt
index 76312473..3da9c845 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,15 @@
-langchain>=0.0.314
+langchain>=0.0.319
langchain-experimental>=0.0.30
-fschat[model_worker]==0.2.30
-openai
+fschat[model_worker]==0.2.31
+xformers>=0.0.22.post4
+openai>=0.28.1
sentence_transformers
transformers>=4.34
torch>=2.0.1 # 推荐2.1
torchvision
torchaudio
-fastapi>=0.103.2
-nltk~=3.8.1
+fastapi>=0.104
+nltk>=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
@@ -43,7 +44,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
-streamlit-chatbox>=1.1.9
+streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
diff --git a/requirements_api.txt b/requirements_api.txt
index af4e7e08..b8a7f6d4 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,13 +1,14 @@
-langchain==0.0.313
-langchain-experimental==0.0.30
-fschat[model_worker]==0.2.30
-openai
+langchain>=0.0.319
+langchain-experimental>=0.0.30
+fschat[model_worker]==0.2.31
+xformers>=0.0.22.post4
+openai>=0.28.1
sentence_transformers>=2.2.2
transformers>=4.34
-torch>=2.0.1
+torch>=2.1
torchvision
torchaudio
-fastapi>=0.103.1
+fastapi>=0.104
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
diff --git a/requirements_webui.txt b/requirements_webui.txt
index 9caf085a..a36fec49 100644
--- a/requirements_webui.txt
+++ b/requirements_webui.txt
@@ -1,11 +1,11 @@
numpy~=1.24.4
pandas~=2.0.3
-streamlit>=1.26.0
+streamlit>=1.27.2
streamlit-option-menu>=0.3.6
-streamlit-antd-components>=0.1.11
-streamlit-chatbox>=1.1.9
+streamlit-antd-components>=0.2.3
+streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
-httpx~=0.24.1
-nltk
+httpx>=0.25.0
+nltk>=3.8.1
watchdog
websockets
diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py
index 3a82b9c7..49ce9730 100644
--- a/server/agent/callbacks.py
+++ b/server/agent/callbacks.py
@@ -97,8 +97,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
-
- async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any,
+ async def on_chat_model_start(
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List],
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py
index 22469c6b..fdac6e21 100644
--- a/server/agent/custom_template.py
+++ b/server/agent/custom_template.py
@@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from server.agent import model_container
-begin = False
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
@@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish
- support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
+ support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
self.begin = False
stop_words = ["Observation:"]
diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py
index 8bb5cac6..7031b71b 100644
--- a/server/agent/tools/__init__.py
+++ b/server/agent/tools/__init__.py
@@ -2,7 +2,6 @@
from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once
from .search_all_knowledge_more import knowledge_search_more
-from .travel_assistant import travel_assistant
from .calculate import calculate
from .translator import translate
from .weather import weathercheck
diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py
index c78add68..5a71478b 100644
--- a/server/chat/agent_chat.py
+++ b/server/chat/agent_chat.py
@@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
- # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
+ max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 3ec68558..4402185a 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
- # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
+ max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index c39b147e..19ca871f 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
- # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
+ max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py
index 4a46ddd9..7efb0a8b 100644
--- a/server/chat/openai_chat.py
+++ b/server/chat/openai_chat.py
@@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
- max_tokens: int = 1024
+ max_tokens: int = None
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 83ed65e4..e1ccaa48 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -1,6 +1,7 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
-from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
- LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
+from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
+ LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
+ TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
@@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
-from typing import List, Optional
+from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
import json
@@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len)
+def metaphor_search(
+ text: str,
+ result_len: int = SEARCH_ENGINE_TOP_K,
+ splitter_name: str = "SpacyTextSplitter",
+ chunk_size: int = 500,
+ chunk_overlap: int = OVERLAP_SIZE,
+) -> List[Dict]:
+ from metaphor_python import Metaphor
+ from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
+ from server.knowledge_base.utils import make_text_splitter
+
+ if not METAPHOR_API_KEY:
+ return []
+
+ client = Metaphor(METAPHOR_API_KEY)
+ search = client.search(text, num_results=result_len, use_autoprompt=True)
+ contents = search.get_contents().contents
+
+ # metaphor 返回的内容都是长文本,需要分词再检索
+ docs = [Document(page_content=x.extract,
+ metadata={"link": x.url, "title": x.title})
+ for x in contents]
+ text_splitter = make_text_splitter(splitter_name=splitter_name,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap)
+ splitted_docs = text_splitter.split_documents(docs)
+
+ # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
+ if len(splitted_docs) > result_len:
+ vs = memo_faiss_pool.new_vector_store()
+ vs.add_documents(splitted_docs)
+ splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
+
+ docs = [{"snippet": x.page_content,
+ "link": x.metadata["link"],
+ "title": x.metadata["title"]}
+ for x in splitted_docs]
+ return docs
+
+
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
+ "metaphor": metaphor_search,
}
@@ -72,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
- # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
+ max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py
index 801e4a6b..c00ac4f6 100644
--- a/server/knowledge_base/kb_cache/faiss_cache.py
+++ b/server/knowledge_base/kb_cache/faiss_cache.py
@@ -140,7 +140,7 @@ if __name__ == "__main__":
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
- docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
+ docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index 02212c88..c73d0219 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -1,7 +1,5 @@
import os
-
from transformers import AutoTokenizer
-
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
diff --git a/server/llm_api.py b/server/llm_api.py
index dc9ddced..453f1473 100644
--- a/server/llm_api.py
+++ b/server/llm_api.py
@@ -1,5 +1,5 @@
from fastapi import Body
-from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
+from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
@@ -16,7 +16,7 @@ def list_running_models(
with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
models = r.json()["models"]
- data = {m: get_model_worker_config(m) for m in models}
+ data = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data)
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
@@ -52,7 +52,6 @@ def get_model_config(
获取LLM模型配置项(合并后的)
'''
config = get_model_worker_config(model_name=model_name)
-
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:
diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py
index e1dce6a0..c4e090e8 100644
--- a/server/model_workers/SparkApi.py
+++ b/server/model_workers/SparkApi.py
@@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
"chat": {
"domain": domain,
"random_threshold": 0.5,
- "max_tokens": 2048,
+ "max_tokens": None,
"auditing": "default",
"temperature": temperature,
}
diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py
index 879a51ec..1c6a6f1d 100644
--- a/server/model_workers/baichuan.py
+++ b/server/model_workers/baichuan.py
@@ -1,15 +1,15 @@
# import os
# import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
-import requests
import json
import time
import hashlib
from server.model_workers.base import ApiModelWorker
+from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv
import sys
import json
-from typing import List, Literal
+from typing import List, Literal, Dict
from configs import TEMPERATURE
@@ -20,29 +20,29 @@ def calculate_md5(input_string):
return encrypted
-def do_request():
- url = "https://api.baichuan-ai.com/v1/stream/chat"
- api_key = ""
- secret_key = ""
+def request_baichuan_api(
+ messages: List[Dict[str, str]],
+ api_key: str = None,
+ secret_key: str = None,
+ version: str = "Baichuan2-53B",
+ temperature: float = TEMPERATURE,
+ model_name: str = "baichuan-api",
+):
+ config = get_model_worker_config(model_name)
+ api_key = api_key or config.get("api_key")
+ secret_key = secret_key or config.get("secret_key")
+ version = version or config.get("version")
+ url = "https://api.baichuan-ai.com/v1/stream/chat"
data = {
- "model": "Baichuan2-53B",
- "messages": [
- {
- "role": "user",
- "content": "世界第一高峰是"
- }
- ],
- "parameters": {
- "temperature": 0.1,
- "top_k": 10
- }
+ "model": version,
+ "messages": messages,
+ "parameters": {"temperature": temperature}
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(secret_key + json_data + str(time_stamp))
-
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key,
@@ -52,18 +52,17 @@ def do_request():
"X-BC-Sign-Algo": "MD5",
}
- response = requests.post(url, data=json_data, headers=headers)
-
- if response.status_code == 200:
- print("请求成功!")
- print("响应header:", response.headers)
- print("响应body:", response.text)
- else:
- print("请求失败,状态码:", response.status_code)
+ with get_httpx_client() as client:
+ with client.stream("POST", url, headers=headers, json=data) as response:
+ for line in response.iter_lines():
+ if not line.strip():
+ continue
+ resp = json.loads(line)
+ yield resp
class BaiChuanWorker(ApiModelWorker):
- BASE_URL = "https://api.baichuan-ai.com/v1/chat"
+ BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
SUPPORT_MODELS = ["Baichuan2-53B"]
def __init__(
@@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker):
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params):
- data = {
- "model": self.version,
- "messages": [
- {
- "role": "user",
- "content": params["prompt"]
- }
- ],
- "parameters": {
- "temperature": params.get("temperature",TEMPERATURE),
- "top_k": params.get("top_k",1)
- }
- }
+ super().generate_stream_gate(params)
- json_data = json.dumps(data)
- time_stamp = int(time.time())
- signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
- headers = {
- "Content-Type": "application/json",
- "Authorization": "Bearer " + self.api_key,
- "X-BC-Request-Id": "your requestId",
- "X-BC-Timestamp": str(time_stamp),
- "X-BC-Signature": signature,
- "X-BC-Sign-Algo": "MD5",
- }
+ messages = self.prompt_to_messages(params["prompt"])
- response = requests.post(self.BASE_URL, data=json_data, headers=headers)
+ text = ""
+ for resp in request_baichuan_api(messages=messages,
+ api_key=self.api_key,
+ secret_key=self.secret_key,
+ version=self.version,
+ temperature=params.get("temperature")):
+ if resp["code"] == 0:
+ text += resp["data"]["messages"][-1]["content"]
+ yield json.dumps(
+ {
+ "error_code": resp["code"],
+ "text": text
+ },
+ ensure_ascii=False
+ ).encode() + b"\0"
+ else:
+ yield json.dumps(
+ {
+ "error_code": resp["code"],
+ "text": resp["msg"]
+ },
+ ensure_ascii=False
+ ).encode() + b"\0"
- if response.status_code == 200:
- resp = eval(response.text)
- yield json.dumps(
- {
- "error_code": resp["code"],
- "text": resp["data"]["messages"][-1]["content"]
- },
- ensure_ascii=False
- ).encode() + b"\0"
- else:
- yield json.dumps(
- {
- "error_code": resp["code"],
- "text": resp["msg"]
- },
- ensure_ascii=False
- ).encode() + b"\0"
-
-
-
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
diff --git a/server/model_workers/base.py b/server/model_workers/base.py
index ea141046..2b39bd6b 100644
--- a/server/model_workers/base.py
+++ b/server/model_workers/base.py
@@ -1,13 +1,13 @@
from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
-from fastchat.serve.model_worker import BaseModelWorker
+from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
-import threading
+import asyncio
from typing import Dict, List
@@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
+ self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat()
def count_token(self, params):
@@ -62,15 +63,6 @@ class ApiModelWorker(BaseModelWorker):
print("embedding")
print(params)
- # workaround to make program exit with Ctrl+c
- # it should be deleted after pr is merged by fastchat
- def init_heart_beat(self):
- self.register_to_controller()
- self.heart_beat_thread = threading.Thread(
- target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
- )
- self.heart_beat_thread.start()
-
# help methods
def get_config(self):
from server.utils import get_model_worker_config
diff --git a/server/utils.py b/server/utils.py
index 2f6dfc49..5dff5f1b 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -5,12 +5,11 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
- MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
- logger, log_verbose,
+ MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
-from langchain.chat_models import ChatOpenAI
+from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@@ -40,19 +39,64 @@ def get_ChatOpenAI(
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
- config = get_model_worker_config(model_name)
- model = ChatOpenAI(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- openai_api_key=config.get("api_key", "EMPTY"),
- openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- openai_proxy=config.get("openai_proxy"),
- **kwargs
- )
+ ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
+ config_models = list_config_llm_models()
+ if model_name in config_models.get("langchain", {}):
+ config = config_models["langchain"][model_name]
+ if model_name == "Azure-OpenAI":
+ model = AzureChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ deployment_name=config.get("deployment_name"),
+ model_version=config.get("model_version"),
+ openai_api_type=config.get("openai_api_type"),
+ openai_api_base=config.get("api_base_url"),
+ openai_api_version=config.get("api_version"),
+ openai_api_key=config.get("api_key"),
+ openai_proxy=config.get("openai_proxy"),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+
+ elif model_name == "OpenAI":
+ model = ChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ model_name=config.get("model_name"),
+ openai_api_base=config.get("api_base_url"),
+ openai_api_key=config.get("api_key"),
+ openai_proxy=config.get("openai_proxy"),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ elif model_name == "Anthropic":
+ model = ChatAnthropic(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ model_name=config.get("model_name"),
+ anthropic_api_key=config.get("api_key"),
+
+ )
+ ## TODO 支持其他的Langchain原生支持的模型
+ else:
+ ## 非Langchain原生支持的模型,走Fschat封装
+ config = get_model_worker_config(model_name)
+ model = ChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ openai_api_key=config.get("api_key", "EMPTY"),
+ openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ openai_proxy=config.get("openai_proxy"),
+ **kwargs
+ )
+
return model
@@ -249,8 +293,9 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon,
)
+ # 从model_config中获取模型信息
+
-# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
get names of configured embedding models
@@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers = list(FSCHAT_MODEL_WORKERS)
if LLM_MODEL not in workers:
workers.insert(0, LLM_MODEL)
-
return {
"local": MODEL_PATH["llm_model"],
+ "langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL,
"worker": workers,
}
@@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
return str(path)
return path_str # THUDM/chatglm06b
+ # 从server_config中获取服务信息
+
-# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项。
@@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API
+ if model_name in LANGCHAIN_LLM_MODEL:
+ config["langchain_model"] = True
+ config["worker_class"] = ""
+
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
@@ -389,7 +439,7 @@ def webui_address() -> str:
return f"http://{host}:{port}"
-def get_prompt_template(type:str,name: str) -> Optional[str]:
+def get_prompt_template(type: str, name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
@@ -459,8 +509,9 @@ def set_httpx_config(
import urllib.request
urllib.request.getproxies = _get_proxies
+ # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
+
-# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
@@ -568,6 +619,8 @@ def get_server_configs() -> Dict:
获取configs中的原始配置项,供前端使用
'''
from configs.kb_config import (
+ DEFAULT_KNOWLEDGE_BASE,
+ DEFAULT_SEARCH_ENGINE,
DEFAULT_VS_TYPE,
CHUNK_SIZE,
OVERLAP_SIZE,
diff --git a/startup.py b/startup.py
index 878e7eca..4bb34496 100644
--- a/startup.py
+++ b/startup.py
@@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
controller_address:
worker_address:
-
+ 对于Langchain支持的模型:
+ langchain_model:True
+ 不会使用fschat
对于online_api:
online_api:True
worker_class: `provider`
@@ -78,31 +80,34 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
"""
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
- from fastchat.serve.model_worker import worker_id, logger
import argparse
- logger.setLevel(log_level)
parser = argparse.ArgumentParser()
args = parser.parse_args([])
for k, v in kwargs.items():
setattr(args, k, v)
-
+ if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
+ from fastchat.serve.base_model_worker import app
+ worker = ""
# 在线模型API
- if worker_class := kwargs.get("worker_class"):
- from fastchat.serve.model_worker import app
+ elif worker_class := kwargs.get("worker_class"):
+ from fastchat.serve.base_model_worker import app
+
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
- sys.modules["fastchat.serve.model_worker"].worker = worker
+ # sys.modules["fastchat.serve.base_model_worker"].worker = worker
+ sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker
- from fastchat.serve.vllm_worker import VLLMWorker,app
+ from fastchat.serve.vllm_worker import VLLMWorker, app
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
+
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code= True
@@ -126,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.engine_use_ray = False
args.disable_log_requests = False
- # 0.2.0 vllm后要加的参数
- args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置,
+ # 0.2.0 vllm后要加的参数, 但是这里不需要
+ args.max_model_len = None
args.revision = None
args.quantization = None
args.max_log_len = None
@@ -155,10 +160,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
conv_template = args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
- sys.modules["fastchat.serve.vllm_worker"].worker = worker
+ # sys.modules["fastchat.serve.vllm_worker"].worker = worker
+ sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else:
- from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
+ from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
+
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
@@ -221,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
-
- sys.modules["fastchat.serve.model_worker"].worker = worker
+ # sys.modules["fastchat.serve.model_worker"].worker = worker
+ sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
diff --git a/tests/online_api/test_baichuan.py b/tests/online_api/test_baichuan.py
new file mode 100644
index 00000000..536466ee
--- /dev/null
+++ b/tests/online_api/test_baichuan.py
@@ -0,0 +1,16 @@
+import sys
+from pathlib import Path
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+
+from server.model_workers.baichuan import request_baichuan_api
+from pprint import pprint
+
+
+def test_qwen():
+ messages = [{"role": "user", "content": "hello"}]
+
+ for x in request_baichuan_api(messages):
+ print(type(x))
+ pprint(x)
+ assert x["code"] == 0
\ No newline at end of file
diff --git a/webui.py b/webui.py
index 776d5e60..85d6cb40 100644
--- a/webui.py
+++ b/webui.py
@@ -21,13 +21,6 @@ if __name__ == "__main__":
}
)
- if not chat_box.chat_inited:
- running_models = api.list_running_models()
- st.toast(
- f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
- f"当前运行中的模型`{running_models}`, 您可以开始提问了."
- )
-
pages = {
"对话": {
"icon": "chat",
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index adc60293..8628aac8 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -3,7 +3,8 @@ from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
import os
-from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES
+from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
+ DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict
chat_box = ChatBox(
@@ -40,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
返回类型为(model_name, is_local_model)
'''
running_models = api.list_running_models()
-
if not running_models:
return "", False
@@ -50,12 +50,17 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
-
- return running_models[0], False
+ return list(running_models)[0], False
def dialogue_page(api: ApiRequest):
- chat_box.init_session()
+ if not chat_box.chat_inited:
+ default_model = get_default_llm_model(api)[0]
+ st.toast(
+ f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
+ f"当前运行的模型`{default_model}`, 您可以开始提问了."
+ )
+ chat_box.init_session()
with st.sidebar:
# TODO: 对话模型与会话绑定
@@ -74,16 +79,17 @@ def dialogue_page(api: ApiRequest):
"搜索引擎问答",
"自定义Agent问答",
],
- index=3,
+ index=0,
on_change=on_mode_change,
key="dialogue_mode",
)
def on_llm_change():
- config = api.get_model_config(llm_model)
- if not config.get("online_api"): # 只有本地model_worker可以切换模型
- st.session_state["prev_llm_model"] = llm_model
- st.session_state["cur_llm_model"] = st.session_state.llm_model
+ if llm_model:
+ config = api.get_model_config(llm_model)
+ if not config.get("online_api"): # 只有本地model_worker可以切换模型
+ st.session_state["prev_llm_model"] = llm_model
+ st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x):
if x in running_models:
@@ -91,16 +97,18 @@ def dialogue_page(api: ApiRequest):
return x
running_models = list(api.list_running_models())
+ running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = []
config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in worker_models:
if m not in running_models and m != "default":
available_models.append(m)
- for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
+ for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models:
- print(k, v)
available_models.append(k)
+ for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
+ available_models.append(k)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
llm_model = st.selectbox("选择LLM模型:",
@@ -111,7 +119,8 @@ def dialogue_page(api: ApiRequest):
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
- and not api.get_model_config(llm_model).get("online_api")
+ and not llm_model in config_models.get("online", {})
+ and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
@@ -156,9 +165,13 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases()
+ index = 0
+ if DEFAULT_KNOWLEDGE_BASE in kb_list:
+ index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
selected_kb = st.selectbox(
"请选择知识库:",
kb_list,
+ index=index,
on_change=on_kb_change,
key="selected_kb",
)
@@ -167,11 +180,15 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()
+ if DEFAULT_SEARCH_ENGINE in search_engine_list:
+ index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
+ else:
+ index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
with st.expander("搜索引擎配置", True):
search_engine = st.selectbox(
label="请选择搜索引擎",
options=search_engine_list,
- index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
+ index=index,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
@@ -210,9 +227,9 @@ def dialogue_page(api: ApiRequest):
])
text = ""
ans = ""
- support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
+ support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
- ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n"
+ ans += "正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt,
history=history,
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 7b9e161c..8190dba1 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -245,7 +245,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
- max_tokens: int = 1024,
+ max_tokens: int = None,
**kwargs: Any,
):
'''
@@ -278,7 +278,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
- max_tokens: int = 1024,
+ max_tokens: int = None,
prompt_name: str = "default",
**kwargs,
):
@@ -308,7 +308,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
- max_tokens: int = 1024,
+ max_tokens: int = None,
prompt_name: str = "default",
):
'''
@@ -340,7 +340,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
- max_tokens: int = 1024,
+ max_tokens: int = None,
prompt_name: str = "default",
):
'''
@@ -378,7 +378,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
- max_tokens: int = 1024,
+ max_tokens: int = None,
prompt_name: str = "default",
):
'''