修改模型配置方式,所有模型以 openai 兼容框架的形式接入,chatchat 自身不再加载模型。

改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码
修改依赖文件,移除 torch transformers 等重依赖
暂时移出对 loom 的集成

后续:
1、优化目录结构
2、检查合并中有无被覆盖的 0.2.10 内容
This commit is contained in:
liunux4odoo 2024-02-08 11:30:39 +08:00
parent 988a0e6ad2
commit 5d422ca9a1
41 changed files with 757 additions and 1142 deletions

View File

@ -1,8 +1,8 @@
import logging import logging
import os import os
from pathlib import Path
import langchain import langchain
import tempfile
import shutil
# 是否显示详细日志 # 是否显示详细日志
@ -11,6 +11,16 @@ langchain.verbose = False
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# 用户数据根目录
DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data")
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# 日志格式 # 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger() logger = logging.getLogger()
@ -19,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径 # 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") LOG_PATH = os.path.join(DATA_PATH, "logs")
if not os.path.exists(LOG_PATH): if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH) os.mkdir(LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置 # 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media") MEDIA_PATH = os.path.join(DATA_PATH, "media")
if not os.path.exists(MEDIA_PATH): if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH) os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image")) os.mkdir(os.path.join(MEDIA_PATH, "image"))
@ -32,9 +42,6 @@ if not os.path.exists(MEDIA_PATH):
os.mkdir(os.path.join(MEDIA_PATH, "video")) os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话 # 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
if os.path.isdir(BASE_TEMP_DIR): if not os.path.exists(BASE_TEMP_DIR):
shutil.rmtree(BASE_TEMP_DIR) os.mkdir(BASE_TEMP_DIR)
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
MEDIA_PATH = None

View File

@ -1,6 +1,6 @@
import os import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples" DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es # 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
@ -44,6 +44,7 @@ KB_INFO = {
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH): if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH) os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。 # 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。 # 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")

View File

@ -1,10 +1,25 @@
import os import os
MODEL_ROOT_PATH = ""
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh # 默认选用的 LLM 名称
EMBEDDING_DEVICE = "auto" DEFAULT_LLM_MODEL = "chatglm3-6b"
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output" # 默认选用的 Embedding 名称
DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5"
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
Agent_MODEL = None
# 历史对话轮数
HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
MAX_TOKENS = None
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
SUPPORT_AGENT_MODELS = [ SUPPORT_AGENT_MODELS = [
"chatglm3-6b", "chatglm3-6b",
@ -12,97 +27,103 @@ SUPPORT_AGENT_MODELS = [
"Qwen-14B-Chat", "Qwen-14B-Chat",
"Qwen-7B-Chat", "Qwen-7B-Chat",
] ]
LLM_MODEL_CONFIG = { LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": { "preprocess_model": {
# "Mixtral-8x7B-v0.1": { DEFAULT_LLM_MODEL: {
# "temperature": 0.01,
# "max_tokens": 5,
# "prompt_name": "default",
# "callbacks": False
# },
"chatglm3-6b": {
"temperature": 0.05, "temperature": 0.05,
"max_tokens": 4096, "max_tokens": 4096,
"history_len": 100,
"prompt_name": "default", "prompt_name": "default",
"callbacks": False "callbacks": False
}, },
}, },
"llm_model": { "llm_model": {
# "Mixtral-8x7B-v0.1": { DEFAULT_LLM_MODEL: {
# "temperature": 0.9, "temperature": 0.9,
# "max_tokens": 4000,
# "history_len": 5,
# "prompt_name": "default",
# "callbacks": True
# },
"chatglm3-6b": {
"temperature": 0.05,
"max_tokens": 4096, "max_tokens": 4096,
"prompt_name": "default",
"history_len": 10, "history_len": 10,
"prompt_name": "default",
"callbacks": True "callbacks": True
}, },
}, },
"action_model": { "action_model": {
# "Qwen-14B-Chat": { DEFAULT_LLM_MODEL: {
# "temperature": 0.05, "temperature": 0.01,
# "max_tokens": 4096,
# "prompt_name": "qwen",
# "callbacks": True
# },
"chatglm3-6b": {
"temperature": 0.05,
"max_tokens": 4096, "max_tokens": 4096,
"prompt_name": "ChatGLM3", "prompt_name": "ChatGLM3",
"callbacks": True "callbacks": True
}, },
# "zhipu-api": { },
# "temperature": 0.01,
# "max_tokens": 4096,
# "prompt_name": "ChatGLM3",
# "callbacks": True
# }
},
"postprocess_model": { "postprocess_model": {
"zhipu-api": { DEFAULT_LLM_MODEL: {
"temperature": 0.01, "temperature": 0.01,
"max_tokens": 4096, "max_tokens": 4096,
"prompt_name": "default", "prompt_name": "default",
"callbacks": True "callbacks": True
} }
}, },
"image_model": {
"sd-turbo": {
"size": "256*256",
}
},
"multimodal_model": {
"qwen-vl": {}
},
} }
# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
MODEL_PLATFORMS = [
{
"platform_name": "openai-api",
"platform_type": "openai",
"llm_models": [
"gpt-3.5-turbo",
],
"api_base_url": "https://api.openai.com/v1",
"api_key": "sk-",
"api_proxy": "",
},
MODEL_PATH = { {
"embed_model": { "platform_name": "xinference",
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "platform_type": "xinference",
"ernie-base": "nghuyong/ernie-3.0-base-zh", "llm_models": [
"text2vec-base": "shibing624/text2vec-base-chinese", "chatglm3-6b",
"text2vec": "GanymedeNil/text2vec-large-chinese", ],
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", "embed_models": [
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", "bge-large-zh-v1.5",
"text2vec-multilingual": "shibing624/text2vec-base-multilingual", ],
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", "image_models": [
"m3e-small": "moka-ai/m3e-small", "sd-turbo",
"m3e-base": "moka-ai/m3e-base", ],
"m3e-large": "moka-ai/m3e-large", "multimodal_models": [
"bge-small-zh": "BAAI/bge-small-zh", "qwen-vl",
"bge-base-zh": "BAAI/bge-base-zh", ],
"bge-large-zh": "BAAI/bge-large-zh", "api_base_url": "http://127.0.0.1:9997/v1",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "api_key": "EMPTY",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", },
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
"text-embedding-ada-002": "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",
}
}
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") {
"platform_name": "oneapi",
"platform_type": "oneapi",
"api_key": "",
"llm_models": [
"chatglm3-6b",
],
},
LOOM_CONFIG = "./loom.yaml" {
OPENAI_KEY = None "platform_name": "loom",
OPENAI_PROXY = None "platform_type": "loom",
"api_key": "",
"llm_models": [
"chatglm3-6b",
],
},
]
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")

View File

