This commit is contained in:
hzg0601 2023-10-20 19:22:36 +08:00
commit 6e9acfc1af
26 changed files with 321 additions and 192 deletions

View File

@ -1,5 +1,7 @@
import os import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库类型。可选faiss, milvus(离线) & zilliz(在线), pg. # 默认向量库类型。可选faiss, milvus(离线) & zilliz(在线), pg.
DEFAULT_VS_TYPE = "faiss" DEFAULT_VS_TYPE = "faiss"
@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右 # 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1 SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量 # 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3 SEARCH_ENGINE_TOP_K = 3
@ -36,6 +41,10 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG # 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = "" BING_SUBSCRIPTION_KEY = ""
# metaphor搜索需要KEY
METAPHOR_API_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置 # 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记 # 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
@ -47,7 +56,6 @@ KB_INFO = {
"知识库名称": "知识库介绍", "知识库名称": "知识库介绍",
"samples": "关于本项目issue的解答", "samples": "关于本项目issue的解答",
} }
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# 知识库默认存储路径 # 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")

View File

@ -44,7 +44,7 @@ MODEL_PATH = {
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4", "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat", "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat", "baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B", "baichuan-7b": "baichuan-inc/Baichuan-7B",
@ -112,7 +112,8 @@ TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 # TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = { LANGCHAIN_LLM_MODEL = {
# 不需要走Fschat封装的Langchain直接支持的模型。
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions # Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11 # 则需要将urllib3版本修改为1.25.11
@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = {
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI. # 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780' # 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
# 这些配置文件的名字不能改动
"Azure-OpenAI": {
"deployment_name": "your Azure deployment name",
"model_version": "0701",
"openai_api_type": "azure",
"api_base_url": "https://your Azure point.azure.com",
"api_version": "2023-07-01-preview",
"api_key": "your Azure api key",
"openai_proxy": "",
},
"OpenAI": {
"model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1", "api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY", "api_key": "your OPENAI_API_KEY",
"openai_proxy": "your OPENAI_PROXY", "openai_proxy": "",
}, },
"Anthropic": {
"model_name": "your claude model name(such as claude2-100k)",
"api_key":"your ANTHROPIC_API_KEY",
}
}
ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口 # 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn # 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": { "zhipu-api": {

View File

@ -1,3 +1,5 @@
import sys
sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH from configs.model_config import NLTK_DATA_PATH
import nltk import nltk

View File

@ -1,14 +1,15 @@
langchain>=0.0.314 langchain>=0.0.319
langchain-experimental>=0.0.30 langchain-experimental>=0.0.30
fschat[model_worker]==0.2.30 fschat[model_worker]==0.2.31
openai xformers>=0.0.22.post4
openai>=0.28.1
sentence_transformers sentence_transformers
transformers>=4.34 transformers>=4.34
torch>=2.0.1 # 推荐2.1 torch>=2.0.1 # 推荐2.1
torchvision torchvision
torchaudio torchaudio
fastapi>=0.103.2 fastapi>=0.104
nltk~=3.8.1 nltk>=3.8.1
uvicorn~=0.23.1 uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0
pydantic~=1.10.11 pydantic~=1.10.11
@ -43,7 +44,7 @@ pandas~=2.0.3
streamlit>=1.26.0 streamlit>=1.26.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9 streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx~=0.24.1
watchdog watchdog

View File

@ -1,13 +1,14 @@
langchain==0.0.313 langchain>=0.0.319
langchain-experimental==0.0.30 langchain-experimental>=0.0.30
fschat[model_worker]==0.2.30 fschat[model_worker]==0.2.31
openai xformers>=0.0.22.post4
openai>=0.28.1
sentence_transformers>=2.2.2 sentence_transformers>=2.2.2
transformers>=4.34 transformers>=4.34
torch>=2.0.1 torch>=2.1
torchvision torchvision
torchaudio torchaudio
fastapi>=0.103.1 fastapi>=0.104
nltk~=3.8.1 nltk~=3.8.1
uvicorn~=0.23.1 uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0

View File

@ -1,11 +1,11 @@
numpy~=1.24.4 numpy~=1.24.4
pandas~=2.0.3 pandas~=2.0.3
streamlit>=1.26.0 streamlit>=1.27.2
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.9 streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx>=0.25.0
nltk nltk>=3.8.1
watchdog watchdog
websockets websockets

