This commit is contained in:
imClumsyPanda 2023-11-10 16:32:21 +08:00
commit 3e20e47328
40 changed files with 510 additions and 208 deletions

9
.gitignore vendored
View File

@ -2,7 +2,10 @@
*.log.*
*.bak
logs
/knowledge_base/
/knowledge_base/*
!/knowledge_base/samples
/knowledge_base/samples/vector_store
/configs/*.py
.vscode/
@ -167,3 +170,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.pytest_cache
.DS_Store

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "knowledge_base/samples/content/wiki"]
path = knowledge_base/samples/content/wiki
url = https://github.com/chatchat-space/Langchain-Chatchat.wiki.git

View File

@ -1,12 +1,12 @@
from server.utils import get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE
from configs.model_config import LLM_MODELS, TEMPERATURE
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE)
human_prompt = "{input}"

View File

@ -2,6 +2,7 @@ import logging
import os
import langchain
# 是否显示详细日志
log_verbose = False
langchain.verbose = False

View File

@ -3,7 +3,7 @@ import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库类型。可选faiss, milvus(离线) & zilliz(在线), pg.
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
@ -56,7 +56,10 @@ KB_INFO = {
"知识库名称": "知识库介绍",
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
@ -86,7 +89,15 @@ kbs_config = {
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
},
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。

View File

@ -1,96 +1,13 @@
import os
# 可以指定一个绝对路径统一存放所有的Embedding和LLM模型。
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。
# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。
MODEL_ROOT_PATH = ""
# 在以下字典中修改属性值以指定本地embedding模型存储位置。支持3种设置方法
# 1、将对应的值修改为模型绝对路径
# 2、不修改此处的值以 text2vec 为例):
# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
# - text2vec
# - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = {
"embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese",
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"text-embedding-ada-002": "your OPENAI_API_KEY",
},
# TODO: add all supported llm models
"llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
"gpt2": "gpt2",
"gpt2-xl": "gpt2-xl",
"gpt-j-6b": "EleutherAI/gpt-j-6b",
"gpt4all-j": "nomic-ai/gpt4all-j",
"gpt-neox-20b": "EleutherAI/gpt-neox-20b",
"pythia-12b": "EleutherAI/pythia-12b",
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
},
}
# 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型bge-large-zh-v1.5
EMBEDDING_MODEL = "m3e-base" # bge-large-zh
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
@ -99,9 +16,11 @@ EMBEDDING_DEVICE = "auto"
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
# LLM 名称
LLM_MODEL = "chatglm2-6b"
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODEL)
# 要运行的 LLM 名称,可以包括本地模型和在线模型。
# 第一个将作为 API 和 WEBUI 的默认模型
LLM_MODELS = ["chatglm2-6b", "zhipu-api", "openai-api"]
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
Agent_MODEL = None
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
@ -111,7 +30,6 @@ LLM_DEVICE = "auto"
HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
MAX_TOKENS = None
# LLM通用对话参数
@ -195,8 +113,98 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"provider": "AzureWorker",
},
}
# 在以下字典中修改属性值以指定本地embedding模型存储位置。支持3种设置方法
# 1、将对应的值修改为模型绝对路径
# 2、不修改此处的值以 text2vec 为例):
# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
# - text2vec
# - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = {
"embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese",
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"text-embedding-ada-002": "your OPENAI_API_KEY",
},
"llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
"gpt2": "gpt2",
"gpt2-xl": "gpt2-xl",
"gpt-j-6b": "EleutherAI/gpt-j-6b",
"gpt4all-j": "nomic-ai/gpt4all-j",
"gpt-neox-20b": "EleutherAI/gpt-neox-20b",
"pythia-12b": "EleutherAI/pythia-12b",
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8" # 确保已经安装了auto-gptq optimum flash-atten
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4" # 确保已经安装了auto-gptq optimum flash-atten
},
}
# 通常情况下不需要更改以下内容
# nltk 模型存储路径

View File

@ -31,8 +31,7 @@ FSCHAT_OPENAI_API = {
# fastchat model_worker server
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里LLM_MODEL会自动添加
# 在启动startup.py时可用通过`--model-name xxxx yyyy`指定模型不指定则为LLM_MODELS
FSCHAT_MODEL_WORKERS = {
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": {
@ -58,7 +57,7 @@ FSCHAT_MODEL_WORKERS = {
# "awq_ckpt": None,
# "awq_wbits": 16,
# "awq_groupsize": -1,
# "model_names": [LLM_MODEL],
# "model_names": LLM_MODELS,
# "conv_template": None,
# "limit_worker_concurrency": 5,
# "stream_interval": 2,
@ -96,30 +95,31 @@ FSCHAT_MODEL_WORKERS = {
# "device": "cpu",
# },
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
#以下配置可以不用修改在model_config中设置启动的模型
"zhipu-api": {
"port": 21001,
},
# "minimax-api": {
# "port": 21002,
# },
# "xinghuo-api": {
# "port": 21003,
# },
# "qianfan-api": {
# "port": 21004,
# },
# "fangzhou-api": {
# "port": 21005,
# },
# "qwen-api": {
# "port": 21006,
# },
# "baichuan-api": {
# "port": 21007,
# },
# "azure-api": {
# "port": 21008,
# },
"minimax-api": {
"port": 21002,
},
"xinghuo-api": {
"port": 21003,
},
"qianfan-api": {
"port": 21004,
},
"fangzhou-api": {
"port": 21005,
},
"qwen-api": {
"port": 21006,
},
"baichuan-api": {
"port": 21007,
},
"azure-api": {
"port": 21008,
},
}
# fastchat multi model worker server

29
docs/ES部署指南.md Normal file
View File

@ -0,0 +1,29 @@
# 实现基于ES的数据插入、检索、删除、更新
```shell
author: 唐国梁Tommy
e-mail: flytang186@qq.com
如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。
```
## 第1步ES docker部署
```shell
docker network create elastic
docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2
```
### 第2步Kibana docker部署
**注意Kibana版本与ES保持一致**
```shell
docker pull docker.elastic.co/kibana/kibana:{version}
docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version}
```
### 第3步核心代码
```shell
1. 核心代码路径
server/knowledge_base/kb_service/es_kb_service.py
2. 需要在 configs/model_config.py 中 配置 ES参数IP PORT
```

@ -0,0 +1 @@
Subproject commit b705cf80e4150cb900c77b343f0f9c62ec9a0278

View File

@ -53,7 +53,7 @@ vllm>=0.2.0; sys_platform == "linux"
# WebUI requirements
streamlit>=1.26.0
streamlit~=1.27.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.11

View File

@ -41,7 +41,7 @@ dashscope>=1.10.0 # qwen
numpy~=1.24.4
pandas~=2.0.3
streamlit>=1.26.0
streamlit~=1.27.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.11

View File

@ -1,6 +1,6 @@
# WebUI requirements
streamlit>=1.26.0
streamlit~=1.27.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.11

View File

@ -5,7 +5,7 @@ from langchain.agents import AgentExecutor, LLMSingleActionAgent, initialize_age
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional, Dict
@ -26,7 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
"content": "使用天气查询工具查询到今天北京多云10-14摄氏度东北风2级易感冒"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
@ -38,7 +38,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = CustomAsyncIteratorCallbackHandler()

View File

@ -1,6 +1,6 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY
from configs import LLM_MODELS, TEMPERATURE, SAVE_CHAT_HISTORY
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -22,7 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
@ -32,7 +32,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()

View File

@ -1,6 +1,6 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE
from configs import LLM_MODELS, TEMPERATURE
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -13,7 +13,7 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
@ -23,7 +23,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
#todo 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:

View File

@ -1,6 +1,6 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
@ -30,7 +30,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
@ -45,7 +45,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
async def knowledge_base_chat_iterator(query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()

View File

@ -1,7 +1,7 @@
from fastapi.responses import StreamingResponse
from typing import List, Optional
import openai
from configs import LLM_MODEL, logger, log_verbose
from configs import LLM_MODELS, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
@ -12,7 +12,7 @@ class OpenAiMessage(BaseModel):
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
model: str = LLM_MODELS[0]
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1

View File

@ -1,7 +1,7 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from fastapi.responses import StreamingResponse
@ -126,7 +126,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
@ -144,7 +144,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()

View File

@ -51,6 +51,10 @@ def add_docs_to_db(session,
将某知识库某文件对应的所有Document信息添加到数据库
doc_infos形式[{"id": str, "metadata": dict}, ...]
'''
#! 这里会出现doc_infos为None的情况需要进一步排查
if doc_infos is None:
print("输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None")
return False
for d in doc_infos:
obj = FileDocModel(
kb_name=kb_name,

View File

@ -119,7 +119,7 @@ class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = device or embedding_device()
device = embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)

View File

@ -46,6 +46,7 @@ class SupportedVSType:
DEFAULT = 'default'
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
class KBService(ABC):
@ -274,19 +275,20 @@ class KBServiceFactory:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str,
default_vs_type: SupportedVSType = SupportedVSType.FAISS,
default_embed_model: str = EMBEDDING_MODEL,
) -> KBService:
def get_service_by_name(kb_name: str) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None: # faiss knowledge base not in db
vs_type = default_vs_type
if embed_model is None:
embed_model = default_embed_model
if _ is None: # kb not in db, just return None
return None
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
@ -331,6 +333,9 @@ def get_kb_details() -> List[Dict]:
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
return []
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}

View File

@ -0,0 +1,205 @@
#!/user/bin/env python3
"""
File_Name: es_kb_service.py
Author: TangGuoLiang
Email: 896165277@qq.com
Created: 2023-09-05
"""
from typing import List
import os
import shutil
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch
from configs import logger
from configs import kbs_config
class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = self.kb_path.split("/")[-1]
self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'')
self.password = kbs_config[self.vs_type()].get("password",'')
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
try:
# ES python客户端连接仅连接
if self.user != "" and self.password != "":
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}",
basic_auth=(self.user,self.password))
else:
logger.warning("ES未配置用户名和密码")
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
self.es_client_python.indices.create(index=self.index_name)
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
except Exception as e:
logger.error(f"Error 发生 : {e}")
try:
# langchain ES 连接、创建索引
if self.user != "" and self.password != "":
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
es_user=self.user,
es_password=self.password
)
else:
logger.warning("ES未配置用户名和密码")
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
)
except ConnectionError:
print("### 连接到 Elasticsearch 失败!")
logger.error("### 连接到 Elasticsearch 失败!")
except Exception as e:
logger.error(f"Error 发生 : {e}")
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self):
if os.path.exists(self.doc_path):
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
def vs_type(self) -> str:
return SupportedVSType.ES
def _load_es(self, docs, embed_model):
# 将docs写入到ES中
try:
# 连接 + 同时写入文档
if self.user != "" and self.password != "":
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False,
es_user=self.user,
es_password=self.password
)
else:
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False)
except ConnectionError as ce:
print(ce)
print("连接到 Elasticsearch 失败!")
logger.error("连接到 Elasticsearch 失败!")
except Exception as e:
logger.error(f"Error 发生 : {e}")
print(e)
def do_search(self, query:str, top_k: int, score_threshold: float):
# 文本相似性检索
docs = self.db_init.similarity_search_with_score(query=query,
k=top_k)
return docs
def do_delete_doc(self, kb_file, **kwargs):
if self.es_client_python.indices.exists(index=self.index_name):
# 从向量数据库中删除索引(文档名称是Keyword)
query = {
"query": {
"term": {
"metadata.source.keyword": kb_file.filepath
}
}
}
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
if len(delete_list) == 0:
return None
else:
for doc_id in delete_list:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error("ES Docs Delete Error!")
# self.db_init.delete(ids=delete_list)
#self.es_client_python.indices.refresh(index=self.index_name)
def do_add_doc(self, docs: List[Document], **kwargs):
'''向知识库添加文件'''
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
print("*"*100)
self._load_es(docs=docs, embed_model=self.embeddings_model)
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
print("写入数据成功.")
print("*"*100)
if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source")
query = {
"query": {
"term": {
"metadata.source.keyword": file_path
}
}
}
search_results = self.es_client_python.search(body=query)
if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0")
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
return info_docs
def do_clear_vs(self):
"""从知识库删除全部向量"""
if self.es_client_python.indices.exists(index=self.kb_name):
self.es_client_python.indices.delete(index=self.kb_name)
def do_drop_kb(self):
"""删除知识库"""
# self.kb_file: 知识库路径
if os.path.exists(self.kb_path):
shutil.rmtree(self.kb_path)

View File

@ -48,7 +48,10 @@ class FaissKBService(KBService):
def do_drop_kb(self):
self.clear_vs()
shutil.rmtree(self.kb_path)
try:
shutil.rmtree(self.kb_path)
except Exception:
...
def do_search(self,
query: str,
@ -90,8 +93,11 @@ class FaissKBService(KBService):
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop((self.kb_name, self.vector_name))
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
try:
shutil.rmtree(self.vs_path)
except Exception:
...
os.makedirs(self.vs_path, exist_ok=True)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):

View File

@ -157,7 +157,7 @@ def prune_db_docs(kb_names: List[str]):
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
@ -175,7 +175,7 @@ def prune_folder_files(kb_names: List[str]):
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))

View File

@ -7,7 +7,7 @@ from configs import (
logger,
log_verbose,
text_splitter_dict,
LLM_MODEL,
LLM_MODELS,
TEXT_SPLITTER_NAME,
)
import importlib
@ -57,7 +57,8 @@ def list_files_from_folder(kb_name: str):
for root, _, files in os.walk(doc_path):
tail = os.path.basename(root).lower()
if (tail.startswith("temp")
or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹
or tail.startswith("tmp")
or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹
continue
for file in files:
if file.startswith("~$"): # 跳过 ~$ 开头的文件
@ -192,7 +193,7 @@ def make_text_splitter(
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL,
llm_model: str = LLM_MODELS[0],
):
"""
根据参数获取特定的分词器

View File

@ -1,5 +1,5 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from copy import deepcopy
@ -65,7 +65,7 @@ def get_model_config(
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
@ -89,8 +89,8 @@ def stop_llm_model(
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''

View File

@ -59,16 +59,22 @@ class QwenWorker(ApiModelWorker):
import dashscope
params.load_config(self.model_names[0])
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=params.texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {"code": resp["status_code"], "msg": resp.message}
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
return {"code": 200, "data": embeddings}
result = []
i = 0
while i < len(params.texts):
texts = params.texts[i:i+25]
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {"code": resp["status_code"], "msg": resp.message}
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings
i += 25
return {"code": 200, "data": result}
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -4,7 +4,7 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
@ -345,8 +345,7 @@ def list_config_llm_models() -> Dict[str, Dict]:
return [(model_name, config_type), ...]
'''
workers = list(FSCHAT_MODEL_WORKERS)
if LLM_MODEL not in workers:
workers.insert(0, LLM_MODEL)
return {
"local": MODEL_PATH["llm_model"],
"online": ONLINE_LLM_MODEL,
@ -431,7 +430,7 @@ def fschat_controller_address() -> str:
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"]
if host == "0.0.0.0":
@ -660,7 +659,7 @@ def get_server_configs() -> Dict:
TEXT_SPLITTER_NAME,
)
from configs.model_config import (
LLM_MODEL,
LLM_MODELS,
HISTORY_LEN,
TEMPERATURE,
)

View File

@ -22,7 +22,7 @@ from configs import (
LOG_PATH,
log_verbose,
logger,
LLM_MODEL,
LLM_MODELS,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
@ -359,7 +359,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
def run_model_worker(
model_name: str = LLM_MODEL,
model_name: str = LLM_MODELS[0],
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
@ -496,7 +496,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. "
"specify --model-name if not using default LLM_MODEL",
"specify --model-name if not using default LLM_MODELS",
dest="model_worker",
)
parser.add_argument(
@ -504,7 +504,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-name",
type=str,
nargs="+",
default=[LLM_MODEL],
default=LLM_MODELS,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
@ -568,7 +568,7 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
models = [LLM_MODEL]
models = LLM_MODELS
if args and args.model_name:
models = args.model_name
@ -694,8 +694,8 @@ async def start_main_server():
processes["model_worker"][model_name] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if (config.get("online_api")
and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS):

View File

@ -1,7 +1,7 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from configs import LLM_MODEL, TEMPERATURE
from configs import LLM_MODELS, TEMPERATURE
from server.utils import get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.agents import LLMSingleActionAgent, AgentExecutor
@ -10,7 +10,7 @@ from langchain.memory import ConversationBufferWindowMemory
memory = ConversationBufferWindowMemory(k=5)
model = get_ChatOpenAI(
model_name=LLM_MODEL,
model_name=LLM_MODELS[0],
temperature=TEMPERATURE,
)
from server.agent.custom_template import CustomOutputParser, prompt

View File

@ -6,7 +6,6 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL
from server.utils import api_address, get_model_worker_config
from pprint import pprint

View File

@ -33,7 +33,7 @@ def test_recreate_vs():
folder2db([kb_name], "recreate_vs")
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb.exists()
assert kb and kb.exists()
files = kb.list_files()
print(files)

View File

@ -8,7 +8,7 @@ from pathlib import Path
from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
LLM_MODEL,
LLM_MODELS,
TEMPERATURE,
SCORE_THRESHOLD,
CHUNK_SIZE,
@ -259,7 +259,7 @@ class ApiRequest:
self,
messages: List[Dict],
stream: bool = True,
model: str = LLM_MODEL,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
**kwargs: Any,
@ -291,7 +291,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
@ -321,7 +321,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
@ -353,7 +353,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
@ -391,7 +391,7 @@ class ApiRequest:
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
@ -677,9 +677,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
def get_default_llm_model(self) -> Tuple[str, bool]:
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
从服务器上获取当前运行的LLM模型
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
返回类型为model_name, is_local_model
'''
def ret_sync():
@ -687,26 +688,42 @@ class ApiRequest:
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
model = ""
for m in LLM_MODELS:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
if local_first and not is_local:
continue
else:
model = m
break
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
async def ret_async():
running_models = await self.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
model = ""
for m in LLM_MODELS:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
if local_first and not is_local:
continue
else:
model = m
break
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
if self._use_async:
return ret_async()