@ -98,7 +98,25 @@ PROMPT_TEMPLATES = {
'Begin!\n\n' 'Begin!\n\n'
'Question: {input}\n\n' 'Question: {input}\n\n'
'{agent_scratchpad}\n\n', '{agent_scratchpad}\n\n',
"structured-chat-agent":
'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n'
'{tools}\n\n'
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n'
'Valid "action" values: "Final Answer" or {tool_names}\n\n'
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
'```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n'
'Follow this format:\n\n'
'Question: input question to answer\n'
'Thought: consider previous and subsequent steps\n'
'Action:\n```\n$JSON_BLOB\n```\n'
'Observation: action result\n'
'... (repeat Thought/Action/Observation N times)\n'
'Thought: I know what to respond\n'
'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n'
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n'
'{input}\n\n'
'{agent_scratchpad}\n\n'
# '(reminder to respond in a JSON blob no matter what)'
}, },
"postprocess_model": { "postprocess_model": {
"default": "{{input}}", "default": "{{input}}",
@ -130,7 +148,7 @@ TOOL_CONFIG = {
"bing": { "bing": {
"result_len": 3, "result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search", "bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "680a39347d7242c5bd2d7a9576a125b7", "bing_key": "",
}, },
"metaphor": { "metaphor": {
"result_len": 3, "result_len": 3,
@ -184,4 +202,8 @@ TOOL_CONFIG = {
"device": "cuda:2" "device": "cuda:2"
}, },
"text2images": {
"use": False,
},
} }

View File

@ -1,5 +1,5 @@
import sys import sys
from configs.model_config import LLM_DEVICE
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0 HTTPX_DEFAULT_TIMEOUT = 300.0
@ -11,10 +11,11 @@ OPEN_CROSS_DOMAIN = True
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host # 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1" DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
# webui.py server # webui.py server
WEBUI_SERVER = { WEBUI_SERVER = {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 7870, "port": 8501,
} }
# api.py server # api.py server

View File

View File

@ -1,79 +0,0 @@
'''
该功能是为了将关键词加入到embedding模型中以便于在embedding模型中进行关键词的embedding
该功能的实现是通过修改embedding模型的tokenizer来实现的
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效输出后的模型保存在原本模型
感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR
保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳
'''
import sys
sys.path.append("..")
import os
import torch
from datetime import datetime
from configs import (
MODEL_PATH,
EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE,
)
from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer
from langchain_core._api import deprecated
@deprecated(
since="0.3.0",
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def get_keyword_embedding(bert_model, tokenizer, key_words):
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
input_ids = tokenizer_output['input_ids']
input_ids = input_ids[:, 1:-1]
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
keyword_embedding = torch.mean(keyword_embedding, 1)
return keyword_embedding
def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
key_words = []
with open(keyword_file, "r") as f:
for line in f:
key_words.append(line.strip())
st_model = SentenceTransformer(model_name)
key_words_len = len(key_words)
word_embedding_model = st_model._first_module()
bert_model = word_embedding_model.auto_model
tokenizer = word_embedding_model.tokenizer
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
embedding_weight = bert_model.embeddings.word_embeddings.weight
embedding_weight_len = len(embedding_weight)
tokenizer.add_tokens(key_words)
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
embedding_weight = bert_model.embeddings.word_embeddings.weight
with torch.no_grad():
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
if output_model_path:
os.makedirs(output_model_path, exist_ok=True)
word_embedding_model.save(output_model_path)
safetensors_file = os.path.join(output_model_path, "model.safetensors")
metadata = {'format': 'pt'}
save_model(bert_model, safetensors_file, metadata)
print("save model to {}".format(output_model_path))
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
keyword_file = os.path.join(path)
model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
model_parent_directory = os.path.dirname(model_name)
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path)

View File

@ -1,3 +0,0 @@
Langchain-Chatchat
数据科学与大数据技术
人工智能与先进计算

View File

@ -2,9 +2,7 @@ import sys
sys.path.append(".") sys.path.append(".")
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files) folder2db, prune_db_docs, prune_folder_files)
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime from datetime import datetime
@ -19,7 +17,7 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help=(''' help=('''
recreate vector store. recreate vector store.
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed. use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/DEFAULT_EMBEDDING_MODEL changed.
''' '''
) )
) )
@ -87,7 +85,7 @@ if __name__ == "__main__":
"-e", "-e",
"--embed-model", "--embed-model",
type=str, type=str,
default=EMBEDDING_MODEL, default=DEFAULT_EMBEDDING_MODEL,
help=("specify embeddings model.") help=("specify embeddings model.")
) )

View File

@ -1,57 +0,0 @@
# API requirements
langchain>=0.0.350
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
openai~=1.9.0
fastapi~=0.109.0
sse_starlette==1.8.2
nltk>=3.8.1
uvicorn>=0.27.0.post1
starlette~=0.35.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu~=1.7.4
accelerate~=0.24.1
spacy~=3.7.2
PyMuPDF~=1.23.8
rapidocr_onnxruntime==1.3.8
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
numexpr~=2.8.6
strsimpy~=0.2.1
markdownify~=0.11.6
tiktoken~=0.5.2
tqdm>=4.66.1
websockets>=12.0
numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"
httpx==0.26.0
httpx_sse==0.4.0
llama-index==0.9.35
pyjwt==2.8.0
# jq==1.6.0
# beautifulsoup4~=4.12.2
# pysrt~=1.1.2
# dashscope==1.13.6
# arxiv~=2.1.0
# youtube-search~=2.1.2
# duckduckgo-search~=3.9.9
# metaphor-python~=0.1.23
# volcengine>=1.0.119
# pymilvus==2.3.6
# psycopg2==2.9.9
# pgvector>=0.2.4
# chromadb==0.4.13
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
#autoawq==0.1.8 # For Int4
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files

View File

@ -1,42 +0,0 @@
langchain==0.0.354
langchain-experimental==0.0.47
pydantic==1.10.13
fschat~=0.2.35
openai~=1.9.0
fastapi~=0.109.0
sse_starlette~=1.8.2
nltk~=3.8.1
uvicorn>=0.27.0.post1
starlette~=0.35.0
unstructured[all-docs]~=0.12.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy~=2.0.25
faiss-cpu~=1.7.4
accelerate~=0.24.1
spacy~=3.7.2
PyMuPDF~=1.23.16
rapidocr_onnxruntime~=1.3.8
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
llama-index==0.9.35
pyjwt==2.8.0
httpx==0.26.0
httpx_sse==0.4.0
dashscope==1.13.6
arxiv~=2.1.0
youtube-search~=2.1.2
duckduckgo-search~=3.9.9
metaphor-python~=0.1.23
watchdog~=3.0.0
# volcengine>=1.0.119
# pymilvus>=2.3.4
# psycopg2==2.9.9
# pgvector>=0.2.4
# chromadb==0.4.13
# jq==1.6.0
# beautifulsoup4~=4.12.2
# pysrt~=1.1.2

View File

@ -1,9 +0,0 @@
streamlit==1.30.0
streamlit-option-menu==0.3.12
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
httpx_sse==0.4.0
watchdog=s=3.0.0

View File

@ -6,7 +6,7 @@ from typing import List
import uuid import uuid
from langchain.agents import tool from langchain.agents import tool
from pydantic.v1 import BaseModel, Field from langchain.pydantic_v1 import Field
import openai import openai
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo

View File