View File

@ -97,8 +97,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="", llm_token="",
) )
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(
async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any, self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None: ) -> None:
self.cur_tool.update( self.cur_tool.update(
status=Status.start, status=Status.start,

View File

@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate
from typing import List from typing import List
from langchain.schema import AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
from server.agent import model_container from server.agent import model_container
begin = False
class CustomPromptTemplate(StringPromptTemplate): class CustomPromptTemplate(StringPromptTemplate):
# The template to use # The template to use
template: str template: str
@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish # Check if agent should finish
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"] support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin: if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
self.begin = False self.begin = False
stop_words = ["Observation:"] stop_words = ["Observation:"]

View File

@ -2,7 +2,6 @@
from .search_knowledge_simple import knowledge_search_simple from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once from .search_all_knowledge_once import knowledge_search_once
from .search_all_knowledge_more import knowledge_search_more from .search_all_knowledge_more import knowledge_search_more
from .travel_assistant import travel_assistant
from .calculate import calculate from .calculate import calculate
from .translator import translate from .translator import translate
from .weather import weathercheck from .weather import weathercheck

View File

@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
): ):

View File

@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):

View File

@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)

View File

@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
messages: List[OpenAiMessage] messages: List[OpenAiMessage]
temperature: float = 0.7 temperature: float = 0.7
n: int = 1 n: int = 1
max_tokens: int = 1024 max_tokens: int = None
stop: List[str] = [] stop: List[str] = []
stream: bool = False stream: bool = False
presence_penalty: int = 0 presence_penalty: int = 0

View File

@ -1,6 +1,7 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE) LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional from typing import List, Optional, Dict
from server.chat.utils import History from server.chat.utils import History
from langchain.docstore.document import Document from langchain.docstore.document import Document
import json import json
@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len) return search.results(text, result_len)
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from server.knowledge_base.utils import make_text_splitter
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
# metaphor 返回的内容都是长文本,需要分词再检索
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = make_text_splitter(splitter_name=splitter_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
vs = memo_faiss_pool.new_vector_store()
vs.add_documents(splitted_docs)
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
return docs
SEARCH_ENGINES = {"bing": bing_search, SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search, "duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
} }
@ -72,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():

View File

@ -140,7 +140,7 @@ if __name__ == "__main__":
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids) pprint(ids)
elif r == 2: # search docs elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0) docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs) pprint(docs)
if r == 3: # delete docs if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}") logger.warning(f"清除 {vs_name} by {name}")

View File

@ -1,7 +1,5 @@
import os import os
from transformers import AutoTokenizer from transformers import AutoTokenizer
from configs import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,

View File

