mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
修改模型配置方式,所有模型以 openai 兼容框架的形式接入,chatchat 自身不再加载模型。
改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码 修改依赖文件,移除 torch transformers 等重依赖 暂时移出对 loom 的集成 后续: 1、优化目录结构 2、检查合并中有无被覆盖的 0.2.10 内容
This commit is contained in:
parent
988a0e6ad2
commit
5d422ca9a1
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import langchain
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
||||
# 是否显示详细日志
|
||||
@ -11,6 +11,16 @@ langchain.verbose = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# 用户数据根目录
|
||||
DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data")
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.mkdir(DATA_PATH)
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
@ -19,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media")
|
||||
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
||||
if not os.path.exists(MEDIA_PATH):
|
||||
os.mkdir(MEDIA_PATH)
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
||||
@ -32,9 +42,6 @@ if not os.path.exists(MEDIA_PATH):
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
||||
if os.path.isdir(BASE_TEMP_DIR):
|
||||
shutil.rmtree(BASE_TEMP_DIR)
|
||||
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
|
||||
|
||||
MEDIA_PATH = None
|
||||
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
||||
if not os.path.exists(BASE_TEMP_DIR):
|
||||
os.mkdir(BASE_TEMP_DIR)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
# 默认使用的知识库
|
||||
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||
|
||||
# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
|
||||
@ -44,6 +44,7 @@ KB_INFO = {
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
os.mkdir(KB_ROOT_PATH)
|
||||
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
|
||||
@ -1,10 +1,25 @@
|
||||
import os
|
||||
|
||||
MODEL_ROOT_PATH = ""
|
||||
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
|
||||
# 默认选用的 LLM 名称
|
||||
DEFAULT_LLM_MODEL = "chatglm3-6b"
|
||||
|
||||
# 默认选用的 Embedding 名称
|
||||
DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5"
|
||||
|
||||
|
||||
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
|
||||
Agent_MODEL = None
|
||||
|
||||
# 历史对话轮数
|
||||
HISTORY_LEN = 3
|
||||
|
||||
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
|
||||
MAX_TOKENS = None
|
||||
|
||||
# LLM通用对话参数
|
||||
TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
SUPPORT_AGENT_MODELS = [
|
||||
"chatglm3-6b",
|
||||
@ -12,97 +27,103 @@ SUPPORT_AGENT_MODELS = [
|
||||
"Qwen-14B-Chat",
|
||||
"Qwen-7B-Chat",
|
||||
]
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
# 意图识别不需要输出,模型后台知道就行
|
||||
"preprocess_model": {
|
||||
# "Mixtral-8x7B-v0.1": {
|
||||
# "temperature": 0.01,
|
||||
# "max_tokens": 5,
|
||||
# "prompt_name": "default",
|
||||
# "callbacks": False
|
||||
# },
|
||||
"chatglm3-6b": {
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.05,
|
||||
"max_tokens": 4096,
|
||||
"history_len": 100,
|
||||
"prompt_name": "default",
|
||||
"callbacks": False
|
||||
},
|
||||
},
|
||||
"llm_model": {
|
||||
# "Mixtral-8x7B-v0.1": {
|
||||
# "temperature": 0.9,
|
||||
# "max_tokens": 4000,
|
||||
# "history_len": 5,
|
||||
# "prompt_name": "default",
|
||||
# "callbacks": True
|
||||
# },
|
||||
"chatglm3-6b": {
|
||||
"temperature": 0.05,
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.9,
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "default",
|
||||
"history_len": 10,
|
||||
"prompt_name": "default",
|
||||
"callbacks": True
|
||||
},
|
||||
},
|
||||
"action_model": {
|
||||
# "Qwen-14B-Chat": {
|
||||
# "temperature": 0.05,
|
||||
# "max_tokens": 4096,
|
||||
# "prompt_name": "qwen",
|
||||
# "callbacks": True
|
||||
# },
|
||||
"chatglm3-6b": {
|
||||
"temperature": 0.05,
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "ChatGLM3",
|
||||
"callbacks": True
|
||||
},
|
||||
# "zhipu-api": {
|
||||
# "temperature": 0.01,
|
||||
# "max_tokens": 4096,
|
||||
# "prompt_name": "ChatGLM3",
|
||||
# "callbacks": True
|
||||
# }
|
||||
|
||||
},
|
||||
},
|
||||
"postprocess_model": {
|
||||
"zhipu-api": {
|
||||
DEFAULT_LLM_MODEL: {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "default",
|
||||
"callbacks": True
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MODEL_PATH = {
|
||||
"embed_model": {
|
||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||
"text2vec-base": "shibing624/text2vec-base-chinese",
|
||||
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
||||
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
|
||||
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
|
||||
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
|
||||
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
|
||||
"m3e-small": "moka-ai/m3e-small",
|
||||
"m3e-base": "moka-ai/m3e-base",
|
||||
"m3e-large": "moka-ai/m3e-large",
|
||||
"bge-small-zh": "BAAI/bge-small-zh",
|
||||
"bge-base-zh": "BAAI/bge-base-zh",
|
||||
"bge-large-zh": "BAAI/bge-large-zh",
|
||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
|
||||
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
|
||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
||||
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
|
||||
"text-embedding-ada-002": "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",
|
||||
}
|
||||
"image_model": {
|
||||
"sd-turbo": {
|
||||
"size": "256*256",
|
||||
}
|
||||
},
|
||||
"multimodal_model": {
|
||||
"qwen-vl": {}
|
||||
},
|
||||
}
|
||||
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
LOOM_CONFIG = "./loom.yaml"
|
||||
OPENAI_KEY = None
|
||||
OPENAI_PROXY = None
|
||||
# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
||||
MODEL_PLATFORMS = [
|
||||
{
|
||||
"platform_name": "openai-api",
|
||||
"platform_type": "openai",
|
||||
"llm_models": [
|
||||
"gpt-3.5-turbo",
|
||||
],
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-",
|
||||
"api_proxy": "",
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "xinference",
|
||||
"platform_type": "xinference",
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
"embed_models": [
|
||||
"bge-large-zh-v1.5",
|
||||
],
|
||||
"image_models": [
|
||||
"sd-turbo",
|
||||
],
|
||||
"multimodal_models": [
|
||||
"qwen-vl",
|
||||
],
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "oneapi",
|
||||
"platform_type": "oneapi",
|
||||
"api_key": "",
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "loom",
|
||||
"platform_type": "loom",
|
||||
"api_key": "",
|
||||
"llm_models": [
|
||||
"chatglm3-6b",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
|
||||
|
||||
@ -98,7 +98,25 @@ PROMPT_TEMPLATES = {
|
||||
'Begin!\n\n'
|
||||
'Question: {input}\n\n'
|
||||
'{agent_scratchpad}\n\n',
|
||||
|
||||
"structured-chat-agent":
|
||||
'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n'
|
||||
'{tools}\n\n'
|
||||
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n'
|
||||
'Valid "action" values: "Final Answer" or {tool_names}\n\n'
|
||||
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
||||
'```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n'
|
||||
'Follow this format:\n\n'
|
||||
'Question: input question to answer\n'
|
||||
'Thought: consider previous and subsequent steps\n'
|
||||
'Action:\n```\n$JSON_BLOB\n```\n'
|
||||
'Observation: action result\n'
|
||||
'... (repeat Thought/Action/Observation N times)\n'
|
||||
'Thought: I know what to respond\n'
|
||||
'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n'
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n'
|
||||
'{input}\n\n'
|
||||
'{agent_scratchpad}\n\n'
|
||||
# '(reminder to respond in a JSON blob no matter what)'
|
||||
},
|
||||
"postprocess_model": {
|
||||
"default": "{{input}}",
|
||||
@ -130,7 +148,7 @@ TOOL_CONFIG = {
|
||||
"bing": {
|
||||
"result_len": 3,
|
||||
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
|
||||
"bing_key": "680a39347d7242c5bd2d7a9576a125b7",
|
||||
"bing_key": "",
|
||||
},
|
||||
"metaphor": {
|
||||
"result_len": 3,
|
||||
@ -184,4 +202,8 @@ TOOL_CONFIG = {
|
||||
"device": "cuda:2"
|
||||
},
|
||||
|
||||
"text2images": {
|
||||
"use": False,
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from configs.model_config import LLM_DEVICE
|
||||
|
||||
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
@ -11,10 +11,11 @@ OPEN_CROSS_DOMAIN = True
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||||
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7870,
|
||||
"port": 8501,
|
||||
}
|
||||
|
||||
# api.py server
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
'''
|
||||
该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding
|
||||
该功能的实现是通过修改embedding模型的tokenizer来实现的
|
||||
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型
|
||||
感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR
|
||||
|
||||
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
||||
'''
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
import os
|
||||
import torch
|
||||
|
||||
from datetime import datetime
|
||||
from configs import (
|
||||
MODEL_PATH,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_KEYWORD_FILE,
|
||||
)
|
||||
|
||||
from safetensors.torch import save_model
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.0",
|
||||
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
|
||||
removal="0.3.0"
|
||||
)
|
||||
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
||||
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
||||
input_ids = tokenizer_output['input_ids']
|
||||
input_ids = input_ids[:, 1:-1]
|
||||
|
||||
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
||||
keyword_embedding = torch.mean(keyword_embedding, 1)
|
||||
return keyword_embedding
|
||||
|
||||
|
||||
def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
|
||||
key_words = []
|
||||
with open(keyword_file, "r") as f:
|
||||
for line in f:
|
||||
key_words.append(line.strip())
|
||||
|
||||
st_model = SentenceTransformer(model_name)
|
||||
key_words_len = len(key_words)
|
||||
word_embedding_model = st_model._first_module()
|
||||
bert_model = word_embedding_model.auto_model
|
||||
tokenizer = word_embedding_model.tokenizer
|
||||
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
||||
|
||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||
embedding_weight_len = len(embedding_weight)
|
||||
tokenizer.add_tokens(key_words)
|
||||
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||
with torch.no_grad():
|
||||
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
||||
|
||||
if output_model_path:
|
||||
os.makedirs(output_model_path, exist_ok=True)
|
||||
word_embedding_model.save(output_model_path)
|
||||
safetensors_file = os.path.join(output_model_path, "model.safetensors")
|
||||
metadata = {'format': 'pt'}
|
||||
save_model(bert_model, safetensors_file, metadata)
|
||||
print("save model to {}".format(output_model_path))
|
||||
|
||||
|
||||
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
|
||||
keyword_file = os.path.join(path)
|
||||
model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
|
||||
model_parent_directory = os.path.dirname(model_name)
|
||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
||||
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
||||
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
||||
@ -1,3 +0,0 @@
|
||||
Langchain-Chatchat
|
||||
数据科学与大数据技术
|
||||
人工智能与先进计算
|
||||
@ -2,9 +2,7 @@ import sys
|
||||
sys.path.append(".")
|
||||
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@ -19,7 +17,7 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help=('''
|
||||
recreate vector store.
|
||||
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed.
|
||||
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/DEFAULT_EMBEDDING_MODEL changed.
|
||||
'''
|
||||
)
|
||||
)
|
||||
@ -87,7 +85,7 @@ if __name__ == "__main__":
|
||||
"-e",
|
||||
"--embed-model",
|
||||
type=str,
|
||||
default=EMBEDDING_MODEL,
|
||||
default=DEFAULT_EMBEDDING_MODEL,
|
||||
help=("specify embeddings model.")
|
||||
)
|
||||
|
||||
|
||||
@ -1,57 +0,0 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.350
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
openai~=1.9.0
|
||||
fastapi~=0.109.0
|
||||
sse_starlette==1.8.2
|
||||
nltk>=3.8.1
|
||||
uvicorn>=0.27.0.post1
|
||||
starlette~=0.35.0
|
||||
unstructured[all-docs]==0.11.0
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
faiss-cpu~=1.7.4
|
||||
accelerate~=0.24.1
|
||||
spacy~=3.7.2
|
||||
PyMuPDF~=1.23.8
|
||||
rapidocr_onnxruntime==1.3.8
|
||||
requests~=2.31.0
|
||||
pathlib~=1.0.1
|
||||
pytest~=7.4.3
|
||||
numexpr~=2.8.6
|
||||
strsimpy~=0.2.1
|
||||
markdownify~=0.11.6
|
||||
tiktoken~=0.5.2
|
||||
tqdm>=4.66.1
|
||||
websockets>=12.0
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
einops>=0.7.0
|
||||
transformers_stream_generator==0.0.4
|
||||
vllm==0.2.7; sys_platform == "linux"
|
||||
httpx==0.26.0
|
||||
httpx_sse==0.4.0
|
||||
llama-index==0.9.35
|
||||
pyjwt==2.8.0
|
||||
|
||||
# jq==1.6.0
|
||||
# beautifulsoup4~=4.12.2
|
||||
# pysrt~=1.1.2
|
||||
# dashscope==1.13.6
|
||||
# arxiv~=2.1.0
|
||||
# youtube-search~=2.1.2
|
||||
# duckduckgo-search~=3.9.9
|
||||
# metaphor-python~=0.1.23
|
||||
|
||||
# volcengine>=1.0.119
|
||||
# pymilvus==2.3.6
|
||||
# psycopg2==2.9.9
|
||||
# pgvector>=0.2.4
|
||||
# chromadb==0.4.13
|
||||
|
||||
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
|
||||
#autoawq==0.1.8 # For Int4
|
||||
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
|
||||
@ -1,42 +0,0 @@
|
||||
langchain==0.0.354
|
||||
langchain-experimental==0.0.47
|
||||
pydantic==1.10.13
|
||||
fschat~=0.2.35
|
||||
openai~=1.9.0
|
||||
fastapi~=0.109.0
|
||||
sse_starlette~=1.8.2
|
||||
nltk~=3.8.1
|
||||
uvicorn>=0.27.0.post1
|
||||
starlette~=0.35.0
|
||||
unstructured[all-docs]~=0.12.0
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy~=2.0.25
|
||||
faiss-cpu~=1.7.4
|
||||
accelerate~=0.24.1
|
||||
spacy~=3.7.2
|
||||
PyMuPDF~=1.23.16
|
||||
rapidocr_onnxruntime~=1.3.8
|
||||
requests~=2.31.0
|
||||
pathlib~=1.0.1
|
||||
pytest~=7.4.3
|
||||
llama-index==0.9.35
|
||||
pyjwt==2.8.0
|
||||
httpx==0.26.0
|
||||
httpx_sse==0.4.0
|
||||
|
||||
dashscope==1.13.6
|
||||
arxiv~=2.1.0
|
||||
youtube-search~=2.1.2
|
||||
duckduckgo-search~=3.9.9
|
||||
metaphor-python~=0.1.23
|
||||
watchdog~=3.0.0
|
||||
# volcengine>=1.0.119
|
||||
# pymilvus>=2.3.4
|
||||
# psycopg2==2.9.9
|
||||
# pgvector>=0.2.4
|
||||
# chromadb==0.4.13
|
||||
|
||||
# jq==1.6.0
|
||||
# beautifulsoup4~=4.12.2
|
||||
# pysrt~=1.1.2
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
streamlit==1.30.0
|
||||
streamlit-option-menu==0.3.12
|
||||
streamlit-antd-components==0.3.1
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-modal==0.1.0
|
||||
streamlit-aggrid==0.3.4.post3
|
||||
httpx==0.26.0
|
||||
httpx_sse==0.4.0
|
||||
watchdog=s=3.0.0
|
||||
@ -6,7 +6,7 @@ from typing import List
|
||||
import uuid
|
||||
|
||||
from langchain.agents import tool
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
from langchain.pydantic_v1 import Field
|
||||
import openai
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
import nltk
|
||||
import sys
|
||||
import os
|
||||
|
||||
from server.knowledge_base.kb_doc_api import update_kb_endpoint
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs import VERSION, MEDIA_PATH
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
@ -18,14 +14,11 @@ from starlette.responses import RedirectResponse
|
||||
from server.chat.chat import chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.embeddings.core.embeddings_api import embed_texts_endpoint
|
||||
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
from typing import List, Literal
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
async def document():
|
||||
return RedirectResponse(url="/docs")
|
||||
@ -95,11 +88,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
summary="要求llm模型补全(通过LLMChain)",
|
||||
)(completion)
|
||||
|
||||
app.post("/other/embed_texts",
|
||||
tags=["Other"],
|
||||
summary="将文本向量化,支持本地模型和在线模型",
|
||||
)(embed_texts_endpoint)
|
||||
|
||||
# 媒体文件
|
||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||
|
||||
@ -109,8 +97,7 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithVSId, update_info,
|
||||
update_docs_by_id,)
|
||||
search_docs, update_info)
|
||||
|
||||
app.post("/chat/file_chat",
|
||||
tags=["Knowledge Base Management"],
|
||||
@ -146,13 +133,6 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/update_docs_by_id",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="直接更新知识库文档"
|
||||
)(update_docs_by_id)
|
||||
|
||||
|
||||
app.post("/knowledge_base/upload_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
@ -171,12 +151,6 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
summary="更新知识库介绍"
|
||||
)(update_info)
|
||||
|
||||
app.post("/knowledge_base/update_kb_endpoint",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新知识库在线api接入点配置"
|
||||
)(update_kb_endpoint)
|
||||
|
||||
app.post("/knowledge_base/update_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, List, Union, Dict, Annotated
|
||||
from typing import AsyncIterable, List
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||
|
||||
|
||||
def create_models_from_config(configs, openai_config, callbacks, stream):
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
models = {}
|
||||
@ -30,9 +30,6 @@ def create_models_from_config(configs, openai_config, callbacks, stream):
|
||||
for model_name, params in model_configs.items():
|
||||
callbacks = callbacks if params.get('callbacks', False) else None
|
||||
model_instance = get_ChatOpenAI(
|
||||
endpoint_host=openai_config.get('endpoint_host', None),
|
||||
endpoint_host_key=openai_config.get('endpoint_host_key', None),
|
||||
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
|
||||
model_name=model_name,
|
||||
temperature=params.get('temperature', 0.5),
|
||||
max_tokens=params.get('max_tokens', 1000),
|
||||
@ -116,7 +113,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
),
|
||||
stream: bool = Body(True, description="流式输出"),
|
||||
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
||||
openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]),
|
||||
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
@ -129,7 +125,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
||||
openai_config=openai_config, stream=stream)
|
||||
stream=stream)
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||
full_chain = create_models_chains(prompts=prompts,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODEL_CONFIG
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from server.utils import wrap_done, get_OpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
@ -14,9 +13,6 @@ from server.utils import get_prompt_template
|
||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -27,9 +23,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
|
||||
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
async def completion_iterator(query: str,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str = None,
|
||||
prompt_name: str = prompt_name,
|
||||
echo: bool = echo,
|
||||
@ -40,9 +33,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
max_tokens = None
|
||||
|
||||
model = get_OpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@ -72,10 +62,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(completion_iterator(query=query,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
return EventSourceResponse(completion_iterator(query=query,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from fastapi import Body, File, Form, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.embeddings.adapter import load_temp_adapter_embeddings
|
||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||
from server.utils import (wrap_done, get_ChatOpenAI, get_Embeddings,
|
||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||
from langchain.chains import LLMChain
|
||||
@ -57,9 +56,6 @@ def _parse_files_in_thread(
|
||||
|
||||
|
||||
def upload_temp_docs(
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
prev_id: str = Form(None, description="前知识库ID"),
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
@ -86,11 +82,7 @@ def upload_temp_docs(
|
||||
else:
|
||||
failed_files.append({file: msg})
|
||||
|
||||
with memo_faiss_pool.load_vector_store(kb_name=id,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
).acquire() as vs:
|
||||
with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs:
|
||||
vs.add_documents(documents)
|
||||
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
||||
|
||||
@ -110,9 +102,6 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -131,17 +120,12 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
embed_func = get_Embeddings()
|
||||
embeddings = await embed_func.aembed_query(query)
|
||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
|
||||
@ -12,9 +12,6 @@ class KnowledgeBaseModel(Base):
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||
vs_type = Column(String(50), comment='向量库类型')
|
||||
endpoint_host = Column(String(50), comment='接入点地址')
|
||||
endpoint_host_key = Column(String(50), comment='接入点key')
|
||||
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
|
||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||
file_count = Column(Integer, default=0, comment='文件数量')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
@ -3,22 +3,16 @@ from server.db.session import with_session
|
||||
|
||||
|
||||
@with_session
|
||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None,
|
||||
endpoint_host_key: str = None, endpoint_host_proxy: str = None):
|
||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
|
||||
# 创建知识库实例
|
||||
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||
if not kb:
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model,
|
||||
endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
|
||||
session.add(kb)
|
||||
else: # update kb with new vs_type and embed_model
|
||||
kb.kb_info = kb_info
|
||||
kb.vs_type = vs_type
|
||||
kb.embed_model = embed_model
|
||||
kb.endpoint_host = endpoint_host
|
||||
kb.endpoint_host_key = endpoint_host_key
|
||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
||||
return True
|
||||
|
||||
|
||||
@ -54,16 +48,6 @@ def delete_kb_from_db(session, kb_name):
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||
if kb:
|
||||
kb.endpoint_host = endpoint_host
|
||||
kb.endpoint_host_key = endpoint_host_key
|
||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def get_kb_detail(session, kb_name: str) -> dict:
|
||||
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||
@ -72,9 +56,6 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
||||
"kb_name": kb.kb_name,
|
||||
"kb_info": kb.kb_info,
|
||||
"vs_type": kb.vs_type,
|
||||
"endpoint_host": kb.endpoint_host,
|
||||
"endpoint_host_key": kb.endpoint_host_key,
|
||||
"endpoint_host_proxy": kb.endpoint_host_proxy,
|
||||
"embed_model": kb.embed_model,
|
||||
"file_count": kb.file_count,
|
||||
"create_time": kb.create_time,
|
||||
|
||||
@ -1,130 +0,0 @@
|
||||
import numpy as np
|
||||
from typing import List, Union, Dict
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
|
||||
from server.utils import embedding_device
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
_endpoint_host: str
|
||||
_endpoint_host_key: str
|
||||
_endpoint_host_proxy: str
|
||||
|
||||
def __init__(self,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self._endpoint_host = endpoint_host
|
||||
self._endpoint_host_key = endpoint_host_key
|
||||
self._endpoint_host_proxy = endpoint_host_proxy
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embed_texts(texts=texts,
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=False).data
|
||||
return self._normalize(embeddings=embeddings).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
embeddings = embed_texts(texts=[text],
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts,
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=False)).data
|
||||
return self._normalize(embeddings=embeddings).tolist()
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text],
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
@staticmethod
|
||||
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
'''
|
||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||
#TODO 此处内容处理错误
|
||||
'''
|
||||
norm = np.linalg.norm(embeddings, axis=1)
|
||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||
return np.divide(embeddings, norm)
|
||||
|
||||
|
||||
def load_kb_adapter_embeddings(
|
||||
kb_name: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> "EmbeddingsFunAdapter":
|
||||
"""
|
||||
加载知识库配置的Embeddings模型
|
||||
本地模型最终会通过load_embeddings加载
|
||||
在线模型会在适配器中直接返回
|
||||
:param kb_name:
|
||||
:param embed_device:
|
||||
:param default_embed_model:
|
||||
:return:
|
||||
"""
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
|
||||
kb_detail = get_kb_detail(kb_name)
|
||||
embed_model = kb_detail.get("embed_model", default_embed_model)
|
||||
endpoint_host = kb_detail.get("endpoint_host", None)
|
||||
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
|
||||
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
|
||||
|
||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model)
|
||||
|
||||
|
||||
def load_temp_adapter_embeddings(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> "EmbeddingsFunAdapter":
|
||||
"""
|
||||
加载临时的Embeddings模型
|
||||
本地模型最终会通过load_embeddings加载
|
||||
在线模型会在适配器中直接返回
|
||||
:param endpoint_host:
|
||||
:param endpoint_host_key:
|
||||
:param endpoint_host_proxy:
|
||||
:param embed_device:
|
||||
:param default_embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=default_embed_model)
|
||||
@ -1,94 +0,0 @@
|
||||
from langchain.docstore.document import Document
|
||||
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
|
||||
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: List[str],
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
TODO: 也许需要加入缓存机制,减少 token 消耗
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
# 使用在线API
|
||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy):
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model=embed_model,
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
chunk_size=CHUNK_SIZE)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
async def aembed_texts(
|
||||
texts: List[str],
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||
|
||||
# 使用在线API
|
||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy):
|
||||
return await run_in_threadpool(embed_texts,
|
||||
texts=texts,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model,
|
||||
to_query=to_query)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
return embed_texts(texts=texts,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model, to_query=to_query)
|
||||
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
from configs import EMBEDDING_MODEL, logger, log_verbose
|
||||
from configs import DEFAULT_EMBEDDING_MODEL, logger, log_verbose
|
||||
from fastapi import Body
|
||||
|
||||
|
||||
@ -14,10 +14,7 @@ def list_kbs():
|
||||
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
@ -31,7 +28,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
try:
|
||||
kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
import threading
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
from configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
@ -98,50 +97,3 @@ class CachePool:
|
||||
else:
|
||||
return cache
|
||||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
|
||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||
"""
|
||||
本地Embeddings模型加载
|
||||
:param model:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = embedding_device()
|
||||
key = (model, device)
|
||||
if not self.get(key):
|
||||
item = ThreadSafeObject(key, pool=self)
|
||||
self.set(key, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if 'bge-' in model:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
if 'zh' in model:
|
||||
# for chinese model
|
||||
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||
elif 'en' in model:
|
||||
# for english model
|
||||
query_instruction = "Represent this sentence for searching relevant passages:"
|
||||
else:
|
||||
# maybe ReRanker or else, just use empty string instead
|
||||
query_instruction = ""
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device},
|
||||
query_instruction=query_instruction)
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device})
|
||||
item.obj = embeddings
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(key).obj
|
||||
|
||||
|
||||
embeddings_pool = EmbeddingsPool(cache_num=1)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
|
||||
from server.knowledge_base.kb_cache.base import *
|
||||
# from server.utils import load_local_embeddings
|
||||
from server.utils import get_Embeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
@ -53,13 +52,11 @@ class _FaissPool(CachePool):
|
||||
def new_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> FAISS:
|
||||
|
||||
# create an empty vector store
|
||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
@ -68,18 +65,11 @@ class _FaissPool(CachePool):
|
||||
|
||||
def new_temp_vector_store(
|
||||
self,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> FAISS:
|
||||
|
||||
# create an empty vector store
|
||||
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
@ -102,8 +92,7 @@ class KBFaissPool(_FaissPool):
|
||||
kb_name: str,
|
||||
vector_name: str = None,
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
vector_name = vector_name or embed_model
|
||||
@ -118,15 +107,13 @@ class KBFaissPool(_FaissPool):
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = self.new_vector_store(kb_name=kb_name,
|
||||
embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store = self.new_vector_store(kb_name=kb_name, embed_model=embed_model)
|
||||
vector_store.save_local(vs_path)
|
||||
else:
|
||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||
@ -148,11 +135,7 @@ class MemoFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
@ -163,10 +146,7 @@ class MemoFaissPool(_FaissPool):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||
# create an empty vector store
|
||||
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store = self.new_temp_vector_store(embed_model=embed_model)
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose, )
|
||||
@ -42,22 +42,6 @@ def search_docs(
|
||||
return data
|
||||
|
||||
|
||||
def update_docs_by_id(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
按照文档 ID 更新文档内容
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
|
||||
if kb.update_doc_by_ids(docs=docs):
|
||||
return BaseResponse(msg=f"文档更新成功")
|
||||
else:
|
||||
return BaseResponse(msg=f"文档更新失败")
|
||||
|
||||
|
||||
def list_files(
|
||||
knowledge_base_name: str
|
||||
) -> ListResponse:
|
||||
@ -230,26 +214,6 @@ def update_info(
|
||||
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||||
|
||||
|
||||
def update_kb_endpoint(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
||||
|
||||
return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成",
|
||||
data={"endpoint_host": endpoint_host,
|
||||
"endpoint_host_key": endpoint_host_key,
|
||||
"endpoint_host_proxy": endpoint_host_proxy})
|
||||
|
||||
|
||||
def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||
@ -366,10 +330,7 @@ def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
@ -389,9 +350,7 @@ def recreate_vector_store(
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
@ -15,7 +15,7 @@ from server.db.repository.knowledge_file_repository import (
|
||||
)
|
||||
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
@ -40,7 +40,7 @@ class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
@ -58,20 +58,14 @@ class KBService(ABC):
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None):
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
|
||||
if status:
|
||||
self.do_create_kb()
|
||||
@ -144,16 +138,6 @@ class KBService(ABC):
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def update_kb_endpoint(self,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None):
|
||||
"""
|
||||
更新知识库在线api接入点配置
|
||||
"""
|
||||
status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
@ -297,7 +281,7 @@ class KBServiceFactory:
|
||||
@staticmethod
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
from typing import List
|
||||
import os
|
||||
import shutil
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import load_local_embeddings
|
||||
from server.utils import get_Embeddings
|
||||
from elasticsearch import Elasticsearch,BadRequestError
|
||||
from configs import logger
|
||||
from configs import kbs_config
|
||||
@ -22,7 +20,7 @@ class ESKBService(KBService):
|
||||
self.user = kbs_config[self.vs_type()].get("user",'')
|
||||
self.password = kbs_config[self.vs_type()].get("password",'')
|
||||
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
|
||||
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
self.embeddings_model = get_Embeddings(self.embed_model)
|
||||
try:
|
||||
# ES python客户端连接(仅连接)
|
||||
if self.user != "" and self.password != "":
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -21,7 +21,7 @@ class KBSummaryService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.embed_model = embed_model
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from fastapi import Body
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||
OVERLAP_SIZE,
|
||||
logger, log_verbose, )
|
||||
from server.knowledge_base.utils import (list_files_from_folder)
|
||||
@ -17,11 +17,8 @@ def recreate_summary_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -29,9 +26,6 @@ def recreate_summary_vector_store(
|
||||
"""
|
||||
重建单个知识库文件摘要
|
||||
:param max_tokens:
|
||||
:param endpoint_host:
|
||||
:param endpoint_host_key:
|
||||
:param endpoint_host_proxy:
|
||||
:param model_name:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
@ -54,17 +48,11 @@ def recreate_summary_vector_store(
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@ -110,20 +98,14 @@ def summary_file_to_vector_store(
|
||||
file_name: str = Body(..., examples=["test.pdf"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
单个知识库根据文件名称摘要
|
||||
:param endpoint_host:
|
||||
:param endpoint_host_key:
|
||||
:param endpoint_host_proxy:
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
:param temperature:
|
||||
@ -146,17 +128,11 @@ def summary_file_to_vector_store(
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@ -194,11 +170,8 @@ def summary_doc_ids_to_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_ids: List = Body([], examples=[["uuid"]]),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -206,9 +179,6 @@ def summary_doc_ids_to_vector_store(
|
||||
"""
|
||||
单个知识库根据doc_ids摘要
|
||||
:param knowledge_base_name:
|
||||
:param endpoint_host:
|
||||
:param endpoint_host_key:
|
||||
:param endpoint_host_proxy:
|
||||
:param doc_ids:
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
@ -223,17 +193,11 @@ def summary_doc_ids_to_vector_store(
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
|
||||
else:
|
||||
llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from configs import (
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE,
|
||||
logger, log_verbose
|
||||
)
|
||||
@ -86,7 +86,7 @@ def folder2db(
|
||||
kb_names: List[str],
|
||||
mode: Literal["recreate_vs", "update_in_db", "increment"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from configs import (
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
@ -143,6 +144,7 @@ def get_LoaderClass(file_extension):
|
||||
if file_extension in extensions:
|
||||
return LoaderClass
|
||||
|
||||
|
||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
@ -184,6 +186,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
return loader
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def make_text_splitter(
|
||||
splitter_name,
|
||||
chunk_size,
|
||||
|
||||
363
server/localai_embeddings.py
Normal file
363
server/localai_embeddings.py
Normal file
@ -0,0 +1,363 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.Timeout)
|
||||
| retry_if_exception_type(openai.APIError)
|
||||
| retry_if_exception_type(openai.APIConnectionError)
|
||||
| retry_if_exception_type(openai.RateLimitError)
|
||||
| retry_if_exception_type(openai.InternalServerError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
|
||||
import openai
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
async_retrying = AsyncRetrying(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.Timeout)
|
||||
| retry_if_exception_type(openai.APIError)
|
||||
| retry_if_exception_type(openai.APIConnectionError)
|
||||
| retry_if_exception_type(openai.RateLimitError)
|
||||
| retry_if_exception_type(openai.InternalServerError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
def wrap(func: Callable) -> Callable:
|
||||
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
|
||||
async for _ in async_retrying:
|
||||
return await func(*args, **kwargs)
|
||||
raise AssertionError("this is unreachable")
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
|
||||
def _check_response(response: dict) -> dict:
|
||||
if any([len(d.embedding) == 1 for d in response.data]):
|
||||
import openai
|
||||
|
||||
raise openai.APIError("LocalAI API returned an empty embedding")
|
||||
return response
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = embeddings.client.create(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
|
||||
@_async_retry_decorator(embeddings)
|
||||
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = await embeddings.async_client.acreate(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
"""LocalAI embedding models.
|
||||
|
||||
Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
|
||||
uses the ``openai`` Python package's ``openai.Embedding`` as its client.
|
||||
Thus, you should have the ``openai`` python package installed, and defeat
|
||||
the environment variable ``OPENAI_API_KEY`` by setting to a random string.
|
||||
You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
|
||||
service endpoint.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import LocalAIEmbeddings
|
||||
openai = LocalAIEmbeddings(
|
||||
openai_api_key="random-string",
|
||||
openai_api_base="http://localhost:8080"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = "text-embedding-ada-002"
|
||||
deployment: str = model
|
||||
openai_api_version: Optional[str] = None
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
# to support explicit proxy for LocalAI
|
||||
openai_proxy: Optional[str] = None
|
||||
embedding_ctx_length: int = 8191
|
||||
"""The maximum number of tokens to embed at once."""
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout in seconds for the LocalAI request."""
|
||||
headers: Any = None
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
|
||||
default_api_version = ""
|
||||
values["openai_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
default=default_api_version,
|
||||
)
|
||||
values["openai_organization"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_organization",
|
||||
"OPENAI_ORGANIZATION",
|
||||
default="",
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
values["client"] = openai.OpenAI(**client_params).embeddings
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params
|
||||
).embeddings
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.Embedding
|
||||
else:
|
||||
pass
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict:
|
||||
openai_args = {
|
||||
"model": self.model,
|
||||
"timeout": self.request_timeout,
|
||||
"extra_headers": self.headers,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.openai_proxy:
|
||||
import openai
|
||||
|
||||
openai.proxy = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
} # type: ignore[assignment] # noqa: E501
|
||||
return openai_args
|
||||
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
).data[0].embedding
|
||||
|
||||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
await async_embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)
|
||||
).data[0].embedding
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
) -> List[List[float]]:
|
||||
"""Call out to LocalAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# call _embedding_func for each text
|
||||
return [self._embedding_func(text, engine=self.deployment) for text in texts]
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
) -> List[List[float]]:
|
||||
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = await self._aembedding_func(text, engine=self.deployment)
|
||||
embeddings.append(response)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = await self._aembedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
@ -109,14 +109,13 @@ if __name__ == "__main__":
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH)
|
||||
from server.utils import embedding_device
|
||||
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model = LangchainReranker(top_n=3,
|
||||
device=embedding_device(),
|
||||
device="cpu",
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
|
||||
269
server/utils.py
269
server/utils.py
@ -1,32 +1,28 @@
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||
HTTPX_DEFAULT_TIMEOUT)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_openai.llms import OpenAI
|
||||
import httpx
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
Optional,
|
||||
Callable,
|
||||
Generator,
|
||||
Dict,
|
||||
List,
|
||||
Any,
|
||||
Awaitable,
|
||||
Union,
|
||||
Tuple
|
||||
Tuple,
|
||||
Literal,
|
||||
)
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL
|
||||
from server.minx_chat_openai import MinxChatOpenAI
|
||||
|
||||
|
||||
@ -44,10 +40,66 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
event.set()
|
||||
|
||||
|
||||
def get_config_models(
|
||||
model_name: str = None,
|
||||
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||
platform_name: str = None,
|
||||
) -> Dict[str, Dict]:
|
||||
'''
|
||||
获取配置的模型列表,返回值为:
|
||||
{model_name: {
|
||||
"platform_name": xx,
|
||||
"platform_type": xx,
|
||||
"model_type": xx,
|
||||
"model_name": xx,
|
||||
"api_base_url": xx,
|
||||
"api_key": xx,
|
||||
"api_proxy": xx,
|
||||
}}
|
||||
'''
|
||||
import importlib
|
||||
from configs import model_config
|
||||
importlib.reload(model_config)
|
||||
|
||||
result = {}
|
||||
for m in model_config.MODEL_PLATFORMS:
|
||||
if platform_name is not None and platform_name != m.get("platform_name"):
|
||||
continue
|
||||
if model_type is not None and f"{model_type}_models" not in m:
|
||||
continue
|
||||
|
||||
if model_type is None:
|
||||
model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"]
|
||||
else:
|
||||
model_types = [f"{model_type}_models"]
|
||||
|
||||
for m_type in model_types:
|
||||
for m_name in m.get(m_type, []):
|
||||
if model_name is None or model_name == m_name:
|
||||
result[m_name] = {
|
||||
"platform_name": m.get("platform_name"),
|
||||
"platform_type": m.get("platform_type"),
|
||||
"model_type": m_type.split("_")[0],
|
||||
"model_name": m_name,
|
||||
"api_base_url": m.get("api_base_url"),
|
||||
"api_key": m.get("api_key"),
|
||||
"api_proxy": m.get("api_proxy"),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def get_model_info(model_name: str, platform_name: str = None) -> Dict:
|
||||
'''
|
||||
获取配置的模型信息,主要是 api_base_url, api_key
|
||||
'''
|
||||
result = get_config_models(model_name=model_name, platform_name=platform_name)
|
||||
if len(result) > 0:
|
||||
return list(result.values())[0]
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
@ -56,29 +108,23 @@ def get_ChatOpenAI(
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
|
||||
model_info = get_model_info(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_OpenAI(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
@ -89,22 +135,40 @@ def get_OpenAI(
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
# TODO: 从API获取模型信息
|
||||
model_info = get_model_info(model_name)
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
openai_api_key=model_info.get("api_key"),
|
||||
openai_api_base=model_info.get("api_base_url"),
|
||||
openai_proxy=model_info.get("api_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||
|
||||
model_info = get_model_info(model_name=embed_model)
|
||||
params = {
|
||||
"model": embed_model,
|
||||
"base_url": model_info.get("api_base_url"),
|
||||
"api_key": model_info.get("api_key"),
|
||||
"openai_proxy": model_info.get("api_proxy"),
|
||||
}
|
||||
if model_info.get("platform_type") == "openai":
|
||||
return OpenAIEmbeddings(**params)
|
||||
else:
|
||||
return LocalAIEmbeddings(**params)
|
||||
|
||||
|
||||
class MsgType:
|
||||
TEXT = 1
|
||||
IMAGE = 2
|
||||
@ -113,9 +177,9 @@ class MsgType:
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
data: Any = pydantic.Field(None, description="API data")
|
||||
code: int = Field(200, description="API status code")
|
||||
msg: str = Field("success", description="API status message")
|
||||
data: Any = Field(None, description="API data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -127,7 +191,7 @@ class BaseResponse(BaseModel):
|
||||
|
||||
|
||||
class ListResponse(BaseResponse):
|
||||
data: List[str] = pydantic.Field(..., description="List of names")
|
||||
data: List[str] = Field(..., description="List of names")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -140,10 +204,10 @@ class ListResponse(BaseResponse):
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
question: str = pydantic.Field(..., description="Question text")
|
||||
response: str = pydantic.Field(..., description="Response text")
|
||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
||||
source_documents: List[str] = pydantic.Field(
|
||||
question: str = Field(..., description="Question text")
|
||||
response: str = Field(..., description="Response text")
|
||||
history: List[List[str]] = Field(..., description="History text")
|
||||
source_documents: List[str] = Field(
|
||||
..., description="List of source documents and their scores"
|
||||
)
|
||||
|
||||
@ -310,39 +374,40 @@ def MakeFastAPIOffline(
|
||||
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
# TODO: 移出模型加载后,这些功能需要删除或改变实现
|
||||
|
||||
def list_embed_models() -> List[str]:
|
||||
'''
|
||||
get names of configured embedding models
|
||||
'''
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
# def list_embed_models() -> List[str]:
|
||||
# '''
|
||||
# get names of configured embedding models
|
||||
# '''
|
||||
# return list(MODEL_PATH["embed_model"])
|
||||
|
||||
|
||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
if type in MODEL_PATH:
|
||||
paths = MODEL_PATH[type]
|
||||
else:
|
||||
paths = {}
|
||||
for v in MODEL_PATH.values():
|
||||
paths.update(v)
|
||||
# def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
# if type in MODEL_PATH:
|
||||
# paths = MODEL_PATH[type]
|
||||
# else:
|
||||
# paths = {}
|
||||
# for v in MODEL_PATH.values():
|
||||
# paths.update(v)
|
||||
|
||||
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||
path = Path(path_str)
|
||||
if path.is_dir(): # 任意绝对路径
|
||||
return str(path)
|
||||
# if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||
# path = Path(path_str)
|
||||
# if path.is_dir(): # 任意绝对路径
|
||||
# return str(path)
|
||||
|
||||
root_path = Path(MODEL_ROOT_PATH)
|
||||
if root_path.is_dir():
|
||||
path = root_path / model_name
|
||||
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||
return str(path)
|
||||
path = root_path / path_str
|
||||
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||
return str(path)
|
||||
path = root_path / path_str.split("/")[-1]
|
||||
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||
return str(path)
|
||||
return path_str # THUDM/chatglm06b
|
||||
# root_path = Path(MODEL_ROOT_PATH)
|
||||
# if root_path.is_dir():
|
||||
# path = root_path / model_name
|
||||
# if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||
# return str(path)
|
||||
# path = root_path / path_str
|
||||
# if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||
# return str(path)
|
||||
# path = root_path / path_str.split("/")[-1]
|
||||
# if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||
# return str(path)
|
||||
# return path_str # THUDM/chatglm06b
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
@ -429,37 +494,6 @@ def set_httpx_config(
|
||||
urllib.request.getproxies = _get_proxies
|
||||
|
||||
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.get_device_properties(0):
|
||||
return "xpu"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
# if device.isdigit():
|
||||
# return "cuda:" + device
|
||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
device = device or EMBEDDING_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
@ -546,56 +580,19 @@ def get_server_configs() -> Dict:
|
||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||
|
||||
|
||||
def list_online_embed_models(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str
|
||||
) -> List[str]:
|
||||
ret = []
|
||||
# TODO: 从在线API获取支持的模型列表
|
||||
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {endpoint_host_key}",
|
||||
}
|
||||
resp = client.get("/models", headers=headers)
|
||||
if resp.status_code == 200:
|
||||
models = resp.json().get("data", [])
|
||||
for model in models:
|
||||
if "embedding" in model.get("id", None):
|
||||
ret.append(model.get("id"))
|
||||
|
||||
except Exception as e:
|
||||
msg = f"获取在线Embeddings模型列表失败:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
finally:
|
||||
client.close()
|
||||
return ret
|
||||
|
||||
|
||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
from configs import EMBEDDING_MODEL
|
||||
|
||||
model = model or EMBEDDING_MODEL
|
||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||
|
||||
|
||||
def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||
'''
|
||||
创建一个临时目录,返回(路径,文件夹名称)
|
||||
'''
|
||||
from configs.basic_config import BASE_TEMP_DIR
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
path = os.path.join(BASE_TEMP_DIR, id)
|
||||
if os.path.isdir(path):
|
||||
return path, id
|
||||
|
||||
path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
|
||||
return path, os.path.basename(path)
|
||||
id = uuid.uuid4().hex
|
||||
path = os.path.join(BASE_TEMP_DIR, id)
|
||||
os.mkdir(path)
|
||||
return path, id
|
||||
|
||||
45
startup.py
45
startup.py
@ -1,12 +1,11 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Process
|
||||
from datetime import datetime
|
||||
from pprint import pprint
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
@ -17,38 +16,29 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs import (
|
||||
LOG_PATH,
|
||||
log_verbose,
|
||||
logger,
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
API_SERVER,
|
||||
WEBUI_SERVER,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from server.utils import (FastAPI, embedding_device)
|
||||
from server.utils import FastAPI
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
from configs import VERSION
|
||||
|
||||
all_model_names = set()
|
||||
for model_category in LLM_MODEL_CONFIG.values():
|
||||
for model_name in model_category.keys():
|
||||
if model_name not in all_model_names:
|
||||
all_model_names.add(model_name)
|
||||
|
||||
all_model_names_list = list(all_model_names)
|
||||
|
||||
|
||||
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
yield
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
|
||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
@ -159,7 +149,7 @@ def dump_server_info(after_start=False, args=None):
|
||||
|
||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
@ -232,13 +222,13 @@ async def start_main_server():
|
||||
return len(processes)
|
||||
|
||||
loom_started = manager.Event()
|
||||
process = Process(
|
||||
target=run_loom,
|
||||
name=f"run_loom Server",
|
||||
kwargs=dict(started_event=loom_started),
|
||||
daemon=True,
|
||||
)
|
||||
processes["run_loom"] = process
|
||||
# process = Process(
|
||||
# target=run_loom,
|
||||
# name=f"run_loom Server",
|
||||
# kwargs=dict(started_event=loom_started),
|
||||
# daemon=True,
|
||||
# )
|
||||
# processes["run_loom"] = process
|
||||
api_started = manager.Event()
|
||||
if args.api:
|
||||
process = Process(
|
||||
@ -283,7 +273,6 @@ async def start_main_server():
|
||||
|
||||
# 等待所有进程退出
|
||||
if p := processes.get("webui"):
|
||||
|
||||
p.join()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@ -306,9 +295,7 @@ async def start_main_server():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
create_tables()
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
|
||||
31
webui.py
31
webui.py
@ -1,7 +1,7 @@
|
||||
import streamlit as st
|
||||
|
||||
from webui_pages.loom_view_client import update_store
|
||||
from webui_pages.openai_plugins import openai_plugins_page
|
||||
# from webui_pages.loom_view_client import update_store
|
||||
# from webui_pages.openai_plugins import openai_plugins_page
|
||||
from webui_pages.utils import *
|
||||
from streamlit_option_menu import option_menu
|
||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||
@ -12,9 +12,9 @@ from configs import VERSION
|
||||
from server.utils import api_address
|
||||
|
||||
|
||||
def on_change(key):
|
||||
if key:
|
||||
update_store()
|
||||
# def on_change(key):
|
||||
# if key:
|
||||
# update_store()
|
||||
|
||||
|
||||
api = ApiRequest(base_url=api_address())
|
||||
@ -59,18 +59,18 @@ if __name__ == "__main__":
|
||||
"icon": "hdd-stack",
|
||||
"func": knowledge_base_page,
|
||||
},
|
||||
"模型服务": {
|
||||
"icon": "hdd-stack",
|
||||
"func": openai_plugins_page,
|
||||
},
|
||||
# "模型服务": {
|
||||
# "icon": "hdd-stack",
|
||||
# "func": openai_plugins_page,
|
||||
# },
|
||||
}
|
||||
# 更新状态
|
||||
if "status" not in st.session_state \
|
||||
or "run_plugins_list" not in st.session_state \
|
||||
or "launch_subscribe_info" not in st.session_state \
|
||||
or "list_running_models" not in st.session_state \
|
||||
or "model_config" not in st.session_state:
|
||||
update_store()
|
||||
# if "status" not in st.session_state \
|
||||
# or "run_plugins_list" not in st.session_state \
|
||||
# or "launch_subscribe_info" not in st.session_state \
|
||||
# or "list_running_models" not in st.session_state \
|
||||
# or "model_config" not in st.session_state:
|
||||
# update_store()
|
||||
|
||||
with st.sidebar:
|
||||
st.image(
|
||||
@ -95,7 +95,6 @@ if __name__ == "__main__":
|
||||
icons=icons,
|
||||
# menu_icon="chat-quote",
|
||||
default_index=default_index,
|
||||
on_change=on_change,
|
||||
)
|
||||
|
||||
if selected_page in pages:
|
||||
|
||||
@ -4,8 +4,8 @@ import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from webui_pages.dialogue.utils import process_files
|
||||
from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
get_select_model_endpoint
|
||||
# from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
# get_select_model_endpoint
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
@ -13,9 +13,9 @@ from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY)
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG)
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from server.utils import MsgType
|
||||
from server.utils import MsgType, get_config_models
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
@ -111,8 +111,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
st.session_state.setdefault("file_chat_id", None)
|
||||
st.session_state.setdefault("select_plugins_info", None)
|
||||
st.session_state.setdefault("select_model_worker", None)
|
||||
|
||||
# 弹出自定义命令帮助信息
|
||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||
@ -131,18 +129,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
with st.expander("模型选择"):
|
||||
plugins_menu = build_providers_model_plugins_name()
|
||||
|
||||
items, _ = ParseItems(plugins_menu).multi()
|
||||
|
||||
if len(plugins_menu) > 0:
|
||||
|
||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
|
||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
||||
set_llm_select(plugins_info, llm_model_worker)
|
||||
else:
|
||||
st.info("没有可用的插件")
|
||||
platforms = [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||
platform = st.selectbox("选择模型平台", platforms, 1)
|
||||
llm_models = list(get_config_models(model_type="llm", platform_name=platform))
|
||||
llm_model = st.selectbox("选择LLM模型", llm_models)
|
||||
|
||||
# 传入后端的内容
|
||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
@ -174,10 +164,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||
|
||||
llm_model = None
|
||||
if st.session_state["select_model_worker"] is not None:
|
||||
llm_model = st.session_state["select_model_worker"]['label']
|
||||
|
||||
if llm_model is not None:
|
||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
|
||||
@ -200,23 +186,23 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
chat_box.output_messages()
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
def on_feedback(
|
||||
feedback,
|
||||
message_id: str = "",
|
||||
history_index: int = -1,
|
||||
):
|
||||
# def on_feedback(
|
||||
# feedback,
|
||||
# message_id: str = "",
|
||||
# history_index: int = -1,
|
||||
# ):
|
||||
|
||||
reason = feedback["text"]
|
||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
api.chat_feedback(message_id=message_id,
|
||||
score=score_int,
|
||||
reason=reason)
|
||||
st.session_state["need_rerun"] = True
|
||||
# reason = feedback["text"]
|
||||
# score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
# api.chat_feedback(message_id=message_id,
|
||||
# score=score_int,
|
||||
# reason=reason)
|
||||
# st.session_state["need_rerun"] = True
|
||||
|
||||
feedback_kwargs = {
|
||||
"feedback_type": "thumbs",
|
||||
"optional_text_label": "欢迎反馈您打分的理由",
|
||||
}
|
||||
# feedback_kwargs = {
|
||||
# "feedback_type": "thumbs",
|
||||
# "optional_text_label": "欢迎反馈您打分的理由",
|
||||
# }
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||
@ -244,17 +230,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
text_action = ""
|
||||
element_index = 0
|
||||
|
||||
openai_config = {}
|
||||
endpoint_host, select_model_name = get_select_model_endpoint()
|
||||
openai_config["endpoint_host"] = endpoint_host
|
||||
openai_config["model_name"] = select_model_name
|
||||
openai_config["endpoint_host_key"] = OPENAI_KEY
|
||||
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
chat_model_config=chat_model_config,
|
||||
openai_config=openai_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_config=selected_tool_configs,
|
||||
):
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
|
||||
set_llm_select, set_embed_select, get_select_embed_endpoint
|
||||
# from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
|
||||
# set_llm_select, set_embed_select, get_select_embed_endpoint
|
||||
from webui_pages.utils import *
|
||||
from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
@ -10,10 +10,9 @@ import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs import (kbs_config,
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY)
|
||||
from server.utils import list_embed_models
|
||||
from configs import (kbs_config, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.utils import get_config_models
|
||||
|
||||
import streamlit_antd_components as sac
|
||||
import os
|
||||
@ -116,25 +115,11 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
|
||||
col1, _ = st.columns([3, 1])
|
||||
with col1:
|
||||
col1.text("Embedding 模型")
|
||||
plugins_menu = build_providers_embedding_plugins_name()
|
||||
|
||||
embed_models = list_embed_models()
|
||||
menu_item_children = []
|
||||
for model in embed_models:
|
||||
menu_item_children.append(sac.MenuItem(model, description=model))
|
||||
|
||||
plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children))
|
||||
|
||||
items, _ = ParseItems(plugins_menu).multi()
|
||||
|
||||
if len(plugins_menu) > 0:
|
||||
|
||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False)
|
||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
||||
set_embed_select(plugins_info, llm_model_worker)
|
||||
else:
|
||||
st.info("没有可用的插件")
|
||||
embed_models = list(get_config_models(model_type="embed"))
|
||||
index = 0
|
||||
if DEFAULT_EMBEDDING_MODEL in embed_models:
|
||||
index = embed_models.index(DEFAULT_EMBEDDING_MODEL)
|
||||
embed_model = st.selectbox("Embeddings模型", embed_models, index)
|
||||
|
||||
submit_create_kb = st.form_submit_button(
|
||||
"新建",
|
||||
@ -143,23 +128,17 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
)
|
||||
|
||||
if submit_create_kb:
|
||||
|
||||
endpoint_host, select_embed_model_name = get_select_embed_endpoint()
|
||||
if not kb_name or not kb_name.strip():
|
||||
st.error(f"知识库名称不能为空!")
|
||||
elif kb_name in kb_list:
|
||||
st.error(f"名为 {kb_name} 的知识库已经存在!")
|
||||
elif select_embed_model_name is None:
|
||||
elif embed_model is None:
|
||||
st.error(f"请选择Embedding模型!")
|
||||
else:
|
||||
|
||||
ret = api.create_knowledge_base(
|
||||
knowledge_base_name=kb_name,
|
||||
vector_store_type=vs_type,
|
||||
embed_model=select_embed_model_name,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=OPENAI_KEY,
|
||||
endpoint_host_proxy=OPENAI_PROXY,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
@ -169,9 +148,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
elif selected_kb:
|
||||
kb = selected_kb
|
||||
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
||||
st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host']
|
||||
st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key']
|
||||
st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy']
|
||||
# 上传文件
|
||||
files = st.file_uploader("上传知识文件:",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
@ -185,37 +161,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
api.update_kb_info(kb, kb_info)
|
||||
|
||||
if st.session_state["kb_endpoint_host"] is not None:
|
||||
with st.expander(
|
||||
"在线api接入点配置",
|
||||
expanded=True,
|
||||
):
|
||||
endpoint_host = st.text_input(
|
||||
"接入点地址",
|
||||
placeholder="接入点地址",
|
||||
key="endpoint_host",
|
||||
value=st.session_state["kb_endpoint_host"],
|
||||
)
|
||||
endpoint_host_key = st.text_input(
|
||||
"接入点key",
|
||||
placeholder="接入点key",
|
||||
key="endpoint_host_key",
|
||||
value=st.session_state["kb_endpoint_host_key"],
|
||||
)
|
||||
endpoint_host_proxy = st.text_input(
|
||||
"接入点代理地址",
|
||||
placeholder="接入点代理地址",
|
||||
key="endpoint_host_proxy",
|
||||
value=st.session_state["kb_endpoint_host_proxy"],
|
||||
)
|
||||
if endpoint_host != st.session_state["kb_endpoint_host"] \
|
||||
or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \
|
||||
or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]:
|
||||
st.session_state["kb_endpoint_host"] = endpoint_host
|
||||
st.session_state["kb_endpoint_host_key"] = endpoint_host_key
|
||||
st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy
|
||||
api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
||||
|
||||
# with st.sidebar:
|
||||
with st.expander(
|
||||
"文件处理配置",
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_VS_TYPE,
|
||||
LLM_MODEL_CONFIG,
|
||||
SCORE_THRESHOLD,
|
||||
@ -266,7 +266,6 @@ class ApiRequest:
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
chat_model_config: Dict = None,
|
||||
openai_config: Dict = None,
|
||||
tool_config: Dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -281,7 +280,6 @@ class ApiRequest:
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"chat_model_config": chat_model_config,
|
||||
"openai_config": openai_config,
|
||||
"tool_config": tool_config,
|
||||
}
|
||||
|
||||
@ -381,10 +379,7 @@ class ApiRequest:
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/create_knowledge_base接口
|
||||
@ -393,9 +388,6 @@ class ApiRequest:
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"vector_store_type": vector_store_type,
|
||||
"embed_model": embed_model,
|
||||
"endpoint_host": endpoint_host,
|
||||
"endpoint_host_key": endpoint_host_key,
|
||||
"endpoint_host_proxy": endpoint_host_proxy,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
@ -459,24 +451,6 @@ class ApiRequest:
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
def update_docs_by_id(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
docs: Dict[str, Dict],
|
||||
) -> bool:
|
||||
'''
|
||||
对应api.py/knowledge_base/update_docs_by_id接口
|
||||
'''
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"docs": docs,
|
||||
}
|
||||
response = self.post(
|
||||
"/knowledge_base/update_docs_by_id",
|
||||
json=data
|
||||
)
|
||||
return self._get_response_value(response)
|
||||
|
||||
def upload_kb_docs(
|
||||
self,
|
||||
files: List[Union[str, Path, bytes]],
|
||||
@ -562,26 +536,6 @@ class ApiRequest:
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
def update_kb_endpoint(self,
|
||||
knowledge_base_name,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None):
|
||||
'''
|
||||
对应api.py/knowledge_base/update_info接口
|
||||
'''
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"endpoint_host": endpoint_host,
|
||||
"endpoint_host_key": endpoint_host_key,
|
||||
"endpoint_host_proxy": endpoint_host_proxy,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
"/knowledge_base/update_kb_endpoint",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
def update_kb_docs(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
@ -621,7 +575,7 @@ class ApiRequest:
|
||||
knowledge_base_name: str,
|
||||
allow_empty_kb: bool = True,
|
||||
vs_type: str = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
@ -650,7 +604,7 @@ class ApiRequest:
|
||||
def embed_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> List[List[float]]:
|
||||
'''
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user