@ -1,13 +1,9 @@
import nltk
import sys import sys
import os import os
from server.knowledge_base.kb_doc_api import update_kb_endpoint
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION, MEDIA_PATH from configs import VERSION, MEDIA_PATH
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN from configs.server_config import OPEN_CROSS_DOMAIN
import argparse import argparse
import uvicorn import uvicorn
@ -18,14 +14,11 @@ from starlette.responses import RedirectResponse
from server.chat.chat import chat from server.chat.chat import chat
from server.chat.completion import completion from server.chat.completion import completion
from server.chat.feedback import chat_feedback from server.chat.feedback import chat_feedback
from server.embeddings.core.embeddings_api import embed_texts_endpoint from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template) get_server_configs, get_prompt_template)
from typing import List, Literal from typing import List, Literal
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
async def document(): async def document():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
@ -95,11 +88,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="要求llm模型补全(通过LLMChain)", summary="要求llm模型补全(通过LLMChain)",
)(completion) )(completion)
app.post("/other/embed_texts",
tags=["Other"],
summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint)
# 媒体文件 # 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
@ -109,8 +97,7 @@ def mount_knowledge_routes(app: FastAPI):
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store, update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithVSId, update_info, search_docs, update_info)
update_docs_by_id,)
app.post("/chat/file_chat", app.post("/chat/file_chat",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
@ -146,13 +133,6 @@ def mount_knowledge_routes(app: FastAPI):
summary="搜索知识库" summary="搜索知识库"
)(search_docs) )(search_docs)
app.post("/knowledge_base/update_docs_by_id",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="直接更新知识库文档"
)(update_docs_by_id)
app.post("/knowledge_base/upload_docs", app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,
@ -171,12 +151,6 @@ def mount_knowledge_routes(app: FastAPI):
summary="更新知识库介绍" summary="更新知识库介绍"
)(update_info) )(update_info)
app.post("/knowledge_base/update_kb_endpoint",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新知识库在线api接入点配置"
)(update_kb_endpoint)
app.post("/knowledge_base/update_docs", app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import json import json
from typing import AsyncIterable, List, Union, Dict, Annotated from typing import AsyncIterable, List
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, openai_config, callbacks, stream): def create_models_from_config(configs, callbacks, stream):
if configs is None: if configs is None:
configs = {} configs = {}
models = {} models = {}
@ -30,9 +30,6 @@ def create_models_from_config(configs, openai_config, callbacks, stream):
for model_name, params in model_configs.items(): for model_name, params in model_configs.items():
callbacks = callbacks if params.get('callbacks', False) else None callbacks = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI( model_instance = get_ChatOpenAI(
endpoint_host=openai_config.get('endpoint_host', None),
endpoint_host_key=openai_config.get('endpoint_host_key', None),
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
model_name=model_name, model_name=model_name,
temperature=params.get('temperature', 0.5), temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000), max_tokens=params.get('max_tokens', 1000),
@ -116,7 +113,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
), ),
stream: bool = Body(True, description="流式输出"), stream: bool = Body(True, description="流式输出"),
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]), tool_config: dict = Body({}, description="工具配置", examples=[]),
): ):
async def chat_iterator() -> AsyncIterable[str]: async def chat_iterator() -> AsyncIterable[str]:
@ -129,7 +125,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback = AgentExecutorAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback] callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
openai_config=openai_config, stream=stream) stream=stream)
tools = [tool for tool in all_tools if tool.name in tool_config] tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools] tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(prompts=prompts, full_chain = create_models_chains(prompts=prompts,

View File

@ -1,6 +1,5 @@
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODEL_CONFIG
from server.utils import wrap_done, get_OpenAI from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
@ -14,9 +13,6 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"), echo: bool = Body(False, description="除了输出之外,还回显输入"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"), model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -27,9 +23,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
#TODO: 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理 #TODO: 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str, async def completion_iterator(query: str,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str = None, model_name: str = None,
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
echo: bool = echo, echo: bool = echo,
@ -40,9 +33,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
max_tokens = None max_tokens = None
model = get_OpenAI( model = get_OpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -72,10 +62,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
await task await task
return StreamingResponse(completion_iterator(query=query, return EventSourceResponse(completion_iterator(query=query,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
prompt_name=prompt_name), prompt_name=prompt_name),
) )

View File

@ -1,8 +1,7 @@
from fastapi import Body, File, Form, UploadFile from fastapi import Body, File, Form, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.embeddings.adapter import load_temp_adapter_embeddings from server.utils import (wrap_done, get_ChatOpenAI, get_Embeddings,
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from langchain.chains import LLMChain from langchain.chains import LLMChain
@ -57,9 +56,6 @@ def _parse_files_in_thread(
def upload_temp_docs( def upload_temp_docs(
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
files: List[UploadFile] = File(..., description="上传文件,支持多文件"), files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"), prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
@ -86,11 +82,7 @@ def upload_temp_docs(
else: else:
failed_files.append({file: msg}) failed_files.append({file: msg})
with memo_faiss_pool.load_vector_store(kb_name=id, with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs:
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
).acquire() as vs:
vs.add_documents(documents) vs.add_documents(documents)
return BaseResponse(data={"id": id, "failed_files": failed_files}) return BaseResponse(data={"id": id, "failed_files": failed_files})
@ -110,9 +102,6 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"), model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -131,17 +120,12 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
max_tokens = None max_tokens = None
model = get_ChatOpenAI( model = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host, embed_func = get_Embeddings()
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
embeddings = await embed_func.aembed_query(query) embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs: with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)

View File

@ -12,9 +12,6 @@ class KnowledgeBaseModel(Base):
kb_name = Column(String(50), comment='知识库名称') kb_name = Column(String(50), comment='知识库名称')
kb_info = Column(String(200), comment='知识库简介(用于Agent)') kb_info = Column(String(200), comment='知识库简介(用于Agent)')
vs_type = Column(String(50), comment='向量库类型') vs_type = Column(String(50), comment='向量库类型')
endpoint_host = Column(String(50), comment='接入点地址')
endpoint_host_key = Column(String(50), comment='接入点key')
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
embed_model = Column(String(50), comment='嵌入模型名称') embed_model = Column(String(50), comment='嵌入模型名称')
file_count = Column(Integer, default=0, comment='文件数量') file_count = Column(Integer, default=0, comment='文件数量')
create_time = Column(DateTime, default=func.now(), comment='创建时间') create_time = Column(DateTime, default=func.now(), comment='创建时间')

View File

@ -3,22 +3,16 @@ from server.db.session import with_session
@with_session @with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None, def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
endpoint_host_key: str = None, endpoint_host_proxy: str = None):
# 创建知识库实例 # 创建知识库实例
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if not kb: if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model, kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
session.add(kb) session.add(kb)
else: # update kb with new vs_type and embed_model else: # update kb with new vs_type and embed_model
kb.kb_info = kb_info kb.kb_info = kb_info
kb.vs_type = vs_type kb.vs_type = vs_type
kb.embed_model = embed_model kb.embed_model = embed_model
kb.endpoint_host = endpoint_host
kb.endpoint_host_key = endpoint_host_key
kb.endpoint_host_proxy = endpoint_host_proxy
return True return True
@ -54,16 +48,6 @@ def delete_kb_from_db(session, kb_name):
return True return True
@with_session
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
if kb:
kb.endpoint_host = endpoint_host
kb.endpoint_host_key = endpoint_host_key
kb.endpoint_host_proxy = endpoint_host_proxy
return True
@with_session @with_session
def get_kb_detail(session, kb_name: str) -> dict: def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
@ -72,9 +56,6 @@ def get_kb_detail(session, kb_name: str) -> dict:
"kb_name": kb.kb_name, "kb_name": kb.kb_name,
"kb_info": kb.kb_info, "kb_info": kb.kb_info,
"vs_type": kb.vs_type, "vs_type": kb.vs_type,
"endpoint_host": kb.endpoint_host,
"endpoint_host_key": kb.endpoint_host_key,
"endpoint_host_proxy": kb.endpoint_host_proxy,
"embed_model": kb.embed_model, "embed_model": kb.embed_model,
"file_count": kb.file_count, "file_count": kb.file_count,
"create_time": kb.create_time, "create_time": kb.create_time,

View File

@ -1,130 +0,0 @@
import numpy as np
from typing import List, Union, Dict
from langchain.embeddings.base import Embeddings
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL, KB_INFO)
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
from server.utils import embedding_device
class EmbeddingsFunAdapter(Embeddings):
_endpoint_host: str
_endpoint_host_key: str
_endpoint_host_proxy: str
def __init__(self,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
):
self._endpoint_host = endpoint_host
self._endpoint_host_key = endpoint_host_key
self._endpoint_host_proxy = endpoint_host_proxy
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embed_texts(texts=texts,
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=False).data
return self._normalize(embeddings=embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text],
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts,
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=False)).data
return self._normalize(embeddings=embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text],
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
@staticmethod
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
#TODO 此处内容处理错误
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
def load_kb_adapter_embeddings(
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> "EmbeddingsFunAdapter":
"""
加载知识库配置的Embeddings模型
本地模型最终会通过load_embeddings加载
在线模型会在适配器中直接返回
:param kb_name:
:param embed_device:
:param default_embed_model:
:return:
"""
from server.db.repository.knowledge_base_repository import get_kb_detail
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
endpoint_host = kb_detail.get("endpoint_host", None)
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model)
def load_temp_adapter_embeddings(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> "EmbeddingsFunAdapter":
"""
加载临时的Embeddings模型
本地模型最终会通过load_embeddings加载
在线模型会在适配器中直接返回
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param embed_device:
:param default_embed_model:
:return:
"""
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=default_embed_model)

View File

@ -1,94 +0,0 @@
from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
def embed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制减少 token 消耗
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=embed_model,
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
chunk_size=CHUNK_SIZE)
return BaseResponse(data=embeddings.embed_documents(texts))
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
return await run_in_threadpool(embed_texts,
texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
接入api对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model, to_query=to_query)

View File

@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs import EMBEDDING_MODEL, logger, log_verbose from configs import DEFAULT_EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body from fastapi import Body
@ -14,10 +14,7 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
@ -31,7 +28,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try: try:
kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy) kb.create_kb()
except Exception as e: except Exception as e:
msg = f"创建知识库出错: {e}" msg = f"创建知识库出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',

View File

@ -1,9 +1,8 @@
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.faiss import FAISS
import threading import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, from configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose) logger, log_verbose)
from server.utils import embedding_device, get_model_path
from contextlib import contextmanager from contextlib import contextmanager
from collections import OrderedDict from collections import OrderedDict
from typing import List, Any, Union, Tuple from typing import List, Any, Union, Tuple
@ -98,50 +97,3 @@ class CachePool:
else: else:
return cache return cache
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
"""
本地Embeddings模型加载
:param model:
:param device:
:return:
"""
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
# for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:"
elif 'en' in model:
# for english model
query_instruction = "Represent this sentence for searching relevant passages:"
else:
# maybe ReRanker or else, just use empty string instead
query_instruction = ""
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device},
query_instruction=query_instruction)
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)