@ -1,5 +1,5 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config) get_httpx_client, get_model_worker_config)
@ -16,7 +16,7 @@ def list_running_models(
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(controller_address + "/list_models") r = client.post(controller_address + "/list_models")
models = r.json()["models"] models = r.json()["models"]
data = {m: get_model_worker_config(m) for m in models} data = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data) return BaseResponse(data=data)
except Exception as e: except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', logger.error(f'{e.__class__.__name__}: {e}',
@ -52,7 +52,6 @@ def get_model_config(
获取LLM模型配置项合并后的 获取LLM模型配置项合并后的
''' '''
config = get_model_worker_config(model_name=model_name) config = get_model_worker_config(model_name=model_name)
# 删除ONLINE_MODEL配置中的敏感信息 # 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"]) del_keys = set(["worker_class"])
for k in config: for k in config:

View File

@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
"chat": { "chat": {
"domain": domain, "domain": domain,
"random_threshold": 0.5, "random_threshold": 0.5,
"max_tokens": 2048, "max_tokens": None,
"auditing": "default", "auditing": "default",
"temperature": temperature, "temperature": temperature,
} }

View File

@ -1,15 +1,15 @@
# import os # import os
# import sys # import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) # sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import requests
import json import json
import time import time
import hashlib import hashlib
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from typing import List, Literal from typing import List, Literal, Dict
from configs import TEMPERATURE from configs import TEMPERATURE
@ -20,29 +20,29 @@ def calculate_md5(input_string):
return encrypted return encrypted
def do_request(): def request_baichuan_api(
url = "https://api.baichuan-ai.com/v1/stream/chat" messages: List[Dict[str, str]],
api_key = "" api_key: str = None,
secret_key = "" secret_key: str = None,
version: str = "Baichuan2-53B",
temperature: float = TEMPERATURE,
model_name: str = "baichuan-api",
):
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
version = version or config.get("version")
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = { data = {
"model": "Baichuan2-53B", "model": version,
"messages": [ "messages": messages,
{ "parameters": {"temperature": temperature}
"role": "user",
"content": "世界第一高峰是"
}
],
"parameters": {
"temperature": 0.1,
"top_k": 10
}
} }
json_data = json.dumps(data) json_data = json.dumps(data)
time_stamp = int(time.time()) time_stamp = int(time.time())
signature = calculate_md5(secret_key + json_data + str(time_stamp)) signature = calculate_md5(secret_key + json_data + str(time_stamp))
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": "Bearer " + api_key, "Authorization": "Bearer " + api_key,
@ -52,18 +52,17 @@ def do_request():
"X-BC-Sign-Algo": "MD5", "X-BC-Sign-Algo": "MD5",
} }
response = requests.post(url, data=json_data, headers=headers) with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
if response.status_code == 200: for line in response.iter_lines():
print("请求成功!") if not line.strip():
print("响应header:", response.headers) continue
print("响应body:", response.text) resp = json.loads(line)
else: yield resp
print("请求失败,状态码:", response.status_code)
class BaiChuanWorker(ApiModelWorker): class BaiChuanWorker(ApiModelWorker):
BASE_URL = "https://api.baichuan-ai.com/v1/chat" BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
SUPPORT_MODELS = ["Baichuan2-53B"] SUPPORT_MODELS = ["Baichuan2-53B"]
def __init__( def __init__(
@ -95,40 +94,22 @@ class BaiChuanWorker(ApiModelWorker):
self.secret_key = config.get("secret_key") self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
data = { super().generate_stream_gate(params)
"model": self.version,
"messages": [
{
"role": "user",
"content": params["prompt"]
}
],
"parameters": {
"temperature": params.get("temperature",TEMPERATURE),
"top_k": params.get("top_k",1)
}
}
json_data = json.dumps(data) messages = self.prompt_to_messages(params["prompt"])
time_stamp = int(time.time())
signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(self.BASE_URL, data=json_data, headers=headers) text = ""
for resp in request_baichuan_api(messages=messages,
if response.status_code == 200: api_key=self.api_key,
resp = eval(response.text) secret_key=self.secret_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield json.dumps( yield json.dumps(
{ {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["data"]["messages"][-1]["content"] "text": text
}, },
ensure_ascii=False ensure_ascii=False
).encode() + b"\0" ).encode() + b"\0"
@ -141,8 +122,6 @@ class BaiChuanWorker(ApiModelWorker):
ensure_ascii=False ensure_ascii=False
).encode() + b"\0" ).encode() + b"\0"
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")

View File

@ -1,13 +1,13 @@
from configs.basic_config import LOG_PATH from configs.basic_config import LOG_PATH
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker from fastchat.serve.base_model_worker import BaseModelWorker
import uuid import uuid
import json import json
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
import fastchat import fastchat
import threading import asyncio
from typing import Dict, List from typing import Dict, List
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
worker_addr=worker_addr, worker_addr=worker_addr,
**kwargs) **kwargs)
self.context_len = context_len self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat() self.init_heart_beat()
def count_token(self, params): def count_token(self, params):
@ -62,15 +63,6 @@ class ApiModelWorker(BaseModelWorker):
print("embedding") print("embedding")
print(params) print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
# help methods # help methods
def get_config(self): def get_config(self):
from server.utils import get_model_worker_config from server.utils import get_model_worker_config

View File

