mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-09 00:25:46 +08:00
Merge branch 'dev' of https://github.com/chatchat-space/Langchain-Chatchat into dev
This commit is contained in:
commit
6e9acfc1af
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
# 默认使用的知识库
|
||||||
|
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||||
|
|
||||||
# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg.
|
# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg.
|
||||||
DEFAULT_VS_TYPE = "faiss"
|
DEFAULT_VS_TYPE = "faiss"
|
||||||
@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3
|
|||||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||||
SCORE_THRESHOLD = 1
|
SCORE_THRESHOLD = 1
|
||||||
|
|
||||||
|
# 默认搜索引擎。可选:bing, duckduckgo, metaphor
|
||||||
|
DEFAULT_SEARCH_ENGINE = "duckduckgo"
|
||||||
|
|
||||||
# 搜索引擎匹配结题数量
|
# 搜索引擎匹配结题数量
|
||||||
SEARCH_ENGINE_TOP_K = 3
|
SEARCH_ENGINE_TOP_K = 3
|
||||||
|
|
||||||
@ -36,6 +41,10 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
|||||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||||
BING_SUBSCRIPTION_KEY = ""
|
BING_SUBSCRIPTION_KEY = ""
|
||||||
|
|
||||||
|
# metaphor搜索需要KEY
|
||||||
|
METAPHOR_API_KEY = ""
|
||||||
|
|
||||||
|
|
||||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||||
@ -47,7 +56,6 @@ KB_INFO = {
|
|||||||
"知识库名称": "知识库介绍",
|
"知识库名称": "知识库介绍",
|
||||||
"samples": "关于本项目issue的解答",
|
"samples": "关于本项目issue的解答",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
# 通常情况下不需要更改以下内容
|
||||||
# 知识库默认存储路径
|
# 知识库默认存储路径
|
||||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||||
|
|||||||
@ -44,7 +44,7 @@ MODEL_PATH = {
|
|||||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
|
|
||||||
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
|
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
|
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
|
||||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||||
@ -112,7 +112,8 @@ TEMPERATURE = 0.7
|
|||||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||||
|
|
||||||
|
|
||||||
ONLINE_LLM_MODEL = {
|
LANGCHAIN_LLM_MODEL = {
|
||||||
|
# 不需要走Fschat封装的,Langchain直接支持的模型。
|
||||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||||
# Max retries exceeded with url: /v1/chat/completions
|
# Max retries exceeded with url: /v1/chat/completions
|
||||||
# 则需要将urllib3版本修改为1.25.11
|
# 则需要将urllib3版本修改为1.25.11
|
||||||
@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = {
|
|||||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
||||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||||
"gpt-3.5-turbo": {
|
|
||||||
|
# 这些配置文件的名字不能改动
|
||||||
|
"Azure-OpenAI": {
|
||||||
|
"deployment_name": "your Azure deployment name",
|
||||||
|
"model_version": "0701",
|
||||||
|
"openai_api_type": "azure",
|
||||||
|
"api_base_url": "https://your Azure point.azure.com",
|
||||||
|
"api_version": "2023-07-01-preview",
|
||||||
|
"api_key": "your Azure api key",
|
||||||
|
"openai_proxy": "",
|
||||||
|
},
|
||||||
|
"OpenAI": {
|
||||||
|
"model_name": "your openai model name(such as gpt-4)",
|
||||||
"api_base_url": "https://api.openai.com/v1",
|
"api_base_url": "https://api.openai.com/v1",
|
||||||
"api_key": "your OPENAI_API_KEY",
|
"api_key": "your OPENAI_API_KEY",
|
||||||
"openai_proxy": "your OPENAI_PROXY",
|
"openai_proxy": "",
|
||||||
},
|
},
|
||||||
|
"Anthropic": {
|
||||||
|
"model_name": "your claude model name(such as claude2-100k)",
|
||||||
|
"api_key":"your ANTHROPIC_API_KEY",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ONLINE_LLM_MODEL = {
|
||||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||||
"zhipu-api": {
|
"zhipu-api": {
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append(".")
|
||||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
from configs.model_config import NLTK_DATA_PATH
|
||||||
import nltk
|
import nltk
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
langchain>=0.0.314
|
langchain>=0.0.319
|
||||||
langchain-experimental>=0.0.30
|
langchain-experimental>=0.0.30
|
||||||
fschat[model_worker]==0.2.30
|
fschat[model_worker]==0.2.31
|
||||||
openai
|
xformers>=0.0.22.post4
|
||||||
|
openai>=0.28.1
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
transformers>=4.34
|
transformers>=4.34
|
||||||
torch>=2.0.1 # 推荐2.1
|
torch>=2.0.1 # 推荐2.1
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
fastapi>=0.103.2
|
fastapi>=0.104
|
||||||
nltk~=3.8.1
|
nltk>=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
pydantic~=1.10.11
|
pydantic~=1.10.11
|
||||||
@ -43,7 +44,7 @@ pandas~=2.0.3
|
|||||||
streamlit>=1.26.0
|
streamlit>=1.26.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu>=0.3.6
|
||||||
streamlit-antd-components>=0.1.11
|
streamlit-antd-components>=0.1.11
|
||||||
streamlit-chatbox>=1.1.9
|
streamlit-chatbox==1.1.10
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx~=0.24.1
|
||||||
watchdog
|
watchdog
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
langchain==0.0.313
|
langchain>=0.0.319
|
||||||
langchain-experimental==0.0.30
|
langchain-experimental>=0.0.30
|
||||||
fschat[model_worker]==0.2.30
|
fschat[model_worker]==0.2.31
|
||||||
openai
|
xformers>=0.0.22.post4
|
||||||
|
openai>=0.28.1
|
||||||
sentence_transformers>=2.2.2
|
sentence_transformers>=2.2.2
|
||||||
transformers>=4.34
|
transformers>=4.34
|
||||||
torch>=2.0.1
|
torch>=2.1
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
fastapi>=0.103.1
|
fastapi>=0.104
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
numpy~=1.24.4
|
numpy~=1.24.4
|
||||||
pandas~=2.0.3
|
pandas~=2.0.3
|
||||||
streamlit>=1.26.0
|
streamlit>=1.27.2
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu>=0.3.6
|
||||||
streamlit-antd-components>=0.1.11
|
streamlit-antd-components>=0.2.3
|
||||||
streamlit-chatbox>=1.1.9
|
streamlit-chatbox==1.1.10
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx>=0.25.0
|
||||||
nltk
|
nltk>=3.8.1
|
||||||
watchdog
|
watchdog
|
||||||
websockets
|
websockets
|
||||||
|
|||||||
@ -97,8 +97,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
llm_token="",
|
llm_token="",
|
||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
async def on_chat_model_start(
|
||||||
async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.start,
|
status=Status.start,
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate
|
|||||||
from typing import List
|
from typing import List
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
from server.agent import model_container
|
from server.agent import model_container
|
||||||
begin = False
|
|
||||||
class CustomPromptTemplate(StringPromptTemplate):
|
class CustomPromptTemplate(StringPromptTemplate):
|
||||||
# The template to use
|
# The template to use
|
||||||
template: str
|
template: str
|
||||||
@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||||
# Check if agent should finish
|
# Check if agent should finish
|
||||||
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
|
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||||
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
||||||
self.begin = False
|
self.begin = False
|
||||||
stop_words = ["Observation:"]
|
stop_words = ["Observation:"]
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
from .search_knowledge_simple import knowledge_search_simple
|
from .search_knowledge_simple import knowledge_search_simple
|
||||||
from .search_all_knowledge_once import knowledge_search_once
|
from .search_all_knowledge_once import knowledge_search_once
|
||||||
from .search_all_knowledge_more import knowledge_search_more
|
from .search_all_knowledge_more import knowledge_search_more
|
||||||
from .travel_assistant import travel_assistant
|
|
||||||
from .calculate import calculate
|
from .calculate import calculate
|
||||||
from .translator import translate
|
from .translator import translate
|
||||||
from .weather import weathercheck
|
from .weather import weathercheck
|
||||||
|
|||||||
@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
|
||||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
):
|
):
|
||||||
|
|||||||
@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
|
|||||||
@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
|
|||||||
messages: List[OpenAiMessage]
|
messages: List[OpenAiMessage]
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
n: int = 1
|
n: int = 1
|
||||||
max_tokens: int = 1024
|
max_tokens: int = None
|
||||||
stop: List[str] = []
|
stop: List[str] = []
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
presence_penalty: int = 0
|
presence_penalty: int = 0
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
|
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||||||
|
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
import json
|
import json
|
||||||
@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
|||||||
return search.results(text, result_len)
|
return search.results(text, result_len)
|
||||||
|
|
||||||
|
|
||||||
|
def metaphor_search(
|
||||||
|
text: str,
|
||||||
|
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||||
|
splitter_name: str = "SpacyTextSplitter",
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
) -> List[Dict]:
|
||||||
|
from metaphor_python import Metaphor
|
||||||
|
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||||
|
from server.knowledge_base.utils import make_text_splitter
|
||||||
|
|
||||||
|
if not METAPHOR_API_KEY:
|
||||||
|
return []
|
||||||
|
|
||||||
|
client = Metaphor(METAPHOR_API_KEY)
|
||||||
|
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||||||
|
contents = search.get_contents().contents
|
||||||
|
|
||||||
|
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||||
|
docs = [Document(page_content=x.extract,
|
||||||
|
metadata={"link": x.url, "title": x.title})
|
||||||
|
for x in contents]
|
||||||
|
text_splitter = make_text_splitter(splitter_name=splitter_name,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap)
|
||||||
|
splitted_docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||||
|
if len(splitted_docs) > result_len:
|
||||||
|
vs = memo_faiss_pool.new_vector_store()
|
||||||
|
vs.add_documents(splitted_docs)
|
||||||
|
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
|
||||||
|
|
||||||
|
docs = [{"snippet": x.page_content,
|
||||||
|
"link": x.metadata["link"],
|
||||||
|
"title": x.metadata["title"]}
|
||||||
|
for x in splitted_docs]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
SEARCH_ENGINES = {"bing": bing_search,
|
SEARCH_ENGINES = {"bing": bing_search,
|
||||||
"duckduckgo": duckduckgo_search,
|
"duckduckgo": duckduckgo_search,
|
||||||
|
"metaphor": metaphor_search,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -72,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
|
||||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
|
|||||||
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
|||||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||||
pprint(ids)
|
pprint(ids)
|
||||||
elif r == 2: # search docs
|
elif r == 2: # search docs
|
||||||
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
|
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||||
pprint(docs)
|
pprint(docs)
|
||||||
if r == 3: # delete docs
|
if r == 3: # delete docs
|
||||||
logger.warning(f"清除 {vs_name} by {name}")
|
logger.warning(f"清除 {vs_name} by {name}")
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
|
||||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||||
get_httpx_client, get_model_worker_config)
|
get_httpx_client, get_model_worker_config)
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ def list_running_models(
|
|||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(controller_address + "/list_models")
|
r = client.post(controller_address + "/list_models")
|
||||||
models = r.json()["models"]
|
models = r.json()["models"]
|
||||||
data = {m: get_model_worker_config(m) for m in models}
|
data = {m: get_model_config(m).data for m in models}
|
||||||
return BaseResponse(data=data)
|
return BaseResponse(data=data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'{e.__class__.__name__}: {e}',
|
logger.error(f'{e.__class__.__name__}: {e}',
|
||||||
@ -52,7 +52,6 @@ def get_model_config(
|
|||||||
获取LLM模型配置项(合并后的)
|
获取LLM模型配置项(合并后的)
|
||||||
'''
|
'''
|
||||||
config = get_model_worker_config(model_name=model_name)
|
config = get_model_worker_config(model_name=model_name)
|
||||||
|
|
||||||
# 删除ONLINE_MODEL配置中的敏感信息
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
del_keys = set(["worker_class"])
|
del_keys = set(["worker_class"])
|
||||||
for k in config:
|
for k in config:
|
||||||
|
|||||||
@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
|
|||||||
"chat": {
|
"chat": {
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
"random_threshold": 0.5,
|
"random_threshold": 0.5,
|
||||||
"max_tokens": 2048,
|
"max_tokens": None,
|
||||||
"auditing": "default",
|
"auditing": "default",
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
# import os
|
# import os
|
||||||
# import sys
|
# import sys
|
||||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import hashlib
|
import hashlib
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
|
from server.utils import get_model_worker_config, get_httpx_client
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal
|
from typing import List, Literal, Dict
|
||||||
from configs import TEMPERATURE
|
from configs import TEMPERATURE
|
||||||
|
|
||||||
|
|
||||||
@ -20,29 +20,29 @@ def calculate_md5(input_string):
|
|||||||
return encrypted
|
return encrypted
|
||||||
|
|
||||||
|
|
||||||
def do_request():
|
def request_baichuan_api(
|
||||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
messages: List[Dict[str, str]],
|
||||||
api_key = ""
|
api_key: str = None,
|
||||||
secret_key = ""
|
secret_key: str = None,
|
||||||
|
version: str = "Baichuan2-53B",
|
||||||
|
temperature: float = TEMPERATURE,
|
||||||
|
model_name: str = "baichuan-api",
|
||||||
|
):
|
||||||
|
config = get_model_worker_config(model_name)
|
||||||
|
api_key = api_key or config.get("api_key")
|
||||||
|
secret_key = secret_key or config.get("secret_key")
|
||||||
|
version = version or config.get("version")
|
||||||
|
|
||||||
|
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||||
data = {
|
data = {
|
||||||
"model": "Baichuan2-53B",
|
"model": version,
|
||||||
"messages": [
|
"messages": messages,
|
||||||
{
|
"parameters": {"temperature": temperature}
|
||||||
"role": "user",
|
|
||||||
"content": "世界第一高峰是"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"parameters": {
|
|
||||||
"temperature": 0.1,
|
|
||||||
"top_k": 10
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
json_data = json.dumps(data)
|
json_data = json.dumps(data)
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
signature = calculate_md5(secret_key + json_data + str(time_stamp))
|
signature = calculate_md5(secret_key + json_data + str(time_stamp))
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": "Bearer " + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
@ -52,18 +52,17 @@ def do_request():
|
|||||||
"X-BC-Sign-Algo": "MD5",
|
"X-BC-Sign-Algo": "MD5",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(url, data=json_data, headers=headers)
|
with get_httpx_client() as client:
|
||||||
|
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||||
if response.status_code == 200:
|
for line in response.iter_lines():
|
||||||
print("请求成功!")
|
if not line.strip():
|
||||||
print("响应header:", response.headers)
|
continue
|
||||||
print("响应body:", response.text)
|
resp = json.loads(line)
|
||||||
else:
|
yield resp
|
||||||
print("请求失败,状态码:", response.status_code)
|
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanWorker(ApiModelWorker):
|
class BaiChuanWorker(ApiModelWorker):
|
||||||
BASE_URL = "https://api.baichuan-ai.com/v1/chat"
|
BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||||
SUPPORT_MODELS = ["Baichuan2-53B"]
|
SUPPORT_MODELS = ["Baichuan2-53B"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,40 +94,22 @@ class BaiChuanWorker(ApiModelWorker):
|
|||||||
self.secret_key = config.get("secret_key")
|
self.secret_key = config.get("secret_key")
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
data = {
|
super().generate_stream_gate(params)
|
||||||
"model": self.version,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": params["prompt"]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"parameters": {
|
|
||||||
"temperature": params.get("temperature",TEMPERATURE),
|
|
||||||
"top_k": params.get("top_k",1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
json_data = json.dumps(data)
|
messages = self.prompt_to_messages(params["prompt"])
|
||||||
time_stamp = int(time.time())
|
|
||||||
signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": "Bearer " + self.api_key,
|
|
||||||
"X-BC-Request-Id": "your requestId",
|
|
||||||
"X-BC-Timestamp": str(time_stamp),
|
|
||||||
"X-BC-Signature": signature,
|
|
||||||
"X-BC-Sign-Algo": "MD5",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(self.BASE_URL, data=json_data, headers=headers)
|
text = ""
|
||||||
|
for resp in request_baichuan_api(messages=messages,
|
||||||
if response.status_code == 200:
|
api_key=self.api_key,
|
||||||
resp = eval(response.text)
|
secret_key=self.secret_key,
|
||||||
|
version=self.version,
|
||||||
|
temperature=params.get("temperature")):
|
||||||
|
if resp["code"] == 0:
|
||||||
|
text += resp["data"]["messages"][-1]["content"]
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{
|
{
|
||||||
"error_code": resp["code"],
|
"error_code": resp["code"],
|
||||||
"text": resp["data"]["messages"][-1]["content"]
|
"text": text
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False
|
||||||
).encode() + b"\0"
|
).encode() + b"\0"
|
||||||
@ -141,8 +122,6 @@ class BaiChuanWorker(ApiModelWorker):
|
|||||||
ensure_ascii=False
|
ensure_ascii=False
|
||||||
).encode() + b"\0"
|
).encode() + b"\0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
print("embedding")
|
print("embedding")
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
from configs.basic_config import LOG_PATH
|
from configs.basic_config import LOG_PATH
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import BaseModelWorker
|
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import fastchat
|
import fastchat
|
||||||
import threading
|
import asyncio
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
|
|||||||
worker_addr=worker_addr,
|
worker_addr=worker_addr,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self.context_len = context_len
|
self.context_len = context_len
|
||||||
|
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||||
self.init_heart_beat()
|
self.init_heart_beat()
|
||||||
|
|
||||||
def count_token(self, params):
|
def count_token(self, params):
|
||||||
@ -62,15 +63,6 @@ class ApiModelWorker(BaseModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
# workaround to make program exit with Ctrl+c
|
|
||||||
# it should be deleted after pr is merged by fastchat
|
|
||||||
def init_heart_beat(self):
|
|
||||||
self.register_to_controller()
|
|
||||||
self.heart_beat_thread = threading.Thread(
|
|
||||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
|
||||||
)
|
|
||||||
self.heart_beat_thread.start()
|
|
||||||
|
|
||||||
# help methods
|
# help methods
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
from server.utils import get_model_worker_config
|
from server.utils import get_model_worker_config
|
||||||
|
|||||||
@ -5,12 +5,11 @@ from fastapi import FastAPI
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
|
||||||
logger, log_verbose,
|
|
||||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||||
|
|
||||||
@ -40,6 +39,50 @@ def get_ChatOpenAI(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI:
|
||||||
|
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||||
|
config_models = list_config_llm_models()
|
||||||
|
if model_name in config_models.get("langchain", {}):
|
||||||
|
config = config_models["langchain"][model_name]
|
||||||
|
if model_name == "Azure-OpenAI":
|
||||||
|
model = AzureChatOpenAI(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
deployment_name=config.get("deployment_name"),
|
||||||
|
model_version=config.get("model_version"),
|
||||||
|
openai_api_type=config.get("openai_api_type"),
|
||||||
|
openai_api_base=config.get("api_base_url"),
|
||||||
|
openai_api_version=config.get("api_version"),
|
||||||
|
openai_api_key=config.get("api_key"),
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_name == "OpenAI":
|
||||||
|
model = ChatOpenAI(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
model_name=config.get("model_name"),
|
||||||
|
openai_api_base=config.get("api_base_url"),
|
||||||
|
openai_api_key=config.get("api_key"),
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
elif model_name == "Anthropic":
|
||||||
|
model = ChatAnthropic(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
model_name=config.get("model_name"),
|
||||||
|
anthropic_api_key=config.get("api_key"),
|
||||||
|
|
||||||
|
)
|
||||||
|
## TODO 支持其他的Langchain原生支持的模型
|
||||||
|
else:
|
||||||
|
## 非Langchain原生支持的模型,走Fschat封装
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
@ -53,6 +96,7 @@ def get_ChatOpenAI(
|
|||||||
openai_proxy=config.get("openai_proxy"),
|
openai_proxy=config.get("openai_proxy"),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -249,8 +293,9 @@ def MakeFastAPIOffline(
|
|||||||
redoc_favicon_url=favicon,
|
redoc_favicon_url=favicon,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 从model_config中获取模型信息
|
# 从model_config中获取模型信息
|
||||||
|
|
||||||
|
|
||||||
def list_embed_models() -> List[str]:
|
def list_embed_models() -> List[str]:
|
||||||
'''
|
'''
|
||||||
get names of configured embedding models
|
get names of configured embedding models
|
||||||
@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]:
|
|||||||
workers = list(FSCHAT_MODEL_WORKERS)
|
workers = list(FSCHAT_MODEL_WORKERS)
|
||||||
if LLM_MODEL not in workers:
|
if LLM_MODEL not in workers:
|
||||||
workers.insert(0, LLM_MODEL)
|
workers.insert(0, LLM_MODEL)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"local": MODEL_PATH["llm_model"],
|
"local": MODEL_PATH["llm_model"],
|
||||||
|
"langchain": LANGCHAIN_LLM_MODEL,
|
||||||
"online": ONLINE_LLM_MODEL,
|
"online": ONLINE_LLM_MODEL,
|
||||||
"worker": workers,
|
"worker": workers,
|
||||||
}
|
}
|
||||||
@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
|||||||
return str(path)
|
return str(path)
|
||||||
return path_str # THUDM/chatglm06b
|
return path_str # THUDM/chatglm06b
|
||||||
|
|
||||||
|
|
||||||
# 从server_config中获取服务信息
|
# 从server_config中获取服务信息
|
||||||
|
|
||||||
|
|
||||||
def get_model_worker_config(model_name: str = None) -> dict:
|
def get_model_worker_config(model_name: str = None) -> dict:
|
||||||
'''
|
'''
|
||||||
加载model worker的配置项。
|
加载model worker的配置项。
|
||||||
@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||||
|
|
||||||
# 在线模型API
|
# 在线模型API
|
||||||
|
if model_name in LANGCHAIN_LLM_MODEL:
|
||||||
|
config["langchain_model"] = True
|
||||||
|
config["worker_class"] = ""
|
||||||
|
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
config["online_api"] = True
|
config["online_api"] = True
|
||||||
if provider := config.get("provider"):
|
if provider := config.get("provider"):
|
||||||
@ -459,8 +509,9 @@ def set_httpx_config(
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
urllib.request.getproxies = _get_proxies
|
urllib.request.getproxies = _get_proxies
|
||||||
|
|
||||||
|
|
||||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
@ -568,6 +619,8 @@ def get_server_configs() -> Dict:
|
|||||||
获取configs中的原始配置项,供前端使用
|
获取configs中的原始配置项,供前端使用
|
||||||
'''
|
'''
|
||||||
from configs.kb_config import (
|
from configs.kb_config import (
|
||||||
|
DEFAULT_KNOWLEDGE_BASE,
|
||||||
|
DEFAULT_SEARCH_ENGINE,
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
|
|||||||
33
startup.py
33
startup.py
@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
controller_address:
|
controller_address:
|
||||||
worker_address:
|
worker_address:
|
||||||
|
|
||||||
|
对于Langchain支持的模型:
|
||||||
|
langchain_model:True
|
||||||
|
不会使用fschat
|
||||||
对于online_api:
|
对于online_api:
|
||||||
online_api:True
|
online_api:True
|
||||||
worker_class: `provider`
|
worker_class: `provider`
|
||||||
@ -78,23 +80,25 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
"""
|
"""
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import worker_id, logger
|
|
||||||
import argparse
|
import argparse
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
|
||||||
|
from fastchat.serve.base_model_worker import app
|
||||||
|
worker = ""
|
||||||
# 在线模型API
|
# 在线模型API
|
||||||
if worker_class := kwargs.get("worker_class"):
|
elif worker_class := kwargs.get("worker_class"):
|
||||||
from fastchat.serve.model_worker import app
|
from fastchat.serve.base_model_worker import app
|
||||||
|
|
||||||
worker = worker_class(model_names=args.model_names,
|
worker = worker_class(model_names=args.model_names,
|
||||||
controller_addr=args.controller_address,
|
controller_addr=args.controller_address,
|
||||||
worker_addr=args.worker_address)
|
worker_addr=args.worker_address)
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||||
|
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||||
# 本地模型
|
# 本地模型
|
||||||
else:
|
else:
|
||||||
from configs.model_config import VLLM_MODEL_DICT
|
from configs.model_config import VLLM_MODEL_DICT
|
||||||
@ -103,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
from fastchat.serve.vllm_worker import VLLMWorker, app
|
from fastchat.serve.vllm_worker import VLLMWorker, app
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
||||||
|
|
||||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||||
args.tokenizer_mode = 'auto'
|
args.tokenizer_mode = 'auto'
|
||||||
args.trust_remote_code= True
|
args.trust_remote_code= True
|
||||||
@ -126,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.engine_use_ray = False
|
args.engine_use_ray = False
|
||||||
args.disable_log_requests = False
|
args.disable_log_requests = False
|
||||||
|
|
||||||
# 0.2.0 vllm后要加的参数
|
# 0.2.0 vllm后要加的参数, 但是这里不需要
|
||||||
args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置,
|
args.max_model_len = None
|
||||||
args.revision = None
|
args.revision = None
|
||||||
args.quantization = None
|
args.quantization = None
|
||||||
args.max_log_len = None
|
args.max_log_len = None
|
||||||
@ -155,10 +160,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
conv_template = args.conv_template,
|
conv_template = args.conv_template,
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
# sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||||
|
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||||
|
|
||||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||||
args.max_gpu_memory = "22GiB"
|
args.max_gpu_memory = "22GiB"
|
||||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||||
@ -221,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.model_worker"].args = args
|
sys.modules["fastchat.serve.model_worker"].args = args
|
||||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||||
|
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
||||||
|
|
||||||
MakeFastAPIOffline(app)
|
MakeFastAPIOffline(app)
|
||||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||||
|
|||||||
16
tests/online_api/test_baichuan.py
Normal file
16
tests/online_api/test_baichuan.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
|
||||||
|
from server.model_workers.baichuan import request_baichuan_api
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen():
|
||||||
|
messages = [{"role": "user", "content": "hello"}]
|
||||||
|
|
||||||
|
for x in request_baichuan_api(messages):
|
||||||
|
print(type(x))
|
||||||
|
pprint(x)
|
||||||
|
assert x["code"] == 0
|
||||||
7
webui.py
7
webui.py
@ -21,13 +21,6 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if not chat_box.chat_inited:
|
|
||||||
running_models = api.list_running_models()
|
|
||||||
st.toast(
|
|
||||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
|
||||||
f"当前运行中的模型`{running_models}`, 您可以开始提问了."
|
|
||||||
)
|
|
||||||
|
|
||||||
pages = {
|
pages = {
|
||||||
"对话": {
|
"对话": {
|
||||||
"icon": "chat",
|
"icon": "chat",
|
||||||
|
|||||||
@ -3,7 +3,8 @@ from webui_pages.utils import *
|
|||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES
|
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||||
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
@ -40,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
|||||||
返回类型为(model_name, is_local_model)
|
返回类型为(model_name, is_local_model)
|
||||||
'''
|
'''
|
||||||
running_models = api.list_running_models()
|
running_models = api.list_running_models()
|
||||||
|
|
||||||
if not running_models:
|
if not running_models:
|
||||||
return "", False
|
return "", False
|
||||||
|
|
||||||
@ -50,11 +50,16 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
|||||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||||
if local_models:
|
if local_models:
|
||||||
return local_models[0], True
|
return local_models[0], True
|
||||||
|
return list(running_models)[0], False
|
||||||
return running_models[0], False
|
|
||||||
|
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest):
|
def dialogue_page(api: ApiRequest):
|
||||||
|
if not chat_box.chat_inited:
|
||||||
|
default_model = get_default_llm_model(api)[0]
|
||||||
|
st.toast(
|
||||||
|
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||||
|
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||||
|
)
|
||||||
chat_box.init_session()
|
chat_box.init_session()
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
@ -74,12 +79,13 @@ def dialogue_page(api: ApiRequest):
|
|||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
"自定义Agent问答",
|
"自定义Agent问答",
|
||||||
],
|
],
|
||||||
index=3,
|
index=0,
|
||||||
on_change=on_mode_change,
|
on_change=on_mode_change,
|
||||||
key="dialogue_mode",
|
key="dialogue_mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_change():
|
def on_llm_change():
|
||||||
|
if llm_model:
|
||||||
config = api.get_model_config(llm_model)
|
config = api.get_model_config(llm_model)
|
||||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
@ -91,15 +97,17 @@ def dialogue_page(api: ApiRequest):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
running_models = list(api.list_running_models())
|
running_models = list(api.list_running_models())
|
||||||
|
running_models += LANGCHAIN_LLM_MODEL.keys()
|
||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||||
for m in worker_models:
|
for m in worker_models:
|
||||||
if m not in running_models and m != "default":
|
if m not in running_models and m != "default":
|
||||||
available_models.append(m)
|
available_models.append(m)
|
||||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
|
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
||||||
if not v.get("provider") and k not in running_models:
|
if not v.get("provider") and k not in running_models:
|
||||||
print(k, v)
|
available_models.append(k)
|
||||||
|
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
||||||
@ -111,7 +119,8 @@ def dialogue_page(api: ApiRequest):
|
|||||||
key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not api.get_model_config(llm_model).get("online_api")
|
and not llm_model in config_models.get("online", {})
|
||||||
|
and not llm_model in config_models.get("langchain", {})
|
||||||
and llm_model not in running_models):
|
and llm_model not in running_models):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
prev_model = st.session_state.get("prev_llm_model")
|
prev_model = st.session_state.get("prev_llm_model")
|
||||||
@ -156,9 +165,13 @@ def dialogue_page(api: ApiRequest):
|
|||||||
if dialogue_mode == "知识库问答":
|
if dialogue_mode == "知识库问答":
|
||||||
with st.expander("知识库配置", True):
|
with st.expander("知识库配置", True):
|
||||||
kb_list = api.list_knowledge_bases()
|
kb_list = api.list_knowledge_bases()
|
||||||
|
index = 0
|
||||||
|
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
||||||
|
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
||||||
selected_kb = st.selectbox(
|
selected_kb = st.selectbox(
|
||||||
"请选择知识库:",
|
"请选择知识库:",
|
||||||
kb_list,
|
kb_list,
|
||||||
|
index=index,
|
||||||
on_change=on_kb_change,
|
on_change=on_kb_change,
|
||||||
key="selected_kb",
|
key="selected_kb",
|
||||||
)
|
)
|
||||||
@ -167,11 +180,15 @@ def dialogue_page(api: ApiRequest):
|
|||||||
|
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
|
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
||||||
|
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
|
||||||
|
else:
|
||||||
|
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
|
||||||
with st.expander("搜索引擎配置", True):
|
with st.expander("搜索引擎配置", True):
|
||||||
search_engine = st.selectbox(
|
search_engine = st.selectbox(
|
||||||
label="请选择搜索引擎",
|
label="请选择搜索引擎",
|
||||||
options=search_engine_list,
|
options=search_engine_list,
|
||||||
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
|
index=index,
|
||||||
)
|
)
|
||||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||||
|
|
||||||
@ -210,9 +227,9 @@ def dialogue_page(api: ApiRequest):
|
|||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||||
if not any(agent in llm_model for agent in support_agent):
|
if not any(agent in llm_model for agent in support_agent):
|
||||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
|
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||||
for d in api.agent_chat(prompt,
|
for d in api.agent_chat(prompt,
|
||||||
history=history,
|
history=history,
|
||||||
|
|||||||
@ -245,7 +245,7 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -278,7 +278,7 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -308,7 +308,7 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -340,7 +340,7 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -378,7 +378,7 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user