View File

@ -1,7 +1,6 @@
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
from server.knowledge_base.kb_cache.base import * from server.knowledge_base.kb_cache.base import *
# from server.utils import load_local_embeddings from server.utils import get_Embeddings
from server.knowledge_base.utils import get_vs_path from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore from langchain.docstore.in_memory import InMemoryDocstore
@ -53,13 +52,11 @@ class _FaissPool(CachePool):
def new_vector_store( def new_vector_store(
self, self,
kb_name: str, kb_name: str,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS: ) -> FAISS:
# create an empty vector store # create an empty vector store
embeddings = load_kb_adapter_embeddings(kb_name=kb_name, embeddings = get_Embeddings(embed_model=embed_model)
embed_device=embed_device, default_embed_model=embed_model)
doc = Document(page_content="init", metadata={}) doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys()) ids = list(vector_store.docstore._dict.keys())
@ -68,18 +65,11 @@ class _FaissPool(CachePool):
def new_temp_vector_store( def new_temp_vector_store(
self, self,
endpoint_host: str, embed_model: str = DEFAULT_EMBEDDING_MODEL,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS: ) -> FAISS:
# create an empty vector store # create an empty vector store
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host, embeddings = get_Embeddings(embed_model=embed_model)
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_device=embed_device, default_embed_model=embed_model)
doc = Document(page_content="init", metadata={}) doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys()) ids = list(vector_store.docstore._dict.keys())
@ -102,8 +92,7 @@ class KBFaissPool(_FaissPool):
kb_name: str, kb_name: str,
vector_name: str = None, vector_name: str = None,
create: bool = True, create: bool = True,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
vector_name = vector_name or embed_model vector_name = vector_name or embed_model
@ -118,15 +107,13 @@ class KBFaissPool(_FaissPool):
vs_path = get_vs_path(kb_name, vector_name) vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = load_kb_adapter_embeddings(kb_name=kb_name, embeddings = get_Embeddings(embed_model=embed_model)
embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create: elif create:
# create an empty vector store # create an empty vector store
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
os.makedirs(vs_path) os.makedirs(vs_path)
vector_store = self.new_vector_store(kb_name=kb_name, vector_store = self.new_vector_store(kb_name=kb_name, embed_model=embed_model)
embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
else: else:
raise RuntimeError(f"knowledge base {kb_name} not exist.") raise RuntimeError(f"knowledge base {kb_name} not exist.")
@ -148,11 +135,7 @@ class MemoFaissPool(_FaissPool):
def load_vector_store( def load_vector_store(
self, self,
kb_name: str, kb_name: str,
endpoint_host: str, embed_model: str = DEFAULT_EMBEDDING_MODEL,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
cache = self.get(kb_name) cache = self.get(kb_name)
@ -163,10 +146,7 @@ class MemoFaissPool(_FaissPool):
self.atomic.release() self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.") logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store # create an empty vector store
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host, vector_store = self.new_temp_vector_store(embed_model=embed_model)
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store item.obj = vector_store
item.finish_loading() item.finish_loading()
else: else:

View File

@ -1,7 +1,7 @@
import os import os
import urllib import urllib
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose, ) logger, log_verbose, )
@ -42,22 +42,6 @@ def search_docs(
return data return data
def update_docs_by_id(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
) -> BaseResponse:
'''
按照文档 ID 更新文档内容
'''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
if kb.update_doc_by_ids(docs=docs):
return BaseResponse(msg=f"文档更新成功")
else:
return BaseResponse(msg=f"文档更新失败")
def list_files( def list_files(
knowledge_base_name: str knowledge_base_name: str
) -> ListResponse: ) -> ListResponse:
@ -230,26 +214,6 @@ def update_info(
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
def update_kb_endpoint(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy)
return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成",
data={"endpoint_host": endpoint_host,
"endpoint_host_key": endpoint_host_key,
"endpoint_host_proxy": endpoint_host_proxy})
def update_docs( def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
@ -366,10 +330,7 @@ def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE), vs_type: str = Body(DEFAULT_VS_TYPE),
endpoint_host: str = Body(None, description="接入点地址"), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
@ -389,9 +350,7 @@ def recreate_vector_store(
else: else:
if kb.exists(): if kb.exists():
kb.clear_vs() kb.clear_vs()
kb.create_kb(endpoint_host=endpoint_host, kb.create_kb()
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
files = list_files_from_folder(knowledge_base_name) files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files] kb_files = [(file, knowledge_base_name) for file in files]
i = 0 i = 0

View File

@ -6,7 +6,7 @@ from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import ( from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db, load_kb_from_db, get_kb_detail,
) )
from server.db.repository.knowledge_file_repository import ( from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
@ -15,7 +15,7 @@ from server.db.repository.knowledge_file_repository import (
) )
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL, KB_INFO) DEFAULT_EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile, get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
@ -40,7 +40,7 @@ class KBService(ABC):
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
): ):
self.kb_name = knowledge_base_name self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
@ -58,20 +58,14 @@ class KBService(ABC):
''' '''
pass pass
def create_kb(self, def create_kb(self):
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None):
""" """
创建知识库 创建知识库
""" """
if not os.path.exists(self.doc_path): if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path) os.makedirs(self.doc_path)
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model, status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
if status: if status:
self.do_create_kb() self.do_create_kb()
@ -144,16 +138,6 @@ class KBService(ABC):
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status return status
def update_kb_endpoint(self,
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None):
"""
更新知识库在线api接入点配置
"""
status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy)
return status
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
""" """
使用content中的文件更新向量库 使用content中的文件更新向量库
@ -297,7 +281,7 @@ class KBServiceFactory:
@staticmethod @staticmethod
def get_service(kb_name: str, def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType], vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> KBService: ) -> KBService:
if isinstance(vector_store_type, str): if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) vector_store_type = getattr(SupportedVSType, vector_store_type.upper())

View File