@ -5,12 +5,11 @@ from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@ -40,6 +39,50 @@ def get_ChatOpenAI(
verbose: bool = True, verbose: bool = True,
**kwargs: Any, **kwargs: Any,
) -> ChatOpenAI: ) -> ChatOpenAI:
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
if model_name in config_models.get("langchain", {}):
config = config_models["langchain"][model_name]
if model_name == "Azure-OpenAI":
model = AzureChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
deployment_name=config.get("deployment_name"),
model_version=config.get("model_version"),
openai_api_type=config.get("openai_api_type"),
openai_api_base=config.get("api_base_url"),
openai_api_version=config.get("api_version"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
)
elif model_name == "OpenAI":
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
openai_api_base=config.get("api_base_url"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
)
elif model_name == "Anthropic":
model = ChatAnthropic(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
anthropic_api_key=config.get("api_key"),
)
## TODO 支持其他的Langchain原生支持的模型
else:
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
model = ChatOpenAI( model = ChatOpenAI(
streaming=streaming, streaming=streaming,
@ -53,6 +96,7 @@ def get_ChatOpenAI(
openai_proxy=config.get("openai_proxy"), openai_proxy=config.get("openai_proxy"),
**kwargs **kwargs
) )
return model return model
@ -249,8 +293,9 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon, redoc_favicon_url=favicon,
) )
# 从model_config中获取模型信息 # 从model_config中获取模型信息
def list_embed_models() -> List[str]: def list_embed_models() -> List[str]:
''' '''
get names of configured embedding models get names of configured embedding models
@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers = list(FSCHAT_MODEL_WORKERS) workers = list(FSCHAT_MODEL_WORKERS)
if LLM_MODEL not in workers: if LLM_MODEL not in workers:
workers.insert(0, LLM_MODEL) workers.insert(0, LLM_MODEL)
return { return {
"local": MODEL_PATH["llm_model"], "local": MODEL_PATH["llm_model"],
"langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL, "online": ONLINE_LLM_MODEL,
"worker": workers, "worker": workers,
} }
@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
return str(path) return str(path)
return path_str # THUDM/chatglm06b return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息 # 从server_config中获取服务信息
def get_model_worker_config(model_name: str = None) -> dict: def get_model_worker_config(model_name: str = None) -> dict:
''' '''
加载model worker的配置项 加载model worker的配置项
@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API # 在线模型API
if model_name in LANGCHAIN_LLM_MODEL:
config["langchain_model"] = True
config["worker_class"] = ""
if model_name in ONLINE_LLM_MODEL: if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True config["online_api"] = True
if provider := config.get("provider"): if provider := config.get("provider"):
@ -459,8 +509,9 @@ def set_httpx_config(
import urllib.request import urllib.request
urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:
import torch import torch
@ -568,6 +619,8 @@ def get_server_configs() -> Dict:
获取configs中的原始配置项供前端使用 获取configs中的原始配置项供前端使用
''' '''
from configs.kb_config import ( from configs.kb_config import (
DEFAULT_KNOWLEDGE_BASE,
DEFAULT_SEARCH_ENGINE,
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE,

View File

@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
controller_address: controller_address:
worker_address: worker_address:
对于Langchain支持的模型
langchain_model:True
不会使用fschat
对于online_api: 对于online_api:
online_api:True online_api:True
worker_class: `provider` worker_class: `provider`
@ -78,23 +80,25 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
""" """
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import worker_id, logger
import argparse import argparse
logger.setLevel(log_level)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args([]) args = parser.parse_args([])
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(args, k, v) setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
# 在线模型API # 在线模型API
if worker_class := kwargs.get("worker_class"): elif worker_class := kwargs.get("worker_class"):
from fastchat.serve.model_worker import app from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names, worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address, controller_addr=args.controller_address,
worker_addr=args.worker_address) worker_addr=args.worker_address)
sys.modules["fastchat.serve.model_worker"].worker = worker # sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型 # 本地模型
else: else:
from configs.model_config import VLLM_MODEL_DICT from configs.model_config import VLLM_MODEL_DICT
@ -103,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
from fastchat.serve.vllm_worker import VLLMWorker, app from fastchat.serve.vllm_worker import VLLMWorker, app
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto' args.tokenizer_mode = 'auto'
args.trust_remote_code= True args.trust_remote_code= True
@ -126,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.engine_use_ray = False args.engine_use_ray = False
args.disable_log_requests = False args.disable_log_requests = False
# 0.2.0 vllm后要加的参数 # 0.2.0 vllm后要加的参数, 但是这里不需要
args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置, args.max_model_len = None
args.revision = None args.revision = None
args.quantization = None args.quantization = None
args.max_log_len = None args.max_log_len = None
@ -155,10 +160,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
conv_template = args.conv_template, conv_template = args.conv_template,
) )
sys.modules["fastchat.serve.vllm_worker"].engine = engine sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker # sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else: else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3" args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB" args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量 args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -221,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
) )
sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
# sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})" app.title = f"FastChat LLM Server ({args.model_names[0]})"

View File

@ -0,0 +1,16 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint
def test_qwen():
messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages):
print(type(x))
pprint(x)
assert x["code"] == 0

View File

@ -21,13 +21,6 @@ if __name__ == "__main__":
} }
) )
if not chat_box.chat_inited:
running_models = api.list_running_models()
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行中的模型`{running_models}`, 您可以开始提问了."
)
pages = { pages = {
"对话": { "对话": {
"icon": "chat", "icon": "chat",

View File

@ -3,7 +3,8 @@ from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
import os import os
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
@ -40,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
返回类型为model_name, is_local_model 返回类型为model_name, is_local_model
''' '''
running_models = api.list_running_models() running_models = api.list_running_models()
if not running_models: if not running_models:
return "", False return "", False
@ -50,11 +50,16 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
local_models = [k for k, v in running_models.items() if not v.get("online_api")] local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models: if local_models:
return local_models[0], True return local_models[0], True
return list(running_models)[0], False
return running_models[0], False
def dialogue_page(api: ApiRequest): def dialogue_page(api: ApiRequest):
if not chat_box.chat_inited:
default_model = get_default_llm_model(api)[0]
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行的模型`{default_model}`, 您可以开始提问了."
)
chat_box.init_session() chat_box.init_session()
with st.sidebar: with st.sidebar:
@ -74,12 +79,13 @@ def dialogue_page(api: ApiRequest):
"搜索引擎问答", "搜索引擎问答",
"自定义Agent问答", "自定义Agent问答",
], ],
index=3, index=0,
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
def on_llm_change(): def on_llm_change():
if llm_model:
config = api.get_model_config(llm_model) config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型 if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
@ -91,15 +97,17 @@ def dialogue_page(api: ApiRequest):
return x return x
running_models = list(api.list_running_models()) running_models = list(api.list_running_models())
running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in worker_models: for m in worker_models:
if m not in running_models and m != "default": if m not in running_models and m != "default":
available_models.append(m) available_models.append(m)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型如GPT for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models: if not v.get("provider") and k not in running_models:
print(k, v) available_models.append(k)
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k) available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
@ -111,7 +119,8 @@ def dialogue_page(api: ApiRequest):
key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not api.get_model_config(llm_model).get("online_api") and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models): and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model") prev_model = st.session_state.get("prev_llm_model")
@ -156,9 +165,13 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "知识库问答": if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True): with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases() kb_list = api.list_knowledge_bases()
index = 0
if DEFAULT_KNOWLEDGE_BASE in kb_list:
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
selected_kb = st.selectbox( selected_kb = st.selectbox(
"请选择知识库:", "请选择知识库:",
kb_list, kb_list,
index=index,
on_change=on_kb_change, on_change=on_kb_change,
key="selected_kb", key="selected_kb",
) )
@ -167,11 +180,15 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines() search_engine_list = api.list_search_engines()
if DEFAULT_SEARCH_ENGINE in search_engine_list:
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
else:
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
with st.expander("搜索引擎配置", True): with st.expander("搜索引擎配置", True):
search_engine = st.selectbox( search_engine = st.selectbox(
label="请选择搜索引擎", label="请选择搜索引擎",
options=search_engine_list, options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0, index=index,
) )
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
@ -210,9 +227,9 @@ def dialogue_page(api: ApiRequest):
]) ])
text = "" text = ""
ans = "" ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent): if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n" ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False) chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt, for d in api.agent_chat(prompt,
history=history, history=history,

View File

@ -245,7 +245,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
**kwargs: Any, **kwargs: Any,
): ):
''' '''
@ -278,7 +278,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
**kwargs, **kwargs,
): ):
@ -308,7 +308,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
@ -340,7 +340,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
@ -378,7 +378,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''