@ -1,13 +1,11 @@
from typing import List from typing import List
import os import os
import shutil import shutil
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings from server.utils import get_Embeddings
from elasticsearch import Elasticsearch,BadRequestError from elasticsearch import Elasticsearch,BadRequestError
from configs import logger from configs import logger
from configs import kbs_config from configs import kbs_config
@ -22,7 +20,7 @@ class ESKBService(KBService):
self.user = kbs_config[self.vs_type()].get("user",'') self.user = kbs_config[self.vs_type()].get("user",'')
self.password = kbs_config[self.vs_type()].get("password",'') self.password = kbs_config[self.vs_type()].get("password",'')
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None) self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE) self.embeddings_model = get_Embeddings(self.embed_model)
try: try:
# ES python客户端连接仅连接 # ES python客户端连接仅连接
if self.user != "" and self.password != "": if self.user != "" and self.password != "":

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
from configs import ( from configs import (
EMBEDDING_MODEL, DEFAULT_EMBEDDING_MODEL,
KB_ROOT_PATH) KB_ROOT_PATH)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -21,7 +21,7 @@ class KBSummaryService(ABC):
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL embed_model: str = DEFAULT_EMBEDDING_MODEL
): ):
self.kb_name = knowledge_base_name self.kb_name = knowledge_base_name
self.embed_model = embed_model self.embed_model = embed_model

View File

@ -1,5 +1,5 @@
from fastapi import Body from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
OVERLAP_SIZE, OVERLAP_SIZE,
logger, log_verbose, ) logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder) from server.knowledge_base.utils import (list_files_from_folder)
@ -17,11 +17,8 @@ def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE), vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
file_description: str = Body(''), file_description: str = Body(''),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"), model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -29,9 +26,6 @@ def recreate_summary_vector_store(
""" """
重建单个知识库文件摘要 重建单个知识库文件摘要
:param max_tokens: :param max_tokens:
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param model_name: :param model_name:
:param temperature: :param temperature:
:param file_description: :param file_description:
@ -54,17 +48,11 @@ def recreate_summary_vector_store(
kb_summary.create_kb_summary() kb_summary.create_kb_summary()
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
reduce_llm = get_ChatOpenAI( reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -110,20 +98,14 @@ def summary_file_to_vector_store(
file_name: str = Body(..., examples=["test.pdf"]), file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE), vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
file_description: str = Body(''), file_description: str = Body(''),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"), model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
): ):
""" """
单个知识库根据文件名称摘要 单个知识库根据文件名称摘要
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param model_name: :param model_name:
:param max_tokens: :param max_tokens:
:param temperature: :param temperature:
@ -146,17 +128,11 @@ def summary_file_to_vector_store(
kb_summary.create_kb_summary() kb_summary.create_kb_summary()
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
reduce_llm = get_ChatOpenAI( reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -194,11 +170,8 @@ def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]), doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(DEFAULT_VS_TYPE), vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
file_description: str = Body(''), file_description: str = Body(''),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"), model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -206,9 +179,6 @@ def summary_doc_ids_to_vector_store(
""" """
单个知识库根据doc_ids摘要 单个知识库根据doc_ids摘要
:param knowledge_base_name: :param knowledge_base_name:
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param doc_ids: :param doc_ids:
:param model_name: :param model_name:
:param max_tokens: :param max_tokens:
@ -223,17 +193,11 @@ def summary_doc_ids_to_vector_store(
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
else: else:
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
reduce_llm = get_ChatOpenAI( reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,

View File

@ -1,5 +1,5 @@
from configs import ( from configs import (
EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE, CHUNK_SIZE, OVERLAP_SIZE,
logger, log_verbose logger, log_verbose
) )
@ -86,7 +86,7 @@ def folder2db(
kb_names: List[str], kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increment"], mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE, zh_title_enhance: bool = ZH_TITLE_ENHANCE,

View File

@ -1,4 +1,5 @@
import os import os
from functools import lru_cache
from configs import ( from configs import (
KB_ROOT_PATH, KB_ROOT_PATH,
CHUNK_SIZE, CHUNK_SIZE,
@ -143,6 +144,7 @@ def get_LoaderClass(file_extension):
if file_extension in extensions: if file_extension in extensions:
return LoaderClass return LoaderClass
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
''' '''
根据loader_name和文件路径或内容返回文档加载器 根据loader_name和文件路径或内容返回文档加载器
@ -184,6 +186,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
return loader return loader
@lru_cache()
def make_text_splitter( def make_text_splitter(
splitter_name, splitter_name,
chunk_size, chunk_size,

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import logging
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_community.utils.openai import is_openai_v1
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
async_retrying = AsyncRetrying(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def wrap(func: Callable) -> Callable:
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
async for _ in async_retrying:
return await func(*args, **kwargs)
raise AssertionError("this is unreachable")
return wrapped_f
return wrap
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any([len(d.embedding) == 1 for d in response.data]):
import openai
raise openai.APIError("LocalAI API returned an empty embedding")
return response
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.async_client.acreate(**kwargs)
return _check_response(response)
return await _async_embed_with_retry(**kwargs)
class LocalAIEmbeddings(BaseModel, Embeddings):
"""LocalAI embedding models.
Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
uses the ``openai`` Python package's ``openai.Embedding`` as its client.
Thus, you should have the ``openai`` python package installed, and defeat
the environment variable ``OPENAI_API_KEY`` by setting to a random string.
You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
service endpoint.
Example:
.. code-block:: python
from langchain_community.embeddings import LocalAIEmbeddings
openai = LocalAIEmbeddings(
openai_api_key="random-string",
openai_api_base="http://localhost:8080"
)
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_organization: Optional[str] = Field(default=None, alias="organization")
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch"""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout in seconds for the LocalAI request."""
headers: Any = None
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
default_api_version = ""
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
default=default_api_version,
)
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
try:
import openai
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).embeddings
elif not values.get("client"):
values["client"] = openai.Embedding
else:
pass
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values
@property
def _invocation_params(self) -> Dict:
openai_args = {
"model": self.model,
"timeout": self.request_timeout,
"extra_headers": self.headers,
**self.model_kwargs,
}
if self.openai_proxy:
import openai
openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501
return openai_args
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self,
input=[text],
**self._invocation_params,
).data[0].embedding
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
).data[0].embedding
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
# call _embedding_func for each text
return [self._embedding_func(text, engine=self.deployment) for text in texts]
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = await self._aembedding_func(text, engine=self.deployment)
embeddings.append(response)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
return embedding
async def aembed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = await self._aembedding_func(text, engine=self.deployment)
return embedding

View File

@ -109,14 +109,13 @@ if __name__ == "__main__":
RERANKER_MODEL, RERANKER_MODEL,
RERANKER_MAX_LENGTH, RERANKER_MAX_LENGTH,
MODEL_PATH) MODEL_PATH)
from server.utils import embedding_device
if USE_RERANKER: if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large") reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
print("-----------------model path------------------") print("-----------------model path------------------")
print(reranker_model_path) print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3, reranker_model = LangchainReranker(top_n=3,
device=embedding_device(), device="cpu",
max_length=RERANKER_MAX_LENGTH, max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path model_name_or_path=reranker_model_path
) )

View File

@ -1,32 +1,28 @@
import pydantic
from pydantic import BaseModel
from typing import List
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.pydantic_v1 import BaseModel, Field
from langchain.embeddings.base import Embeddings
from langchain_openai.chat_models import ChatOpenAI from langchain_openai.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI from langchain_openai.llms import OpenAI
import httpx import httpx
from typing import ( from typing import (
TYPE_CHECKING,
Literal,
Optional, Optional,
Callable, Callable,
Generator, Generator,
Dict, Dict,
List,
Any, Any,
Awaitable, Awaitable,
Union, Union,
Tuple Tuple,
Literal,
) )
import logging import logging
import torch
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL
from server.minx_chat_openai import MinxChatOpenAI from server.minx_chat_openai import MinxChatOpenAI
@ -44,10 +40,66 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
event.set() event.set()
def get_config_models(
model_name: str = None,
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
platform_name: str = None,
) -> Dict[str, Dict]:
'''
获取配置的模型列表返回值为:
{model_name: {
"platform_name": xx,
"platform_type": xx,
"model_type": xx,
"model_name": xx,
"api_base_url": xx,
"api_key": xx,
"api_proxy": xx,
}}
'''
import importlib
from configs import model_config
importlib.reload(model_config)
result = {}
for m in model_config.MODEL_PLATFORMS:
if platform_name is not None and platform_name != m.get("platform_name"):
continue
if model_type is not None and f"{model_type}_models" not in m:
continue
if model_type is None:
model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"]
else:
model_types = [f"{model_type}_models"]
for m_type in model_types:
for m_name in m.get(m_type, []):
if model_name is None or model_name == m_name:
result[m_name] = {
"platform_name": m.get("platform_name"),
"platform_type": m.get("platform_type"),
"model_type": m_type.split("_")[0],
"model_name": m_name,
"api_base_url": m.get("api_base_url"),
"api_key": m.get("api_key"),
"api_proxy": m.get("api_proxy"),
}
return result
def get_model_info(model_name: str, platform_name: str = None) -> Dict:
'''
获取配置的模型信息主要是 api_base_url, api_key
'''
result = get_config_models(model_name=model_name, platform_name=platform_name)
if len(result) > 0:
return list(result.values())[0]
else:
return {}
def get_ChatOpenAI( def get_ChatOpenAI(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str, model_name: str,
temperature: float, temperature: float,
max_tokens: int = None, max_tokens: int = None,
@ -56,29 +108,23 @@ def get_ChatOpenAI(
verbose: bool = True, verbose: bool = True,
**kwargs: Any, **kwargs: Any,
) -> ChatOpenAI: ) -> ChatOpenAI:
config = get_model_worker_config(model_name) model_info = get_model_info(model_name)
if model_name == "openai-api":
model_name = config.get("model_name")
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
model = ChatOpenAI( model = ChatOpenAI(
streaming=streaming, streaming=streaming,
verbose=verbose, verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
**kwargs **kwargs
) )
return model return model
def get_OpenAI( def get_OpenAI(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str, model_name: str,
temperature: float, temperature: float,
max_tokens: int = None, max_tokens: int = None,
@ -89,22 +135,40 @@ def get_OpenAI(
**kwargs: Any, **kwargs: Any,
) -> OpenAI: ) -> OpenAI:
# TODO: 从API获取模型信息 # TODO: 从API获取模型信息
model_info = get_model_info(model_name)
model = OpenAI( model = OpenAI(
streaming=streaming, streaming=streaming,
verbose=verbose, verbose=verbose,
callbacks=callbacks, callbacks=callbacks,
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, openai_api_key=model_info.get("api_key"),
openai_api_base=model_info.get("api_base_url"),
openai_proxy=model_info.get("api_proxy"),
echo=echo, echo=echo,
**kwargs **kwargs
) )
return model return model
def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings:
from langchain_community.embeddings.openai import OpenAIEmbeddings
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
model_info = get_model_info(model_name=embed_model)
params = {
"model": embed_model,
"base_url": model_info.get("api_base_url"),
"api_key": model_info.get("api_key"),
"openai_proxy": model_info.get("api_proxy"),
}
if model_info.get("platform_type") == "openai":
return OpenAIEmbeddings(**params)
else:
return LocalAIEmbeddings(**params)
class MsgType: class MsgType:
TEXT = 1 TEXT = 1
IMAGE = 2 IMAGE = 2
@ -113,9 +177,9 @@ class MsgType:
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data") data: Any = Field(None, description="API data")
class Config: class Config:
schema_extra = { schema_extra = {
@ -127,7 +191,7 @@ class BaseResponse(BaseModel):
class ListResponse(BaseResponse): class ListResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of names") data: List[str] = Field(..., description="List of names")
class Config: class Config:
schema_extra = { schema_extra = {
@ -140,10 +204,10 @@ class ListResponse(BaseResponse):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text") question: str = Field(..., description="Question text")
response: str = pydantic.Field(..., description="Response text") response: str = Field(..., description="Response text")
history: List[List[str]] = pydantic.Field(..., description="History text") history: List[List[str]] = Field(..., description="History text")
source_documents: List[str] = pydantic.Field( source_documents: List[str] = Field(
..., description="List of source documents and their scores" ..., description="List of source documents and their scores"
) )
@ -310,39 +374,40 @@ def MakeFastAPIOffline(
# 从model_config中获取模型信息 # 从model_config中获取模型信息
# TODO: 移出模型加载后,这些功能需要删除或改变实现
def list_embed_models() -> List[str]: # def list_embed_models() -> List[str]:
''' # '''
get names of configured embedding models # get names of configured embedding models
''' # '''
return list(MODEL_PATH["embed_model"]) # return list(MODEL_PATH["embed_model"])
def get_model_path(model_name: str, type: str = None) -> Optional[str]: # def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH: # if type in MODEL_PATH:
paths = MODEL_PATH[type] # paths = MODEL_PATH[type]
else: # else:
paths = {} # paths = {}
for v in MODEL_PATH.values(): # for v in MODEL_PATH.values():
paths.update(v) # paths.update(v)
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 # if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
path = Path(path_str) # path = Path(path_str)
if path.is_dir(): # 任意绝对路径 # if path.is_dir(): # 任意绝对路径
return str(path) # return str(path)
root_path = Path(MODEL_ROOT_PATH) # root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir(): # if root_path.is_dir():
path = root_path / model_name # path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b # if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path) # return str(path)
path = root_path / path_str # path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new # if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path) # return str(path)
path = root_path / path_str.split("/")[-1] # path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new # if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path) # return str(path)
return path_str # THUDM/chatglm06b # return path_str # THUDM/chatglm06b
def api_address() -> str: def api_address() -> str:
@ -429,37 +494,6 @@ def set_httpx_config(
urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
import intel_extension_for_pytorch as ipex
if torch.xpu.get_device_properties(0):
return "xpu"
except:
pass
return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE
# if device.isdigit():
# return "cuda:" + device
if device not in ["cuda", "mps", "cpu", "xpu"]:
device = detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu", "xpu"]:
device = detect_device()
return device
def run_in_thread_pool( def run_in_thread_pool(
func: Callable, func: Callable,
params: List[Dict] = [], params: List[Dict] = [],
@ -546,56 +580,19 @@ def get_server_configs() -> Dict:
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
def list_online_embed_models(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str
) -> List[str]:
ret = []
# TODO: 从在线API获取支持的模型列表
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
try:
headers = {
"Authorization": f"Bearer {endpoint_host_key}",
}
resp = client.get("/models", headers=headers)
if resp.status_code == 200:
models = resp.json().get("data", [])
for model in models:
if "embedding" in model.get("id", None):
ret.append(model.get("id"))
except Exception as e:
msg = f"获取在线Embeddings模型列表失败{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
client.close()
return ret
def load_local_embeddings(model: str = None, device: str = embedding_device()):
'''
从缓存中本地Embeddings模型加载可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
from configs import EMBEDDING_MODEL
model = model or EMBEDDING_MODEL
return embeddings_pool.load_embeddings(model=model, device=device)
def get_temp_dir(id: str = None) -> Tuple[str, str]: def get_temp_dir(id: str = None) -> Tuple[str, str]:
''' '''
创建一个临时目录返回路径文件夹名称 创建一个临时目录返回路径文件夹名称
''' '''
from configs.basic_config import BASE_TEMP_DIR from configs.basic_config import BASE_TEMP_DIR
import tempfile import uuid
if id is not None: # 如果指定的临时目录已存在,直接返回 if id is not None: # 如果指定的临时目录已存在,直接返回
path = os.path.join(BASE_TEMP_DIR, id) path = os.path.join(BASE_TEMP_DIR, id)
if os.path.isdir(path): if os.path.isdir(path):
return path, id return path, id
path = tempfile.mkdtemp(dir=BASE_TEMP_DIR) id = uuid.uuid4().hex
return path, os.path.basename(path) path = os.path.join(BASE_TEMP_DIR, id)
os.mkdir(path)
return path, id

View File

@ -1,12 +1,11 @@
import asyncio import asyncio
from contextlib import asynccontextmanager
import multiprocessing as mp import multiprocessing as mp
import os import os
import subprocess import subprocess
import sys import sys
from multiprocessing import Process from multiprocessing import Process
from datetime import datetime
from pprint import pprint
from langchain_core._api import deprecated
# 设置numexpr最大线程数默认为CPU核心数 # 设置numexpr最大线程数默认为CPU核心数
try: try:
@ -17,38 +16,29 @@ try:
except: except:
pass pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import ( from configs import (
LOG_PATH, LOG_PATH,
log_verbose, log_verbose,
logger, logger,
LLM_MODEL_CONFIG, DEFAULT_EMBEDDING_MODEL,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
API_SERVER, API_SERVER,
WEBUI_SERVER, WEBUI_SERVER,
HTTPX_DEFAULT_TIMEOUT,
) )
from server.utils import (FastAPI, embedding_device) from server.utils import FastAPI
from server.knowledge_base.migrate import create_tables from server.knowledge_base.migrate import create_tables
import argparse import argparse
from typing import List, Dict from typing import List, Dict
from configs import VERSION from configs import VERSION
all_model_names = set()
for model_category in LLM_MODEL_CONFIG.values():
for model_name in model_category.keys():
if model_name not in all_model_names:
all_model_names.add(model_name)
all_model_names_list = list(all_model_names)
def _set_app_event(app: FastAPI, started_event: mp.Event = None): def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup") @asynccontextmanager
async def on_startup(): async def lifespan(app: FastAPI):
if started_event is not None: if started_event is not None:
started_event.set() started_event.set()
yield
app.router.lifespan_context = lifespan
def run_api_server(started_event: mp.Event = None, run_mode: str = None): def run_api_server(started_event: mp.Event = None, run_mode: str = None):
@ -159,7 +149,7 @@ def dump_server_info(after_start=False, args=None):
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}") print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}") print(f"当前Embbedings模型 {DEFAULT_EMBEDDING_MODEL}")
if after_start: if after_start:
print("\n") print("\n")
@ -232,13 +222,13 @@ async def start_main_server():
return len(processes) return len(processes)
loom_started = manager.Event() loom_started = manager.Event()
process = Process( # process = Process(
target=run_loom, # target=run_loom,
name=f"run_loom Server", # name=f"run_loom Server",
kwargs=dict(started_event=loom_started), # kwargs=dict(started_event=loom_started),
daemon=True, # daemon=True,
) # )
processes["run_loom"] = process # processes["run_loom"] = process
api_started = manager.Event() api_started = manager.Event()
if args.api: if args.api:
process = Process( process = Process(
@ -283,7 +273,6 @@ async def start_main_server():
# 等待所有进程退出 # 等待所有进程退出
if p := processes.get("webui"): if p := processes.get("webui"):
p.join() p.join()
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
@ -306,9 +295,7 @@ async def start_main_server():
if __name__ == "__main__": if __name__ == "__main__":
create_tables() create_tables()
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
else: else:

View File

@ -1,7 +1,7 @@
import streamlit as st import streamlit as st
from webui_pages.loom_view_client import update_store # from webui_pages.loom_view_client import update_store
from webui_pages.openai_plugins import openai_plugins_page # from webui_pages.openai_plugins import openai_plugins_page
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu
from webui_pages.dialogue.dialogue import dialogue_page, chat_box from webui_pages.dialogue.dialogue import dialogue_page, chat_box
@ -12,9 +12,9 @@ from configs import VERSION
from server.utils import api_address from server.utils import api_address
def on_change(key): # def on_change(key):
if key: # if key:
update_store() # update_store()
api = ApiRequest(base_url=api_address()) api = ApiRequest(base_url=api_address())
@ -59,18 +59,18 @@ if __name__ == "__main__":
"icon": "hdd-stack", "icon": "hdd-stack",
"func": knowledge_base_page, "func": knowledge_base_page,
}, },
"模型服务": { # "模型服务": {
"icon": "hdd-stack", # "icon": "hdd-stack",
"func": openai_plugins_page, # "func": openai_plugins_page,
}, # },
} }
# 更新状态 # 更新状态
if "status" not in st.session_state \ # if "status" not in st.session_state \
or "run_plugins_list" not in st.session_state \ # or "run_plugins_list" not in st.session_state \
or "launch_subscribe_info" not in st.session_state \ # or "launch_subscribe_info" not in st.session_state \
or "list_running_models" not in st.session_state \ # or "list_running_models" not in st.session_state \
or "model_config" not in st.session_state: # or "model_config" not in st.session_state:
update_store() # update_store()
with st.sidebar: with st.sidebar:
st.image( st.image(
@ -95,7 +95,6 @@ if __name__ == "__main__":
icons=icons, icons=icons,
# menu_icon="chat-quote", # menu_icon="chat-quote",
default_index=default_index, default_index=default_index,
on_change=on_change,
) )
if selected_page in pages: if selected_page in pages:

View File

@ -4,8 +4,8 @@ import streamlit as st
from streamlit_antd_components.utils import ParseItems from streamlit_antd_components.utils import ParseItems
from webui_pages.dialogue.utils import process_files from webui_pages.dialogue.utils import process_files
from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \ # from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
get_select_model_endpoint # get_select_model_endpoint
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from streamlit_modal import Modal from streamlit_modal import Modal
@ -13,9 +13,9 @@ from datetime import datetime
import os import os
import re import re
import time import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY) from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG)
from server.callback_handler.agent_callback_handler import AgentStatus from server.callback_handler.agent_callback_handler import AgentStatus
from server.utils import MsgType from server.utils import MsgType, get_config_models
import uuid import uuid
from typing import List, Dict from typing import List, Dict
@ -111,8 +111,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {}) st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex) st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
st.session_state.setdefault("file_chat_id", None) st.session_state.setdefault("file_chat_id", None)
st.session_state.setdefault("select_plugins_info", None)
st.session_state.setdefault("select_model_worker", None)
# 弹出自定义命令帮助信息 # 弹出自定义命令帮助信息
modal = Modal("自定义命令", key="cmd_help", max_width="500") modal = Modal("自定义命令", key="cmd_help", max_width="500")
@ -131,18 +129,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.use_chat_name(conversation_name) chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name] conversation_id = st.session_state["conversation_ids"][conversation_name]
with st.expander("模型选择"): platforms = [x["platform_name"] for x in MODEL_PLATFORMS]
plugins_menu = build_providers_model_plugins_name() platform = st.selectbox("选择模型平台", platforms, 1)
llm_models = list(get_config_models(model_type="llm", platform_name=platform))
items, _ = ParseItems(plugins_menu).multi() llm_model = st.selectbox("选择LLM模型", llm_models)
if len(plugins_menu) > 0:
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
set_llm_select(plugins_info, llm_model_worker)
else:
st.info("没有可用的插件")
# 传入后端的内容 # 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
@ -174,10 +164,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if is_selected: if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool] selected_tool_configs[tool] = TOOL_CONFIG[tool]
llm_model = None
if st.session_state["select_model_worker"] is not None:
llm_model = st.session_state["select_model_worker"]['label']
if llm_model is not None: if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
@ -200,23 +186,23 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.output_messages() chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 " chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback( # def on_feedback(
feedback, # feedback,
message_id: str = "", # message_id: str = "",
history_index: int = -1, # history_index: int = -1,
): # ):
reason = feedback["text"] # reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) # score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
api.chat_feedback(message_id=message_id, # api.chat_feedback(message_id=message_id,
score=score_int, # score=score_int,
reason=reason) # reason=reason)
st.session_state["need_rerun"] = True # st.session_state["need_rerun"] = True
feedback_kwargs = { # feedback_kwargs = {
"feedback_type": "thumbs", # "feedback_type": "thumbs",
"optional_text_label": "欢迎反馈您打分的理由", # "optional_text_label": "欢迎反馈您打分的理由",
} # }
if prompt := st.chat_input(chat_input_placeholder, key="prompt"): if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
@ -244,17 +230,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text_action = "" text_action = ""
element_index = 0 element_index = 0
openai_config = {}
endpoint_host, select_model_name = get_select_model_endpoint()
openai_config["endpoint_host"] = endpoint_host
openai_config["model_name"] = select_model_name
openai_config["endpoint_host_key"] = OPENAI_KEY
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
for d in api.chat_chat(query=prompt, for d in api.chat_chat(query=prompt,
metadata=files_upload, metadata=files_upload,
history=history, history=history,
chat_model_config=chat_model_config, chat_model_config=chat_model_config,
openai_config=openai_config,
conversation_id=conversation_id, conversation_id=conversation_id,
tool_config=selected_tool_configs, tool_config=selected_tool_configs,
): ):

View File

@ -1,8 +1,8 @@
import streamlit as st import streamlit as st
from streamlit_antd_components.utils import ParseItems from streamlit_antd_components.utils import ParseItems
from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \ # from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
set_llm_select, set_embed_select, get_select_embed_endpoint # set_llm_select, set_embed_select, get_select_embed_endpoint
from webui_pages.utils import * from webui_pages.utils import *
from st_aggrid import AgGrid, JsCode from st_aggrid import AgGrid, JsCode
from st_aggrid.grid_options_builder import GridOptionsBuilder from st_aggrid.grid_options_builder import GridOptionsBuilder
@ -10,10 +10,9 @@ import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
from configs import (kbs_config, from configs import (kbs_config, DEFAULT_VS_TYPE,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY) from server.utils import get_config_models
from server.utils import list_embed_models
import streamlit_antd_components as sac import streamlit_antd_components as sac
import os import os
@ -116,25 +115,11 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
col1, _ = st.columns([3, 1]) col1, _ = st.columns([3, 1])
with col1: with col1:
col1.text("Embedding 模型") embed_models = list(get_config_models(model_type="embed"))
plugins_menu = build_providers_embedding_plugins_name() index = 0
if DEFAULT_EMBEDDING_MODEL in embed_models:
embed_models = list_embed_models() index = embed_models.index(DEFAULT_EMBEDDING_MODEL)
menu_item_children = [] embed_model = st.selectbox("Embeddings模型", embed_models, index)
for model in embed_models:
menu_item_children.append(sac.MenuItem(model, description=model))
plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children))
items, _ = ParseItems(plugins_menu).multi()
if len(plugins_menu) > 0:
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False)
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
set_embed_select(plugins_info, llm_model_worker)
else:
st.info("没有可用的插件")
submit_create_kb = st.form_submit_button( submit_create_kb = st.form_submit_button(
"新建", "新建",
@ -143,23 +128,17 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
) )
if submit_create_kb: if submit_create_kb:
endpoint_host, select_embed_model_name = get_select_embed_endpoint()
if not kb_name or not kb_name.strip(): if not kb_name or not kb_name.strip():
st.error(f"知识库名称不能为空!") st.error(f"知识库名称不能为空!")
elif kb_name in kb_list: elif kb_name in kb_list:
st.error(f"名为 {kb_name} 的知识库已经存在!") st.error(f"名为 {kb_name} 的知识库已经存在!")
elif select_embed_model_name is None: elif embed_model is None:
st.error(f"请选择Embedding模型") st.error(f"请选择Embedding模型")
else: else:
ret = api.create_knowledge_base( ret = api.create_knowledge_base(
knowledge_base_name=kb_name, knowledge_base_name=kb_name,
vector_store_type=vs_type, vector_store_type=vs_type,
embed_model=select_embed_model_name, embed_model=embed_model,
endpoint_host=endpoint_host,
endpoint_host_key=OPENAI_KEY,
endpoint_host_proxy=OPENAI_PROXY,
) )
st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name st.session_state["selected_kb_name"] = kb_name
@ -169,9 +148,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
elif selected_kb: elif selected_kb:
kb = selected_kb kb = selected_kb
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host']
st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key']
st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy']
# 上传文件 # 上传文件
files = st.file_uploader("上传知识文件:", files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
@ -185,37 +161,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.session_state["selected_kb_info"] = kb_info st.session_state["selected_kb_info"] = kb_info
api.update_kb_info(kb, kb_info) api.update_kb_info(kb, kb_info)
if st.session_state["kb_endpoint_host"] is not None:
with st.expander(
"在线api接入点配置",
expanded=True,
):
endpoint_host = st.text_input(
"接入点地址",
placeholder="接入点地址",
key="endpoint_host",
value=st.session_state["kb_endpoint_host"],
)
endpoint_host_key = st.text_input(
"接入点key",
placeholder="接入点key",
key="endpoint_host_key",
value=st.session_state["kb_endpoint_host_key"],
)
endpoint_host_proxy = st.text_input(
"接入点代理地址",
placeholder="接入点代理地址",
key="endpoint_host_proxy",
value=st.session_state["kb_endpoint_host_proxy"],
)
if endpoint_host != st.session_state["kb_endpoint_host"] \
or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \
or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]:
st.session_state["kb_endpoint_host"] = endpoint_host
st.session_state["kb_endpoint_host_key"] = endpoint_host_key
st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy
api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy)
# with st.sidebar: # with st.sidebar:
with st.expander( with st.expander(
"文件处理配置", "文件处理配置",

View File

@ -4,7 +4,7 @@
from typing import * from typing import *
from pathlib import Path from pathlib import Path
from configs import ( from configs import (
EMBEDDING_MODEL, DEFAULT_EMBEDDING_MODEL,
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
SCORE_THRESHOLD, SCORE_THRESHOLD,
@ -266,7 +266,6 @@ class ApiRequest:
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
chat_model_config: Dict = None, chat_model_config: Dict = None,
openai_config: Dict = None,
tool_config: Dict = None, tool_config: Dict = None,
**kwargs, **kwargs,
): ):
@ -281,7 +280,6 @@ class ApiRequest:
"history": history, "history": history,
"stream": stream, "stream": stream,
"chat_model_config": chat_model_config, "chat_model_config": chat_model_config,
"openai_config": openai_config,
"tool_config": tool_config, "tool_config": tool_config,
} }
@ -381,10 +379,7 @@ class ApiRequest:
self, self,
knowledge_base_name: str, knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE, vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None
): ):
''' '''
对应api.py/knowledge_base/create_knowledge_base接口 对应api.py/knowledge_base/create_knowledge_base接口
@ -393,9 +388,6 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type, "vector_store_type": vector_store_type,
"embed_model": embed_model, "embed_model": embed_model,
"endpoint_host": endpoint_host,
"endpoint_host_key": endpoint_host_key,
"endpoint_host_proxy": endpoint_host_proxy,
} }
response = self.post( response = self.post(
@ -459,24 +451,6 @@ class ApiRequest:
) )
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_docs_by_id(
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
) -> bool:
'''
对应api.py/knowledge_base/update_docs_by_id接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"docs": docs,
}
response = self.post(
"/knowledge_base/update_docs_by_id",
json=data
)
return self._get_response_value(response)
def upload_kb_docs( def upload_kb_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
@ -562,26 +536,6 @@ class ApiRequest:
) )
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_kb_endpoint(self,
knowledge_base_name,
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None):
'''
对应api.py/knowledge_base/update_info接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"endpoint_host": endpoint_host,
"endpoint_host_key": endpoint_host_key,
"endpoint_host_proxy": endpoint_host_proxy,
}
response = self.post(
"/knowledge_base/update_kb_endpoint",
json=data,
)
return self._get_response_value(response, as_json=True)
def update_kb_docs( def update_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
@ -621,7 +575,7 @@ class ApiRequest:
knowledge_base_name: str, knowledge_base_name: str,
allow_empty_kb: bool = True, allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE, vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
@ -650,7 +604,7 @@ class ApiRequest:
def embed_texts( def embed_texts(
self, self,
texts: List[str], texts: List[str],
embed_model: str = EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
to_query: bool = False, to_query: bool = False,
) -> List[List[float]]: ) -> List[List[float]]:
''' '''