mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 22:33:24 +08:00
修改模型配置方式,所有模型以 openai 兼容框架的形式接入,chatchat 自身不再加载模型。
改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码 修改依赖文件,移除 torch transformers 等重依赖 暂时移出对 loom 的集成 后续: 1、优化目录结构 2、检查合并中有无被覆盖的 0.2.10 内容
This commit is contained in:
parent
988a0e6ad2
commit
5d422ca9a1
@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
# 是否显示详细日志
|
# 是否显示详细日志
|
||||||
@ -11,6 +11,16 @@ langchain.verbose = False
|
|||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
# 通常情况下不需要更改以下内容
|
||||||
|
|
||||||
|
# 用户数据根目录
|
||||||
|
DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data")
|
||||||
|
if not os.path.exists(DATA_PATH):
|
||||||
|
os.mkdir(DATA_PATH)
|
||||||
|
|
||||||
|
# nltk 模型存储路径
|
||||||
|
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
||||||
|
import nltk
|
||||||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
# 日志格式
|
# 日志格式
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@ -19,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT)
|
|||||||
|
|
||||||
|
|
||||||
# 日志存储路径
|
# 日志存储路径
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
||||||
if not os.path.exists(LOG_PATH):
|
if not os.path.exists(LOG_PATH):
|
||||||
os.mkdir(LOG_PATH)
|
os.mkdir(LOG_PATH)
|
||||||
|
|
||||||
# 模型生成内容(图片、视频、音频等)保存位置
|
# 模型生成内容(图片、视频、音频等)保存位置
|
||||||
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media")
|
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
||||||
if not os.path.exists(MEDIA_PATH):
|
if not os.path.exists(MEDIA_PATH):
|
||||||
os.mkdir(MEDIA_PATH)
|
os.mkdir(MEDIA_PATH)
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
||||||
@ -32,9 +42,6 @@ if not os.path.exists(MEDIA_PATH):
|
|||||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
||||||
|
|
||||||
# 临时文件目录,主要用于文件对话
|
# 临时文件目录,主要用于文件对话
|
||||||
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
||||||
if os.path.isdir(BASE_TEMP_DIR):
|
if not os.path.exists(BASE_TEMP_DIR):
|
||||||
shutil.rmtree(BASE_TEMP_DIR)
|
os.mkdir(BASE_TEMP_DIR)
|
||||||
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
MEDIA_PATH = None
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
# 默认使用的知识库
|
||||||
DEFAULT_KNOWLEDGE_BASE = "samples"
|
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||||
|
|
||||||
# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
|
# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
|
||||||
@ -44,6 +44,7 @@ KB_INFO = {
|
|||||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
if not os.path.exists(KB_ROOT_PATH):
|
||||||
os.mkdir(KB_ROOT_PATH)
|
os.mkdir(KB_ROOT_PATH)
|
||||||
|
|
||||||
# 数据库默认存储路径。
|
# 数据库默认存储路径。
|
||||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
|
|||||||
@ -1,10 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
MODEL_ROOT_PATH = ""
|
|
||||||
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
|
# 默认选用的 LLM 名称
|
||||||
EMBEDDING_DEVICE = "auto"
|
DEFAULT_LLM_MODEL = "chatglm3-6b"
|
||||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
|
||||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
# 默认选用的 Embedding 名称
|
||||||
|
DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5"
|
||||||
|
|
||||||
|
|
||||||
|
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
|
||||||
|
Agent_MODEL = None
|
||||||
|
|
||||||
|
# 历史对话轮数
|
||||||
|
HISTORY_LEN = 3
|
||||||
|
|
||||||
|
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
|
||||||
|
MAX_TOKENS = None
|
||||||
|
|
||||||
|
# LLM通用对话参数
|
||||||
|
TEMPERATURE = 0.7
|
||||||
|
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||||
|
|
||||||
SUPPORT_AGENT_MODELS = [
|
SUPPORT_AGENT_MODELS = [
|
||||||
"chatglm3-6b",
|
"chatglm3-6b",
|
||||||
@ -12,97 +27,103 @@ SUPPORT_AGENT_MODELS = [
|
|||||||
"Qwen-14B-Chat",
|
"Qwen-14B-Chat",
|
||||||
"Qwen-7B-Chat",
|
"Qwen-7B-Chat",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
|
# 意图识别不需要输出,模型后台知道就行
|
||||||
"preprocess_model": {
|
"preprocess_model": {
|
||||||
# "Mixtral-8x7B-v0.1": {
|
DEFAULT_LLM_MODEL: {
|
||||||
# "temperature": 0.01,
|
|
||||||
# "max_tokens": 5,
|
|
||||||
# "prompt_name": "default",
|
|
||||||
# "callbacks": False
|
|
||||||
# },
|
|
||||||
"chatglm3-6b": {
|
|
||||||
"temperature": 0.05,
|
"temperature": 0.05,
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
"history_len": 100,
|
||||||
"prompt_name": "default",
|
"prompt_name": "default",
|
||||||
"callbacks": False
|
"callbacks": False
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"llm_model": {
|
"llm_model": {
|
||||||
# "Mixtral-8x7B-v0.1": {
|
DEFAULT_LLM_MODEL: {
|
||||||
# "temperature": 0.9,
|
"temperature": 0.9,
|
||||||
# "max_tokens": 4000,
|
|
||||||
# "history_len": 5,
|
|
||||||
# "prompt_name": "default",
|
|
||||||
# "callbacks": True
|
|
||||||
# },
|
|
||||||
"chatglm3-6b": {
|
|
||||||
"temperature": 0.05,
|
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"prompt_name": "default",
|
|
||||||
"history_len": 10,
|
"history_len": 10,
|
||||||
|
"prompt_name": "default",
|
||||||
"callbacks": True
|
"callbacks": True
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"action_model": {
|
"action_model": {
|
||||||
# "Qwen-14B-Chat": {
|
DEFAULT_LLM_MODEL: {
|
||||||
# "temperature": 0.05,
|
"temperature": 0.01,
|
||||||
# "max_tokens": 4096,
|
|
||||||
# "prompt_name": "qwen",
|
|
||||||
# "callbacks": True
|
|
||||||
# },
|
|
||||||
"chatglm3-6b": {
|
|
||||||
"temperature": 0.05,
|
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"prompt_name": "ChatGLM3",
|
"prompt_name": "ChatGLM3",
|
||||||
"callbacks": True
|
"callbacks": True
|
||||||
},
|
},
|
||||||
# "zhipu-api": {
|
},
|
||||||
# "temperature": 0.01,
|
|
||||||
# "max_tokens": 4096,
|
|
||||||
# "prompt_name": "ChatGLM3",
|
|
||||||
# "callbacks": True
|
|
||||||
# }
|
|
||||||
|
|
||||||
},
|
|
||||||
"postprocess_model": {
|
"postprocess_model": {
|
||||||
"zhipu-api": {
|
DEFAULT_LLM_MODEL: {
|
||||||
"temperature": 0.01,
|
"temperature": 0.01,
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"prompt_name": "default",
|
"prompt_name": "default",
|
||||||
"callbacks": True
|
"callbacks": True
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
"image_model": {
|
||||||
|
"sd-turbo": {
|
||||||
|
"size": "256*256",
|
||||||
MODEL_PATH = {
|
}
|
||||||
"embed_model": {
|
},
|
||||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
"multimodal_model": {
|
||||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
"qwen-vl": {}
|
||||||
"text2vec-base": "shibing624/text2vec-base-chinese",
|
},
|
||||||
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
|
||||||
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
|
|
||||||
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
|
|
||||||
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
|
|
||||||
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
|
|
||||||
"m3e-small": "moka-ai/m3e-small",
|
|
||||||
"m3e-base": "moka-ai/m3e-base",
|
|
||||||
"m3e-large": "moka-ai/m3e-large",
|
|
||||||
"bge-small-zh": "BAAI/bge-small-zh",
|
|
||||||
"bge-base-zh": "BAAI/bge-base-zh",
|
|
||||||
"bge-large-zh": "BAAI/bge-large-zh",
|
|
||||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
|
||||||
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
|
|
||||||
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
|
|
||||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
|
||||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
|
||||||
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
|
|
||||||
"text-embedding-ada-002": "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",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
|
||||||
|
MODEL_PLATFORMS = [
|
||||||
LOOM_CONFIG = "./loom.yaml"
|
{
|
||||||
OPENAI_KEY = None
|
"platform_name": "openai-api",
|
||||||
OPENAI_PROXY = None
|
"platform_type": "openai",
|
||||||
|
"llm_models": [
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
],
|
||||||
|
"api_base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-",
|
||||||
|
"api_proxy": "",
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"platform_name": "xinference",
|
||||||
|
"platform_type": "xinference",
|
||||||
|
"llm_models": [
|
||||||
|
"chatglm3-6b",
|
||||||
|
],
|
||||||
|
"embed_models": [
|
||||||
|
"bge-large-zh-v1.5",
|
||||||
|
],
|
||||||
|
"image_models": [
|
||||||
|
"sd-turbo",
|
||||||
|
],
|
||||||
|
"multimodal_models": [
|
||||||
|
"qwen-vl",
|
||||||
|
],
|
||||||
|
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||||
|
"api_key": "EMPTY",
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"platform_name": "oneapi",
|
||||||
|
"platform_type": "oneapi",
|
||||||
|
"api_key": "",
|
||||||
|
"llm_models": [
|
||||||
|
"chatglm3-6b",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"platform_name": "loom",
|
||||||
|
"platform_type": "loom",
|
||||||
|
"api_key": "",
|
||||||
|
"llm_models": [
|
||||||
|
"chatglm3-6b",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
|
||||||
|
|||||||
@ -98,7 +98,25 @@ PROMPT_TEMPLATES = {
|
|||||||
'Begin!\n\n'
|
'Begin!\n\n'
|
||||||
'Question: {input}\n\n'
|
'Question: {input}\n\n'
|
||||||
'{agent_scratchpad}\n\n',
|
'{agent_scratchpad}\n\n',
|
||||||
|
"structured-chat-agent":
|
||||||
|
'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n'
|
||||||
|
'{tools}\n\n'
|
||||||
|
'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n'
|
||||||
|
'Valid "action" values: "Final Answer" or {tool_names}\n\n'
|
||||||
|
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
||||||
|
'```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n'
|
||||||
|
'Follow this format:\n\n'
|
||||||
|
'Question: input question to answer\n'
|
||||||
|
'Thought: consider previous and subsequent steps\n'
|
||||||
|
'Action:\n```\n$JSON_BLOB\n```\n'
|
||||||
|
'Observation: action result\n'
|
||||||
|
'... (repeat Thought/Action/Observation N times)\n'
|
||||||
|
'Thought: I know what to respond\n'
|
||||||
|
'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n'
|
||||||
|
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n'
|
||||||
|
'{input}\n\n'
|
||||||
|
'{agent_scratchpad}\n\n'
|
||||||
|
# '(reminder to respond in a JSON blob no matter what)'
|
||||||
},
|
},
|
||||||
"postprocess_model": {
|
"postprocess_model": {
|
||||||
"default": "{{input}}",
|
"default": "{{input}}",
|
||||||
@ -130,7 +148,7 @@ TOOL_CONFIG = {
|
|||||||
"bing": {
|
"bing": {
|
||||||
"result_len": 3,
|
"result_len": 3,
|
||||||
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
|
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
|
||||||
"bing_key": "680a39347d7242c5bd2d7a9576a125b7",
|
"bing_key": "",
|
||||||
},
|
},
|
||||||
"metaphor": {
|
"metaphor": {
|
||||||
"result_len": 3,
|
"result_len": 3,
|
||||||
@ -184,4 +202,8 @@ TOOL_CONFIG = {
|
|||||||
"device": "cuda:2"
|
"device": "cuda:2"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"text2images": {
|
||||||
|
"use": False,
|
||||||
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
from configs.model_config import LLM_DEVICE
|
|
||||||
|
|
||||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||||
@ -11,10 +11,11 @@ OPEN_CROSS_DOMAIN = True
|
|||||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||||
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
# webui.py server
|
# webui.py server
|
||||||
WEBUI_SERVER = {
|
WEBUI_SERVER = {
|
||||||
"host": DEFAULT_BIND_HOST,
|
"host": DEFAULT_BIND_HOST,
|
||||||
"port": 7870,
|
"port": 8501,
|
||||||
}
|
}
|
||||||
|
|
||||||
# api.py server
|
# api.py server
|
||||||
|
|||||||
@ -1,79 +0,0 @@
|
|||||||
'''
|
|
||||||
该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding
|
|
||||||
该功能的实现是通过修改embedding模型的tokenizer来实现的
|
|
||||||
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型
|
|
||||||
感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR
|
|
||||||
|
|
||||||
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
|
||||||
'''
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append("..")
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from configs import (
|
|
||||||
MODEL_PATH,
|
|
||||||
EMBEDDING_MODEL,
|
|
||||||
EMBEDDING_KEYWORD_FILE,
|
|
||||||
)
|
|
||||||
|
|
||||||
from safetensors.torch import save_model
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from langchain_core._api import deprecated
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
|
||||||
since="0.3.0",
|
|
||||||
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
|
|
||||||
removal="0.3.0"
|
|
||||||
)
|
|
||||||
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
|
||||||
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
|
||||||
input_ids = tokenizer_output['input_ids']
|
|
||||||
input_ids = input_ids[:, 1:-1]
|
|
||||||
|
|
||||||
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
|
||||||
keyword_embedding = torch.mean(keyword_embedding, 1)
|
|
||||||
return keyword_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
|
|
||||||
key_words = []
|
|
||||||
with open(keyword_file, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
key_words.append(line.strip())
|
|
||||||
|
|
||||||
st_model = SentenceTransformer(model_name)
|
|
||||||
key_words_len = len(key_words)
|
|
||||||
word_embedding_model = st_model._first_module()
|
|
||||||
bert_model = word_embedding_model.auto_model
|
|
||||||
tokenizer = word_embedding_model.tokenizer
|
|
||||||
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
|
||||||
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
|
||||||
embedding_weight_len = len(embedding_weight)
|
|
||||||
tokenizer.add_tokens(key_words)
|
|
||||||
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
|
||||||
with torch.no_grad():
|
|
||||||
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
|
||||||
|
|
||||||
if output_model_path:
|
|
||||||
os.makedirs(output_model_path, exist_ok=True)
|
|
||||||
word_embedding_model.save(output_model_path)
|
|
||||||
safetensors_file = os.path.join(output_model_path, "model.safetensors")
|
|
||||||
metadata = {'format': 'pt'}
|
|
||||||
save_model(bert_model, safetensors_file, metadata)
|
|
||||||
print("save model to {}".format(output_model_path))
|
|
||||||
|
|
||||||
|
|
||||||
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
|
|
||||||
keyword_file = os.path.join(path)
|
|
||||||
model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
|
|
||||||
model_parent_directory = os.path.dirname(model_name)
|
|
||||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
|
||||||
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
|
||||||
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
Langchain-Chatchat
|
|
||||||
数据科学与大数据技术
|
|
||||||
人工智能与先进计算
|
|
||||||
@ -2,9 +2,7 @@ import sys
|
|||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||||
folder2db, prune_db_docs, prune_folder_files)
|
folder2db, prune_db_docs, prune_folder_files)
|
||||||
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL
|
||||||
import nltk
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +17,7 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help=('''
|
help=('''
|
||||||
recreate vector store.
|
recreate vector store.
|
||||||
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed.
|
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/DEFAULT_EMBEDDING_MODEL changed.
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -87,7 +85,7 @@ if __name__ == "__main__":
|
|||||||
"-e",
|
"-e",
|
||||||
"--embed-model",
|
"--embed-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=EMBEDDING_MODEL,
|
default=DEFAULT_EMBEDDING_MODEL,
|
||||||
help=("specify embeddings model.")
|
help=("specify embeddings model.")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,57 +0,0 @@
|
|||||||
# API requirements
|
|
||||||
|
|
||||||
langchain>=0.0.350
|
|
||||||
langchain-experimental>=0.0.42
|
|
||||||
pydantic==1.10.13
|
|
||||||
fschat==0.2.35
|
|
||||||
openai~=1.9.0
|
|
||||||
fastapi~=0.109.0
|
|
||||||
sse_starlette==1.8.2
|
|
||||||
nltk>=3.8.1
|
|
||||||
uvicorn>=0.27.0.post1
|
|
||||||
starlette~=0.35.0
|
|
||||||
unstructured[all-docs]==0.11.0
|
|
||||||
python-magic-bin; sys_platform == 'win32'
|
|
||||||
SQLAlchemy==2.0.19
|
|
||||||
faiss-cpu~=1.7.4
|
|
||||||
accelerate~=0.24.1
|
|
||||||
spacy~=3.7.2
|
|
||||||
PyMuPDF~=1.23.8
|
|
||||||
rapidocr_onnxruntime==1.3.8
|
|
||||||
requests~=2.31.0
|
|
||||||
pathlib~=1.0.1
|
|
||||||
pytest~=7.4.3
|
|
||||||
numexpr~=2.8.6
|
|
||||||
strsimpy~=0.2.1
|
|
||||||
markdownify~=0.11.6
|
|
||||||
tiktoken~=0.5.2
|
|
||||||
tqdm>=4.66.1
|
|
||||||
websockets>=12.0
|
|
||||||
numpy~=1.24.4
|
|
||||||
pandas~=2.0.3
|
|
||||||
einops>=0.7.0
|
|
||||||
transformers_stream_generator==0.0.4
|
|
||||||
vllm==0.2.7; sys_platform == "linux"
|
|
||||||
httpx==0.26.0
|
|
||||||
httpx_sse==0.4.0
|
|
||||||
llama-index==0.9.35
|
|
||||||
pyjwt==2.8.0
|
|
||||||
|
|
||||||
# jq==1.6.0
|
|
||||||
# beautifulsoup4~=4.12.2
|
|
||||||
# pysrt~=1.1.2
|
|
||||||
# dashscope==1.13.6
|
|
||||||
# arxiv~=2.1.0
|
|
||||||
# youtube-search~=2.1.2
|
|
||||||
# duckduckgo-search~=3.9.9
|
|
||||||
# metaphor-python~=0.1.23
|
|
||||||
|
|
||||||
# volcengine>=1.0.119
|
|
||||||
# pymilvus==2.3.6
|
|
||||||
# psycopg2==2.9.9
|
|
||||||
# pgvector>=0.2.4
|
|
||||||
# chromadb==0.4.13
|
|
||||||
|
|
||||||
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
|
|
||||||
#autoawq==0.1.8 # For Int4
|
|
||||||
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
langchain==0.0.354
|
|
||||||
langchain-experimental==0.0.47
|
|
||||||
pydantic==1.10.13
|
|
||||||
fschat~=0.2.35
|
|
||||||
openai~=1.9.0
|
|
||||||
fastapi~=0.109.0
|
|
||||||
sse_starlette~=1.8.2
|
|
||||||
nltk~=3.8.1
|
|
||||||
uvicorn>=0.27.0.post1
|
|
||||||
starlette~=0.35.0
|
|
||||||
unstructured[all-docs]~=0.12.0
|
|
||||||
python-magic-bin; sys_platform == 'win32'
|
|
||||||
SQLAlchemy~=2.0.25
|
|
||||||
faiss-cpu~=1.7.4
|
|
||||||
accelerate~=0.24.1
|
|
||||||
spacy~=3.7.2
|
|
||||||
PyMuPDF~=1.23.16
|
|
||||||
rapidocr_onnxruntime~=1.3.8
|
|
||||||
requests~=2.31.0
|
|
||||||
pathlib~=1.0.1
|
|
||||||
pytest~=7.4.3
|
|
||||||
llama-index==0.9.35
|
|
||||||
pyjwt==2.8.0
|
|
||||||
httpx==0.26.0
|
|
||||||
httpx_sse==0.4.0
|
|
||||||
|
|
||||||
dashscope==1.13.6
|
|
||||||
arxiv~=2.1.0
|
|
||||||
youtube-search~=2.1.2
|
|
||||||
duckduckgo-search~=3.9.9
|
|
||||||
metaphor-python~=0.1.23
|
|
||||||
watchdog~=3.0.0
|
|
||||||
# volcengine>=1.0.119
|
|
||||||
# pymilvus>=2.3.4
|
|
||||||
# psycopg2==2.9.9
|
|
||||||
# pgvector>=0.2.4
|
|
||||||
# chromadb==0.4.13
|
|
||||||
|
|
||||||
# jq==1.6.0
|
|
||||||
# beautifulsoup4~=4.12.2
|
|
||||||
# pysrt~=1.1.2
|
|
||||||
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
streamlit==1.30.0
|
|
||||||
streamlit-option-menu==0.3.12
|
|
||||||
streamlit-antd-components==0.3.1
|
|
||||||
streamlit-chatbox==1.1.11
|
|
||||||
streamlit-modal==0.1.0
|
|
||||||
streamlit-aggrid==0.3.4.post3
|
|
||||||
httpx==0.26.0
|
|
||||||
httpx_sse==0.4.0
|
|
||||||
watchdog=s=3.0.0
|
|
||||||
@ -6,7 +6,7 @@ from typing import List
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from langchain.agents import tool
|
from langchain.agents import tool
|
||||||
from pydantic.v1 import BaseModel, Field
|
from langchain.pydantic_v1 import Field
|
||||||
import openai
|
import openai
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,9 @@
|
|||||||
import nltk
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from server.knowledge_base.kb_doc_api import update_kb_endpoint
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from configs import VERSION, MEDIA_PATH
|
from configs import VERSION, MEDIA_PATH
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
|
||||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -18,14 +14,11 @@ from starlette.responses import RedirectResponse
|
|||||||
from server.chat.chat import chat
|
from server.chat.chat import chat
|
||||||
from server.chat.completion import completion
|
from server.chat.completion import completion
|
||||||
from server.chat.feedback import chat_feedback
|
from server.chat.feedback import chat_feedback
|
||||||
from server.embeddings.core.embeddings_api import embed_texts_endpoint
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||||
get_server_configs, get_prompt_template)
|
get_server_configs, get_prompt_template)
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
|
|
||||||
|
|
||||||
async def document():
|
async def document():
|
||||||
return RedirectResponse(url="/docs")
|
return RedirectResponse(url="/docs")
|
||||||
@ -95,11 +88,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
summary="要求llm模型补全(通过LLMChain)",
|
summary="要求llm模型补全(通过LLMChain)",
|
||||||
)(completion)
|
)(completion)
|
||||||
|
|
||||||
app.post("/other/embed_texts",
|
|
||||||
tags=["Other"],
|
|
||||||
summary="将文本向量化,支持本地模型和在线模型",
|
|
||||||
)(embed_texts_endpoint)
|
|
||||||
|
|
||||||
# 媒体文件
|
# 媒体文件
|
||||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||||
|
|
||||||
@ -109,8 +97,7 @@ def mount_knowledge_routes(app: FastAPI):
|
|||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithVSId, update_info,
|
search_docs, update_info)
|
||||||
update_docs_by_id,)
|
|
||||||
|
|
||||||
app.post("/chat/file_chat",
|
app.post("/chat/file_chat",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
@ -146,13 +133,6 @@ def mount_knowledge_routes(app: FastAPI):
|
|||||||
summary="搜索知识库"
|
summary="搜索知识库"
|
||||||
)(search_docs)
|
)(search_docs)
|
||||||
|
|
||||||
app.post("/knowledge_base/update_docs_by_id",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="直接更新知识库文档"
|
|
||||||
)(update_docs_by_id)
|
|
||||||
|
|
||||||
|
|
||||||
app.post("/knowledge_base/upload_docs",
|
app.post("/knowledge_base/upload_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
@ -171,12 +151,6 @@ def mount_knowledge_routes(app: FastAPI):
|
|||||||
summary="更新知识库介绍"
|
summary="更新知识库介绍"
|
||||||
)(update_info)
|
)(update_info)
|
||||||
|
|
||||||
app.post("/knowledge_base/update_kb_endpoint",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="更新知识库在线api接入点配置"
|
|
||||||
)(update_kb_endpoint)
|
|
||||||
|
|
||||||
app.post("/knowledge_base/update_docs",
|
app.post("/knowledge_base/update_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import AsyncIterable, List, Union, Dict, Annotated
|
from typing import AsyncIterable, List
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
|
|||||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||||
|
|
||||||
|
|
||||||
def create_models_from_config(configs, openai_config, callbacks, stream):
|
def create_models_from_config(configs, callbacks, stream):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
configs = {}
|
configs = {}
|
||||||
models = {}
|
models = {}
|
||||||
@ -30,9 +30,6 @@ def create_models_from_config(configs, openai_config, callbacks, stream):
|
|||||||
for model_name, params in model_configs.items():
|
for model_name, params in model_configs.items():
|
||||||
callbacks = callbacks if params.get('callbacks', False) else None
|
callbacks = callbacks if params.get('callbacks', False) else None
|
||||||
model_instance = get_ChatOpenAI(
|
model_instance = get_ChatOpenAI(
|
||||||
endpoint_host=openai_config.get('endpoint_host', None),
|
|
||||||
endpoint_host_key=openai_config.get('endpoint_host_key', None),
|
|
||||||
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=params.get('temperature', 0.5),
|
temperature=params.get('temperature', 0.5),
|
||||||
max_tokens=params.get('max_tokens', 1000),
|
max_tokens=params.get('max_tokens', 1000),
|
||||||
@ -116,7 +113,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
),
|
),
|
||||||
stream: bool = Body(True, description="流式输出"),
|
stream: bool = Body(True, description="流式输出"),
|
||||||
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
|
||||||
openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]),
|
|
||||||
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
tool_config: dict = Body({}, description="工具配置", examples=[]),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
@ -129,7 +125,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
|
||||||
openai_config=openai_config, stream=stream)
|
stream=stream)
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODEL_CONFIG
|
|
||||||
from server.utils import wrap_done, get_OpenAI
|
from server.utils import wrap_done, get_OpenAI
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
@ -14,9 +13,6 @@ from server.utils import get_prompt_template
|
|||||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
@ -27,9 +23,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
|
|
||||||
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||||
async def completion_iterator(query: str,
|
async def completion_iterator(query: str,
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
prompt_name: str = prompt_name,
|
prompt_name: str = prompt_name,
|
||||||
echo: bool = echo,
|
echo: bool = echo,
|
||||||
@ -40,9 +33,6 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
max_tokens = None
|
max_tokens = None
|
||||||
|
|
||||||
model = get_OpenAI(
|
model = get_OpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -72,10 +62,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(completion_iterator(query=query,
|
return EventSourceResponse(completion_iterator(query=query,
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from fastapi import Body, File, Form, UploadFile
|
from fastapi import Body, File, Form, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
from server.embeddings.adapter import load_temp_adapter_embeddings
|
from server.utils import (wrap_done, get_ChatOpenAI, get_Embeddings,
|
||||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
|
||||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
@ -57,9 +56,6 @@ def _parse_files_in_thread(
|
|||||||
|
|
||||||
|
|
||||||
def upload_temp_docs(
|
def upload_temp_docs(
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
prev_id: str = Form(None, description="前知识库ID"),
|
prev_id: str = Form(None, description="前知识库ID"),
|
||||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
@ -86,11 +82,7 @@ def upload_temp_docs(
|
|||||||
else:
|
else:
|
||||||
failed_files.append({file: msg})
|
failed_files.append({file: msg})
|
||||||
|
|
||||||
with memo_faiss_pool.load_vector_store(kb_name=id,
|
with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs:
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
).acquire() as vs:
|
|
||||||
vs.add_documents(documents)
|
vs.add_documents(documents)
|
||||||
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
||||||
|
|
||||||
@ -110,9 +102,6 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
@ -131,17 +120,12 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
max_tokens = None
|
max_tokens = None
|
||||||
|
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
embed_func = get_Embeddings()
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy)
|
|
||||||
embeddings = await embed_func.aembed_query(query)
|
embeddings = await embed_func.aembed_query(query)
|
||||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
|
|||||||
@ -12,9 +12,6 @@ class KnowledgeBaseModel(Base):
|
|||||||
kb_name = Column(String(50), comment='知识库名称')
|
kb_name = Column(String(50), comment='知识库名称')
|
||||||
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||||
vs_type = Column(String(50), comment='向量库类型')
|
vs_type = Column(String(50), comment='向量库类型')
|
||||||
endpoint_host = Column(String(50), comment='接入点地址')
|
|
||||||
endpoint_host_key = Column(String(50), comment='接入点key')
|
|
||||||
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
|
|
||||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||||
file_count = Column(Integer, default=0, comment='文件数量')
|
file_count = Column(Integer, default=0, comment='文件数量')
|
||||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||||
|
|||||||
@ -3,22 +3,16 @@ from server.db.session import with_session
|
|||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None,
|
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
|
||||||
endpoint_host_key: str = None, endpoint_host_proxy: str = None):
|
|
||||||
# 创建知识库实例
|
# 创建知识库实例
|
||||||
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||||
if not kb:
|
if not kb:
|
||||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model,
|
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
|
||||||
endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy)
|
|
||||||
session.add(kb)
|
session.add(kb)
|
||||||
else: # update kb with new vs_type and embed_model
|
else: # update kb with new vs_type and embed_model
|
||||||
kb.kb_info = kb_info
|
kb.kb_info = kb_info
|
||||||
kb.vs_type = vs_type
|
kb.vs_type = vs_type
|
||||||
kb.embed_model = embed_model
|
kb.embed_model = embed_model
|
||||||
kb.endpoint_host = endpoint_host
|
|
||||||
kb.endpoint_host_key = endpoint_host_key
|
|
||||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -54,16 +48,6 @@ def delete_kb_from_db(session, kb_name):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@with_session
|
|
||||||
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
|
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
|
||||||
if kb:
|
|
||||||
kb.endpoint_host = endpoint_host
|
|
||||||
kb.endpoint_host_key = endpoint_host_key
|
|
||||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
def get_kb_detail(session, kb_name: str) -> dict:
|
def get_kb_detail(session, kb_name: str) -> dict:
|
||||||
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||||
@ -72,9 +56,6 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
|||||||
"kb_name": kb.kb_name,
|
"kb_name": kb.kb_name,
|
||||||
"kb_info": kb.kb_info,
|
"kb_info": kb.kb_info,
|
||||||
"vs_type": kb.vs_type,
|
"vs_type": kb.vs_type,
|
||||||
"endpoint_host": kb.endpoint_host,
|
|
||||||
"endpoint_host_key": kb.endpoint_host_key,
|
|
||||||
"endpoint_host_proxy": kb.endpoint_host_proxy,
|
|
||||||
"embed_model": kb.embed_model,
|
"embed_model": kb.embed_model,
|
||||||
"file_count": kb.file_count,
|
"file_count": kb.file_count,
|
||||||
"create_time": kb.create_time,
|
"create_time": kb.create_time,
|
||||||
|
|||||||
@ -1,130 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from typing import List, Union, Dict
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
|
||||||
EMBEDDING_MODEL, KB_INFO)
|
|
||||||
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
|
|
||||||
from server.utils import embedding_device
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsFunAdapter(Embeddings):
|
|
||||||
_endpoint_host: str
|
|
||||||
_endpoint_host_key: str
|
|
||||||
_endpoint_host_proxy: str
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
):
|
|
||||||
self._endpoint_host = endpoint_host
|
|
||||||
self._endpoint_host_key = endpoint_host_key
|
|
||||||
self._endpoint_host_proxy = endpoint_host_proxy
|
|
||||||
self.embed_model = embed_model
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
embeddings = embed_texts(texts=texts,
|
|
||||||
endpoint_host=self._endpoint_host,
|
|
||||||
endpoint_host_key=self._endpoint_host_key,
|
|
||||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
|
||||||
embed_model=self.embed_model,
|
|
||||||
to_query=False).data
|
|
||||||
return self._normalize(embeddings=embeddings).tolist()
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
embeddings = embed_texts(texts=[text],
|
|
||||||
endpoint_host=self._endpoint_host,
|
|
||||||
endpoint_host_key=self._endpoint_host_key,
|
|
||||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
|
||||||
embed_model=self.embed_model,
|
|
||||||
to_query=True).data
|
|
||||||
query_embed = embeddings[0]
|
|
||||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
|
||||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
embeddings = (await aembed_texts(texts=texts,
|
|
||||||
endpoint_host=self._endpoint_host,
|
|
||||||
endpoint_host_key=self._endpoint_host_key,
|
|
||||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
|
||||||
embed_model=self.embed_model,
|
|
||||||
to_query=False)).data
|
|
||||||
return self._normalize(embeddings=embeddings).tolist()
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
embeddings = (await aembed_texts(texts=[text],
|
|
||||||
endpoint_host=self._endpoint_host,
|
|
||||||
endpoint_host_key=self._endpoint_host_key,
|
|
||||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
|
||||||
embed_model=self.embed_model,
|
|
||||||
to_query=True)).data
|
|
||||||
query_embed = embeddings[0]
|
|
||||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
|
||||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
|
|
||||||
'''
|
|
||||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
|
||||||
#TODO 此处内容处理错误
|
|
||||||
'''
|
|
||||||
norm = np.linalg.norm(embeddings, axis=1)
|
|
||||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
|
||||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
|
||||||
return np.divide(embeddings, norm)
|
|
||||||
|
|
||||||
|
|
||||||
def load_kb_adapter_embeddings(
|
|
||||||
kb_name: str,
|
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
default_embed_model: str = EMBEDDING_MODEL,
|
|
||||||
) -> "EmbeddingsFunAdapter":
|
|
||||||
"""
|
|
||||||
加载知识库配置的Embeddings模型
|
|
||||||
本地模型最终会通过load_embeddings加载
|
|
||||||
在线模型会在适配器中直接返回
|
|
||||||
:param kb_name:
|
|
||||||
:param embed_device:
|
|
||||||
:param default_embed_model:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
|
||||||
|
|
||||||
kb_detail = get_kb_detail(kb_name)
|
|
||||||
embed_model = kb_detail.get("embed_model", default_embed_model)
|
|
||||||
endpoint_host = kb_detail.get("endpoint_host", None)
|
|
||||||
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
|
|
||||||
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
|
|
||||||
|
|
||||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_model=embed_model)
|
|
||||||
|
|
||||||
|
|
||||||
def load_temp_adapter_embeddings(
|
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
default_embed_model: str = EMBEDDING_MODEL,
|
|
||||||
) -> "EmbeddingsFunAdapter":
|
|
||||||
"""
|
|
||||||
加载临时的Embeddings模型
|
|
||||||
本地模型最终会通过load_embeddings加载
|
|
||||||
在线模型会在适配器中直接返回
|
|
||||||
:param endpoint_host:
|
|
||||||
:param endpoint_host_key:
|
|
||||||
:param endpoint_host_proxy:
|
|
||||||
:param embed_device:
|
|
||||||
:param default_embed_model:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_model=default_embed_model)
|
|
||||||
@ -1,94 +0,0 @@
|
|||||||
from langchain.docstore.document import Document
|
|
||||||
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
|
|
||||||
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
|
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(
|
|
||||||
texts: List[str],
|
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
to_query: bool = False,
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
|
||||||
TODO: 也许需要加入缓存机制,减少 token 消耗
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
|
||||||
from server.utils import load_local_embeddings
|
|
||||||
embeddings = load_local_embeddings(model=embed_model)
|
|
||||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
|
||||||
|
|
||||||
# 使用在线API
|
|
||||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy):
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
||||||
embeddings = OpenAIEmbeddings(model=embed_model,
|
|
||||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
|
||||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
|
||||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
|
||||||
chunk_size=CHUNK_SIZE)
|
|
||||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
|
||||||
|
|
||||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def aembed_texts(
|
|
||||||
texts: List[str],
|
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
to_query: bool = False,
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
|
||||||
from server.utils import load_local_embeddings
|
|
||||||
|
|
||||||
embeddings = load_local_embeddings(model=embed_model)
|
|
||||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
|
||||||
|
|
||||||
# 使用在线API
|
|
||||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy):
|
|
||||||
return await run_in_threadpool(embed_texts,
|
|
||||||
texts=texts,
|
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_model=embed_model,
|
|
||||||
to_query=to_query)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
|
||||||
|
|
||||||
|
|
||||||
def embed_texts_endpoint(
|
|
||||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
|
||||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
|
||||||
'''
|
|
||||||
return embed_texts(texts=texts,
|
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_model=embed_model, to_query=to_query)
|
|
||||||
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
|||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||||
from configs import EMBEDDING_MODEL, logger, log_verbose
|
from configs import DEFAULT_EMBEDDING_MODEL, logger, log_verbose
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
||||||
|
|
||||||
@ -14,10 +14,7 @@ def list_kbs():
|
|||||||
|
|
||||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
@ -31,7 +28,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
|
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||||
try:
|
try:
|
||||||
kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
kb.create_kb()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"创建知识库出错: {e}"
|
msg = f"创建知识库出错: {e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.faiss import FAISS
|
from langchain.vectorstores.faiss import FAISS
|
||||||
import threading
|
import threading
|
||||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
from configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE,
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.utils import embedding_device, get_model_path
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Any, Union, Tuple
|
from typing import List, Any, Union, Tuple
|
||||||
@ -98,50 +97,3 @@ class CachePool:
|
|||||||
else:
|
else:
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsPool(CachePool):
|
|
||||||
|
|
||||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
|
||||||
"""
|
|
||||||
本地Embeddings模型加载
|
|
||||||
:param model:
|
|
||||||
:param device:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self.atomic.acquire()
|
|
||||||
model = model or EMBEDDING_MODEL
|
|
||||||
device = embedding_device()
|
|
||||||
key = (model, device)
|
|
||||||
if not self.get(key):
|
|
||||||
item = ThreadSafeObject(key, pool=self)
|
|
||||||
self.set(key, item)
|
|
||||||
with item.acquire(msg="初始化"):
|
|
||||||
self.atomic.release()
|
|
||||||
if 'bge-' in model:
|
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|
||||||
if 'zh' in model:
|
|
||||||
# for chinese model
|
|
||||||
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
|
||||||
elif 'en' in model:
|
|
||||||
# for english model
|
|
||||||
query_instruction = "Represent this sentence for searching relevant passages:"
|
|
||||||
else:
|
|
||||||
# maybe ReRanker or else, just use empty string instead
|
|
||||||
query_instruction = ""
|
|
||||||
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
|
||||||
model_kwargs={'device': device},
|
|
||||||
query_instruction=query_instruction)
|
|
||||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
|
||||||
embeddings.query_instruction = ""
|
|
||||||
else:
|
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
|
|
||||||
model_kwargs={'device': device})
|
|
||||||
item.obj = embeddings
|
|
||||||
item.finish_loading()
|
|
||||||
else:
|
|
||||||
self.atomic.release()
|
|
||||||
return self.get(key).obj
|
|
||||||
|
|
||||||
|
|
||||||
embeddings_pool = EmbeddingsPool(cache_num=1)
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||||
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
|
|
||||||
from server.knowledge_base.kb_cache.base import *
|
from server.knowledge_base.kb_cache.base import *
|
||||||
# from server.utils import load_local_embeddings
|
from server.utils import get_Embeddings
|
||||||
from server.knowledge_base.utils import get_vs_path
|
from server.knowledge_base.utils import get_vs_path
|
||||||
from langchain.vectorstores.faiss import FAISS
|
from langchain.vectorstores.faiss import FAISS
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
@ -53,13 +52,11 @@ class _FaissPool(CachePool):
|
|||||||
def new_vector_store(
|
def new_vector_store(
|
||||||
self,
|
self,
|
||||||
kb_name: str,
|
kb_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
|
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
embed_device=embed_device, default_embed_model=embed_model)
|
|
||||||
doc = Document(page_content="init", metadata={})
|
doc = Document(page_content="init", metadata={})
|
||||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||||
ids = list(vector_store.docstore._dict.keys())
|
ids = list(vector_store.docstore._dict.keys())
|
||||||
@ -68,18 +65,11 @@ class _FaissPool(CachePool):
|
|||||||
|
|
||||||
def new_temp_vector_store(
|
def new_temp_vector_store(
|
||||||
self,
|
self,
|
||||||
endpoint_host: str,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
|
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_device=embed_device, default_embed_model=embed_model)
|
|
||||||
doc = Document(page_content="init", metadata={})
|
doc = Document(page_content="init", metadata={})
|
||||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||||
ids = list(vector_store.docstore._dict.keys())
|
ids = list(vector_store.docstore._dict.keys())
|
||||||
@ -102,8 +92,7 @@ class KBFaissPool(_FaissPool):
|
|||||||
kb_name: str,
|
kb_name: str,
|
||||||
vector_name: str = None,
|
vector_name: str = None,
|
||||||
create: bool = True,
|
create: bool = True,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
vector_name = vector_name or embed_model
|
vector_name = vector_name or embed_model
|
||||||
@ -118,15 +107,13 @@ class KBFaissPool(_FaissPool):
|
|||||||
vs_path = get_vs_path(kb_name, vector_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
embed_device=embed_device, default_embed_model=embed_model)
|
|
||||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||||
elif create:
|
elif create:
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(vs_path):
|
||||||
os.makedirs(vs_path)
|
os.makedirs(vs_path)
|
||||||
vector_store = self.new_vector_store(kb_name=kb_name,
|
vector_store = self.new_vector_store(kb_name=kb_name, embed_model=embed_model)
|
||||||
embed_model=embed_model, embed_device=embed_device)
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||||
@ -148,11 +135,7 @@ class MemoFaissPool(_FaissPool):
|
|||||||
def load_vector_store(
|
def load_vector_store(
|
||||||
self,
|
self,
|
||||||
kb_name: str,
|
kb_name: str,
|
||||||
endpoint_host: str,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
cache = self.get(kb_name)
|
cache = self.get(kb_name)
|
||||||
@ -163,10 +146,7 @@ class MemoFaissPool(_FaissPool):
|
|||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host,
|
vector_store = self.new_temp_vector_store(embed_model=embed_model)
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
embed_model=embed_model, embed_device=embed_device)
|
|
||||||
item.obj = vector_store
|
item.obj = vector_store
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose, )
|
logger, log_verbose, )
|
||||||
@ -42,22 +42,6 @@ def search_docs(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def update_docs_by_id(
|
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
|
||||||
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
按照文档 ID 更新文档内容
|
|
||||||
'''
|
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
|
||||||
if kb is None:
|
|
||||||
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
|
|
||||||
if kb.update_doc_by_ids(docs=docs):
|
|
||||||
return BaseResponse(msg=f"文档更新成功")
|
|
||||||
else:
|
|
||||||
return BaseResponse(msg=f"文档更新失败")
|
|
||||||
|
|
||||||
|
|
||||||
def list_files(
|
def list_files(
|
||||||
knowledge_base_name: str
|
knowledge_base_name: str
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
@ -230,26 +214,6 @@ def update_info(
|
|||||||
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||||||
|
|
||||||
|
|
||||||
def update_kb_endpoint(
|
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
):
|
|
||||||
if not validate_kb_name(knowledge_base_name):
|
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
|
||||||
|
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
|
||||||
if kb is None:
|
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
|
||||||
kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
|
||||||
|
|
||||||
return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成",
|
|
||||||
data={"endpoint_host": endpoint_host,
|
|
||||||
"endpoint_host_key": endpoint_host_key,
|
|
||||||
"endpoint_host_proxy": endpoint_host_proxy})
|
|
||||||
|
|
||||||
|
|
||||||
def update_docs(
|
def update_docs(
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||||
@ -366,10 +330,7 @@ def recreate_vector_store(
|
|||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
|
||||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
@ -389,9 +350,7 @@ def recreate_vector_store(
|
|||||||
else:
|
else:
|
||||||
if kb.exists():
|
if kb.exists():
|
||||||
kb.clear_vs()
|
kb.clear_vs()
|
||||||
kb.create_kb(endpoint_host=endpoint_host,
|
kb.create_kb()
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy)
|
|
||||||
files = list_files_from_folder(knowledge_base_name)
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
kb_files = [(file, knowledge_base_name) for file in files]
|
kb_files = [(file, knowledge_base_name) for file in files]
|
||||||
i = 0
|
i = 0
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
|||||||
|
|
||||||
from server.db.repository.knowledge_base_repository import (
|
from server.db.repository.knowledge_base_repository import (
|
||||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||||
load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db,
|
load_kb_from_db, get_kb_detail,
|
||||||
)
|
)
|
||||||
from server.db.repository.knowledge_file_repository import (
|
from server.db.repository.knowledge_file_repository import (
|
||||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||||
@ -15,7 +15,7 @@ from server.db.repository.knowledge_file_repository import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_MODEL, KB_INFO)
|
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, KnowledgeFile,
|
get_kb_path, get_doc_path, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
@ -40,7 +40,7 @@ class KBService(ABC):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||||
@ -58,20 +58,14 @@ class KBService(ABC):
|
|||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_kb(self,
|
def create_kb(self):
|
||||||
endpoint_host: str = None,
|
|
||||||
endpoint_host_key: str = None,
|
|
||||||
endpoint_host_proxy: str = None):
|
|
||||||
"""
|
"""
|
||||||
创建知识库
|
创建知识库
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.doc_path):
|
if not os.path.exists(self.doc_path):
|
||||||
os.makedirs(self.doc_path)
|
os.makedirs(self.doc_path)
|
||||||
|
|
||||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model,
|
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy)
|
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
self.do_create_kb()
|
self.do_create_kb()
|
||||||
@ -144,16 +138,6 @@ class KBService(ABC):
|
|||||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def update_kb_endpoint(self,
|
|
||||||
endpoint_host: str = None,
|
|
||||||
endpoint_host_key: str = None,
|
|
||||||
endpoint_host_proxy: str = None):
|
|
||||||
"""
|
|
||||||
更新知识库在线api接入点配置
|
|
||||||
"""
|
|
||||||
status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
|
||||||
return status
|
|
||||||
|
|
||||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||||
"""
|
"""
|
||||||
使用content中的文件更新向量库
|
使用content中的文件更新向量库
|
||||||
@ -297,7 +281,7 @@ class KBServiceFactory:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service(kb_name: str,
|
def get_service(kb_name: str,
|
||||||
vector_store_type: Union[str, SupportedVSType],
|
vector_store_type: Union[str, SupportedVSType],
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
if isinstance(vector_store_type, str):
|
if isinstance(vector_store_type, str):
|
||||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||||
|
|||||||
@ -1,13 +1,11 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||||
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
from server.utils import load_local_embeddings
|
from server.utils import get_Embeddings
|
||||||
from elasticsearch import Elasticsearch,BadRequestError
|
from elasticsearch import Elasticsearch,BadRequestError
|
||||||
from configs import logger
|
from configs import logger
|
||||||
from configs import kbs_config
|
from configs import kbs_config
|
||||||
@ -22,7 +20,7 @@ class ESKBService(KBService):
|
|||||||
self.user = kbs_config[self.vs_type()].get("user",'')
|
self.user = kbs_config[self.vs_type()].get("user",'')
|
||||||
self.password = kbs_config[self.vs_type()].get("password",'')
|
self.password = kbs_config[self.vs_type()].get("password",'')
|
||||||
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
|
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
|
||||||
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
self.embeddings_model = get_Embeddings(self.embed_model)
|
||||||
try:
|
try:
|
||||||
# ES python客户端连接(仅连接)
|
# ES python客户端连接(仅连接)
|
||||||
if self.user != "" and self.password != "":
|
if self.user != "" and self.password != "":
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
DEFAULT_EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH)
|
KB_ROOT_PATH)
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -21,7 +21,7 @@ class KBSummaryService(ABC):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL
|
embed_model: str = DEFAULT_EMBEDDING_MODEL
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
logger, log_verbose, )
|
logger, log_verbose, )
|
||||||
from server.knowledge_base.utils import (list_files_from_folder)
|
from server.knowledge_base.utils import (list_files_from_folder)
|
||||||
@ -17,11 +17,8 @@ def recreate_summary_vector_store(
|
|||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
@ -29,9 +26,6 @@ def recreate_summary_vector_store(
|
|||||||
"""
|
"""
|
||||||
重建单个知识库文件摘要
|
重建单个知识库文件摘要
|
||||||
:param max_tokens:
|
:param max_tokens:
|
||||||
:param endpoint_host:
|
|
||||||
:param endpoint_host_key:
|
|
||||||
:param endpoint_host_proxy:
|
|
||||||
:param model_name:
|
:param model_name:
|
||||||
:param temperature:
|
:param temperature:
|
||||||
:param file_description:
|
:param file_description:
|
||||||
@ -54,17 +48,11 @@ def recreate_summary_vector_store(
|
|||||||
kb_summary.create_kb_summary()
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -110,20 +98,14 @@ def summary_file_to_vector_store(
|
|||||||
file_name: str = Body(..., examples=["test.pdf"]),
|
file_name: str = Body(..., examples=["test.pdf"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
单个知识库根据文件名称摘要
|
单个知识库根据文件名称摘要
|
||||||
:param endpoint_host:
|
|
||||||
:param endpoint_host_key:
|
|
||||||
:param endpoint_host_proxy:
|
|
||||||
:param model_name:
|
:param model_name:
|
||||||
:param max_tokens:
|
:param max_tokens:
|
||||||
:param temperature:
|
:param temperature:
|
||||||
@ -146,17 +128,11 @@ def summary_file_to_vector_store(
|
|||||||
kb_summary.create_kb_summary()
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -194,11 +170,8 @@ def summary_doc_ids_to_vector_store(
|
|||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
doc_ids: List = Body([], examples=[["uuid"]]),
|
doc_ids: List = Body([], examples=[["uuid"]]),
|
||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
endpoint_host: str = Body(None, description="接入点地址"),
|
|
||||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
|
||||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
@ -206,9 +179,6 @@ def summary_doc_ids_to_vector_store(
|
|||||||
"""
|
"""
|
||||||
单个知识库根据doc_ids摘要
|
单个知识库根据doc_ids摘要
|
||||||
:param knowledge_base_name:
|
:param knowledge_base_name:
|
||||||
:param endpoint_host:
|
|
||||||
:param endpoint_host_key:
|
|
||||||
:param endpoint_host_proxy:
|
|
||||||
:param doc_ids:
|
:param doc_ids:
|
||||||
:param model_name:
|
:param model_name:
|
||||||
:param max_tokens:
|
:param max_tokens:
|
||||||
@ -223,17 +193,11 @@ def summary_doc_ids_to_vector_store(
|
|||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
|
||||||
else:
|
else:
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=endpoint_host_key,
|
|
||||||
endpoint_host_proxy=endpoint_host_proxy,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE,
|
CHUNK_SIZE, OVERLAP_SIZE,
|
||||||
logger, log_verbose
|
logger, log_verbose
|
||||||
)
|
)
|
||||||
@ -86,7 +86,7 @@ def folder2db(
|
|||||||
kb_names: List[str],
|
kb_names: List[str],
|
||||||
mode: Literal["recreate_vs", "update_in_db", "increment"],
|
mode: Literal["recreate_vs", "update_in_db", "increment"],
|
||||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from configs import (
|
from configs import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
@ -143,6 +144,7 @@ def get_LoaderClass(file_extension):
|
|||||||
if file_extension in extensions:
|
if file_extension in extensions:
|
||||||
return LoaderClass
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||||
'''
|
'''
|
||||||
根据loader_name和文件路径或内容返回文档加载器。
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
@ -184,6 +186,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
|||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def make_text_splitter(
|
def make_text_splitter(
|
||||||
splitter_name,
|
splitter_name,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
|||||||
363
server/localai_embeddings.py
Normal file
363
server/localai_embeddings.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
from langchain_community.utils.openai import is_openai_v1
|
||||||
|
from tenacity import (
|
||||||
|
AsyncRetrying,
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.Timeout)
|
||||||
|
| retry_if_exception_type(openai.APIError)
|
||||||
|
| retry_if_exception_type(openai.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.InternalServerError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
async_retrying = AsyncRetrying(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.Timeout)
|
||||||
|
| retry_if_exception_type(openai.APIError)
|
||||||
|
| retry_if_exception_type(openai.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.InternalServerError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def wrap(func: Callable) -> Callable:
|
||||||
|
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
|
||||||
|
async for _ in async_retrying:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
raise AssertionError("this is unreachable")
|
||||||
|
|
||||||
|
return wrapped_f
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
|
||||||
|
def _check_response(response: dict) -> dict:
|
||||||
|
if any([len(d.embedding) == 1 for d in response.data]):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
raise openai.APIError("LocalAI API returned an empty embedding")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embedding call."""
|
||||||
|
retry_decorator = _create_retry_decorator(embeddings)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
response = embeddings.client.create(**kwargs)
|
||||||
|
return _check_response(response)
|
||||||
|
|
||||||
|
return _embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embedding call."""
|
||||||
|
|
||||||
|
@_async_retry_decorator(embeddings)
|
||||||
|
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
response = await embeddings.async_client.acreate(**kwargs)
|
||||||
|
return _check_response(response)
|
||||||
|
|
||||||
|
return await _async_embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""LocalAI embedding models.
|
||||||
|
|
||||||
|
Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
|
||||||
|
uses the ``openai`` Python package's ``openai.Embedding`` as its client.
|
||||||
|
Thus, you should have the ``openai`` python package installed, and defeat
|
||||||
|
the environment variable ``OPENAI_API_KEY`` by setting to a random string.
|
||||||
|
You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
|
||||||
|
service endpoint.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.embeddings import LocalAIEmbeddings
|
||||||
|
openai = LocalAIEmbeddings(
|
||||||
|
openai_api_key="random-string",
|
||||||
|
openai_api_base="http://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
model: str = "text-embedding-ada-002"
|
||||||
|
deployment: str = model
|
||||||
|
openai_api_version: Optional[str] = None
|
||||||
|
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||||
|
# to support explicit proxy for LocalAI
|
||||||
|
openai_proxy: Optional[str] = None
|
||||||
|
embedding_ctx_length: int = 8191
|
||||||
|
"""The maximum number of tokens to embed at once."""
|
||||||
|
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||||
|
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||||
|
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||||
|
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||||
|
chunk_size: int = 1000
|
||||||
|
"""Maximum number of texts to embed in each batch"""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||||
|
default=None, alias="timeout"
|
||||||
|
)
|
||||||
|
"""Timeout in seconds for the LocalAI request."""
|
||||||
|
headers: Any = None
|
||||||
|
show_progress_bar: bool = False
|
||||||
|
"""Whether to show a progress bar when embedding."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
warnings.warn(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
|
)
|
||||||
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_base",
|
||||||
|
"OPENAI_API_BASE",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_proxy",
|
||||||
|
"OPENAI_PROXY",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_api_version = ""
|
||||||
|
values["openai_api_version"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_version",
|
||||||
|
"OPENAI_API_VERSION",
|
||||||
|
default=default_api_version,
|
||||||
|
)
|
||||||
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_organization",
|
||||||
|
"OPENAI_ORGANIZATION",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
if is_openai_v1():
|
||||||
|
client_params = {
|
||||||
|
"api_key": values["openai_api_key"],
|
||||||
|
"organization": values["openai_organization"],
|
||||||
|
"base_url": values["openai_api_base"],
|
||||||
|
"timeout": values["request_timeout"],
|
||||||
|
"max_retries": values["max_retries"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not values.get("client"):
|
||||||
|
values["client"] = openai.OpenAI(**client_params).embeddings
|
||||||
|
if not values.get("async_client"):
|
||||||
|
values["async_client"] = openai.AsyncOpenAI(
|
||||||
|
**client_params
|
||||||
|
).embeddings
|
||||||
|
elif not values.get("client"):
|
||||||
|
values["client"] = openai.Embedding
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Dict:
|
||||||
|
openai_args = {
|
||||||
|
"model": self.model,
|
||||||
|
"timeout": self.request_timeout,
|
||||||
|
"extra_headers": self.headers,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
openai.proxy = {
|
||||||
|
"http": self.openai_proxy,
|
||||||
|
"https": self.openai_proxy,
|
||||||
|
} # type: ignore[assignment] # noqa: E501
|
||||||
|
return openai_args
|
||||||
|
|
||||||
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint."""
|
||||||
|
# handle large input text
|
||||||
|
if self.model.endswith("001"):
|
||||||
|
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
return embed_with_retry(
|
||||||
|
self,
|
||||||
|
input=[text],
|
||||||
|
**self._invocation_params,
|
||||||
|
).data[0].embedding
|
||||||
|
|
||||||
|
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint."""
|
||||||
|
# handle large input text
|
||||||
|
if self.model.endswith("001"):
|
||||||
|
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
return (
|
||||||
|
await async_embed_with_retry(
|
||||||
|
self,
|
||||||
|
input=[text],
|
||||||
|
**self._invocation_params,
|
||||||
|
)
|
||||||
|
).data[0].embedding
|
||||||
|
|
||||||
|
def embed_documents(
|
||||||
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
|
specified by the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
# call _embedding_func for each text
|
||||||
|
return [self._embedding_func(text, engine=self.deployment) for text in texts]
|
||||||
|
|
||||||
|
async def aembed_documents(
|
||||||
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
|
specified by the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
response = await self._aembedding_func(text, engine=self.deployment)
|
||||||
|
embeddings.append(response)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint for embedding query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding for the text.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding_func(text, engine=self.deployment)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint async for embedding query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding for the text.
|
||||||
|
"""
|
||||||
|
embedding = await self._aembedding_func(text, engine=self.deployment)
|
||||||
|
return embedding
|
||||||
@ -109,14 +109,13 @@ if __name__ == "__main__":
|
|||||||
RERANKER_MODEL,
|
RERANKER_MODEL,
|
||||||
RERANKER_MAX_LENGTH,
|
RERANKER_MAX_LENGTH,
|
||||||
MODEL_PATH)
|
MODEL_PATH)
|
||||||
from server.utils import embedding_device
|
|
||||||
|
|
||||||
if USE_RERANKER:
|
if USE_RERANKER:
|
||||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
|
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
|
||||||
print("-----------------model path------------------")
|
print("-----------------model path------------------")
|
||||||
print(reranker_model_path)
|
print(reranker_model_path)
|
||||||
reranker_model = LangchainReranker(top_n=3,
|
reranker_model = LangchainReranker(top_n=3,
|
||||||
device=embedding_device(),
|
device="cpu",
|
||||||
max_length=RERANKER_MAX_LENGTH,
|
max_length=RERANKER_MAX_LENGTH,
|
||||||
model_name_or_path=reranker_model_path
|
model_name_or_path=reranker_model_path
|
||||||
)
|
)
|
||||||
|
|||||||
269
server/utils.py
269
server/utils.py
@ -1,32 +1,28 @@
|
|||||||
import pydantic
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
|
|
||||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
|
||||||
HTTPX_DEFAULT_TIMEOUT)
|
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain_openai.chat_models import ChatOpenAI
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
from langchain_community.llms import OpenAI
|
from langchain_openai.llms import OpenAI
|
||||||
import httpx
|
import httpx
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
Optional,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
Generator,
|
||||||
Dict,
|
Dict,
|
||||||
|
List,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Union,
|
Union,
|
||||||
Tuple
|
Tuple,
|
||||||
|
Literal,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
import torch
|
|
||||||
|
|
||||||
|
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL
|
||||||
from server.minx_chat_openai import MinxChatOpenAI
|
from server.minx_chat_openai import MinxChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
@ -44,10 +40,66 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||||||
event.set()
|
event.set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_models(
|
||||||
|
model_name: str = None,
|
||||||
|
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
|
||||||
|
platform_name: str = None,
|
||||||
|
) -> Dict[str, Dict]:
|
||||||
|
'''
|
||||||
|
获取配置的模型列表,返回值为:
|
||||||
|
{model_name: {
|
||||||
|
"platform_name": xx,
|
||||||
|
"platform_type": xx,
|
||||||
|
"model_type": xx,
|
||||||
|
"model_name": xx,
|
||||||
|
"api_base_url": xx,
|
||||||
|
"api_key": xx,
|
||||||
|
"api_proxy": xx,
|
||||||
|
}}
|
||||||
|
'''
|
||||||
|
import importlib
|
||||||
|
from configs import model_config
|
||||||
|
importlib.reload(model_config)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for m in model_config.MODEL_PLATFORMS:
|
||||||
|
if platform_name is not None and platform_name != m.get("platform_name"):
|
||||||
|
continue
|
||||||
|
if model_type is not None and f"{model_type}_models" not in m:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model_type is None:
|
||||||
|
model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"]
|
||||||
|
else:
|
||||||
|
model_types = [f"{model_type}_models"]
|
||||||
|
|
||||||
|
for m_type in model_types:
|
||||||
|
for m_name in m.get(m_type, []):
|
||||||
|
if model_name is None or model_name == m_name:
|
||||||
|
result[m_name] = {
|
||||||
|
"platform_name": m.get("platform_name"),
|
||||||
|
"platform_type": m.get("platform_type"),
|
||||||
|
"model_type": m_type.split("_")[0],
|
||||||
|
"model_name": m_name,
|
||||||
|
"api_base_url": m.get("api_base_url"),
|
||||||
|
"api_key": m.get("api_key"),
|
||||||
|
"api_proxy": m.get("api_proxy"),
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_info(model_name: str, platform_name: str = None) -> Dict:
|
||||||
|
'''
|
||||||
|
获取配置的模型信息,主要是 api_base_url, api_key
|
||||||
|
'''
|
||||||
|
result = get_config_models(model_name=model_name, platform_name=platform_name)
|
||||||
|
if len(result) > 0:
|
||||||
|
return list(result.values())[0]
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_ChatOpenAI(
|
def get_ChatOpenAI(
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
@ -56,29 +108,23 @@ def get_ChatOpenAI(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI:
|
||||||
config = get_model_worker_config(model_name)
|
model_info = get_model_info(model_name)
|
||||||
if model_name == "openai-api":
|
|
||||||
model_name = config.get("model_name")
|
|
||||||
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
|
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
|
||||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
openai_api_key=model_info.get("api_key"),
|
||||||
|
openai_api_base=model_info.get("api_base_url"),
|
||||||
|
openai_proxy=model_info.get("api_proxy"),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_OpenAI(
|
def get_OpenAI(
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
@ -89,22 +135,40 @@ def get_OpenAI(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OpenAI:
|
) -> OpenAI:
|
||||||
# TODO: 从API获取模型信息
|
# TODO: 从API获取模型信息
|
||||||
|
model_info = get_model_info(model_name)
|
||||||
model = OpenAI(
|
model = OpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
|
||||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
openai_api_key=model_info.get("api_key"),
|
||||||
|
openai_api_base=model_info.get("api_base_url"),
|
||||||
|
openai_proxy=model_info.get("api_proxy"),
|
||||||
echo=echo,
|
echo=echo,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings:
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||||
|
|
||||||
|
model_info = get_model_info(model_name=embed_model)
|
||||||
|
params = {
|
||||||
|
"model": embed_model,
|
||||||
|
"base_url": model_info.get("api_base_url"),
|
||||||
|
"api_key": model_info.get("api_key"),
|
||||||
|
"openai_proxy": model_info.get("api_proxy"),
|
||||||
|
}
|
||||||
|
if model_info.get("platform_type") == "openai":
|
||||||
|
return OpenAIEmbeddings(**params)
|
||||||
|
else:
|
||||||
|
return LocalAIEmbeddings(**params)
|
||||||
|
|
||||||
|
|
||||||
class MsgType:
|
class MsgType:
|
||||||
TEXT = 1
|
TEXT = 1
|
||||||
IMAGE = 2
|
IMAGE = 2
|
||||||
@ -113,9 +177,9 @@ class MsgType:
|
|||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
code: int = pydantic.Field(200, description="API status code")
|
code: int = Field(200, description="API status code")
|
||||||
msg: str = pydantic.Field("success", description="API status message")
|
msg: str = Field("success", description="API status message")
|
||||||
data: Any = pydantic.Field(None, description="API data")
|
data: Any = Field(None, description="API data")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -127,7 +191,7 @@ class BaseResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ListResponse(BaseResponse):
|
class ListResponse(BaseResponse):
|
||||||
data: List[str] = pydantic.Field(..., description="List of names")
|
data: List[str] = Field(..., description="List of names")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -140,10 +204,10 @@ class ListResponse(BaseResponse):
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
question: str = pydantic.Field(..., description="Question text")
|
question: str = Field(..., description="Question text")
|
||||||
response: str = pydantic.Field(..., description="Response text")
|
response: str = Field(..., description="Response text")
|
||||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
history: List[List[str]] = Field(..., description="History text")
|
||||||
source_documents: List[str] = pydantic.Field(
|
source_documents: List[str] = Field(
|
||||||
..., description="List of source documents and their scores"
|
..., description="List of source documents and their scores"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,39 +374,40 @@ def MakeFastAPIOffline(
|
|||||||
|
|
||||||
|
|
||||||
# 从model_config中获取模型信息
|
# 从model_config中获取模型信息
|
||||||
|
# TODO: 移出模型加载后,这些功能需要删除或改变实现
|
||||||
|
|
||||||
def list_embed_models() -> List[str]:
|
# def list_embed_models() -> List[str]:
|
||||||
'''
|
# '''
|
||||||
get names of configured embedding models
|
# get names of configured embedding models
|
||||||
'''
|
# '''
|
||||||
return list(MODEL_PATH["embed_model"])
|
# return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
# def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||||
if type in MODEL_PATH:
|
# if type in MODEL_PATH:
|
||||||
paths = MODEL_PATH[type]
|
# paths = MODEL_PATH[type]
|
||||||
else:
|
# else:
|
||||||
paths = {}
|
# paths = {}
|
||||||
for v in MODEL_PATH.values():
|
# for v in MODEL_PATH.values():
|
||||||
paths.update(v)
|
# paths.update(v)
|
||||||
|
|
||||||
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
# if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||||
path = Path(path_str)
|
# path = Path(path_str)
|
||||||
if path.is_dir(): # 任意绝对路径
|
# if path.is_dir(): # 任意绝对路径
|
||||||
return str(path)
|
# return str(path)
|
||||||
|
|
||||||
root_path = Path(MODEL_ROOT_PATH)
|
# root_path = Path(MODEL_ROOT_PATH)
|
||||||
if root_path.is_dir():
|
# if root_path.is_dir():
|
||||||
path = root_path / model_name
|
# path = root_path / model_name
|
||||||
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
# if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||||
return str(path)
|
# return str(path)
|
||||||
path = root_path / path_str
|
# path = root_path / path_str
|
||||||
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
# if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||||
return str(path)
|
# return str(path)
|
||||||
path = root_path / path_str.split("/")[-1]
|
# path = root_path / path_str.split("/")[-1]
|
||||||
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
# if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||||
return str(path)
|
# return str(path)
|
||||||
return path_str # THUDM/chatglm06b
|
# return path_str # THUDM/chatglm06b
|
||||||
|
|
||||||
|
|
||||||
def api_address() -> str:
|
def api_address() -> str:
|
||||||
@ -429,37 +494,6 @@ def set_httpx_config(
|
|||||||
urllib.request.getproxies = _get_proxies
|
urllib.request.getproxies = _get_proxies
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return "cuda"
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.get_device_properties(0):
|
|
||||||
return "xpu"
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
device = device or LLM_DEVICE
|
|
||||||
# if device.isdigit():
|
|
||||||
# return "cuda:" + device
|
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
device = detect_device()
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
device = device or EMBEDDING_DEVICE
|
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
device = detect_device()
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def run_in_thread_pool(
|
def run_in_thread_pool(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
params: List[Dict] = [],
|
params: List[Dict] = [],
|
||||||
@ -546,56 +580,19 @@ def get_server_configs() -> Dict:
|
|||||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||||
|
|
||||||
|
|
||||||
def list_online_embed_models(
|
|
||||||
endpoint_host: str,
|
|
||||||
endpoint_host_key: str,
|
|
||||||
endpoint_host_proxy: str
|
|
||||||
) -> List[str]:
|
|
||||||
ret = []
|
|
||||||
# TODO: 从在线API获取支持的模型列表
|
|
||||||
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
|
|
||||||
try:
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {endpoint_host_key}",
|
|
||||||
}
|
|
||||||
resp = client.get("/models", headers=headers)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
models = resp.json().get("data", [])
|
|
||||||
for model in models:
|
|
||||||
if "embedding" in model.get("id", None):
|
|
||||||
ret.append(model.get("id"))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"获取在线Embeddings模型列表失败:{e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
|
||||||
'''
|
|
||||||
从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。
|
|
||||||
'''
|
|
||||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
|
||||||
from configs import EMBEDDING_MODEL
|
|
||||||
|
|
||||||
model = model or EMBEDDING_MODEL
|
|
||||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||||
'''
|
'''
|
||||||
创建一个临时目录,返回(路径,文件夹名称)
|
创建一个临时目录,返回(路径,文件夹名称)
|
||||||
'''
|
'''
|
||||||
from configs.basic_config import BASE_TEMP_DIR
|
from configs.basic_config import BASE_TEMP_DIR
|
||||||
import tempfile
|
import uuid
|
||||||
|
|
||||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||||
path = os.path.join(BASE_TEMP_DIR, id)
|
path = os.path.join(BASE_TEMP_DIR, id)
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return path, id
|
return path, id
|
||||||
|
|
||||||
path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
|
id = uuid.uuid4().hex
|
||||||
return path, os.path.basename(path)
|
path = os.path.join(BASE_TEMP_DIR, id)
|
||||||
|
os.mkdir(path)
|
||||||
|
return path, id
|
||||||
|
|||||||
45
startup.py
45
startup.py
@ -1,12 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from datetime import datetime
|
|
||||||
from pprint import pprint
|
|
||||||
from langchain_core._api import deprecated
|
|
||||||
|
|
||||||
# 设置numexpr最大线程数,默认为CPU核心数
|
# 设置numexpr最大线程数,默认为CPU核心数
|
||||||
try:
|
try:
|
||||||
@ -17,38 +16,29 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
||||||
from configs import (
|
from configs import (
|
||||||
LOG_PATH,
|
LOG_PATH,
|
||||||
log_verbose,
|
log_verbose,
|
||||||
logger,
|
logger,
|
||||||
LLM_MODEL_CONFIG,
|
DEFAULT_EMBEDDING_MODEL,
|
||||||
EMBEDDING_MODEL,
|
|
||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
API_SERVER,
|
API_SERVER,
|
||||||
WEBUI_SERVER,
|
WEBUI_SERVER,
|
||||||
HTTPX_DEFAULT_TIMEOUT,
|
|
||||||
)
|
)
|
||||||
from server.utils import (FastAPI, embedding_device)
|
from server.utils import FastAPI
|
||||||
from server.knowledge_base.migrate import create_tables
|
from server.knowledge_base.migrate import create_tables
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
||||||
all_model_names = set()
|
|
||||||
for model_category in LLM_MODEL_CONFIG.values():
|
|
||||||
for model_name in model_category.keys():
|
|
||||||
if model_name not in all_model_names:
|
|
||||||
all_model_names.add(model_name)
|
|
||||||
|
|
||||||
all_model_names_list = list(all_model_names)
|
|
||||||
|
|
||||||
|
|
||||||
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||||
@app.on_event("startup")
|
@asynccontextmanager
|
||||||
async def on_startup():
|
async def lifespan(app: FastAPI):
|
||||||
if started_event is not None:
|
if started_event is not None:
|
||||||
started_event.set()
|
started_event.set()
|
||||||
|
yield
|
||||||
|
app.router.lifespan_context = lifespan
|
||||||
|
|
||||||
|
|
||||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||||
@ -159,7 +149,7 @@ def dump_server_info(after_start=False, args=None):
|
|||||||
|
|
||||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||||
|
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}")
|
||||||
|
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
@ -232,13 +222,13 @@ async def start_main_server():
|
|||||||
return len(processes)
|
return len(processes)
|
||||||
|
|
||||||
loom_started = manager.Event()
|
loom_started = manager.Event()
|
||||||
process = Process(
|
# process = Process(
|
||||||
target=run_loom,
|
# target=run_loom,
|
||||||
name=f"run_loom Server",
|
# name=f"run_loom Server",
|
||||||
kwargs=dict(started_event=loom_started),
|
# kwargs=dict(started_event=loom_started),
|
||||||
daemon=True,
|
# daemon=True,
|
||||||
)
|
# )
|
||||||
processes["run_loom"] = process
|
# processes["run_loom"] = process
|
||||||
api_started = manager.Event()
|
api_started = manager.Event()
|
||||||
if args.api:
|
if args.api:
|
||||||
process = Process(
|
process = Process(
|
||||||
@ -283,7 +273,6 @@ async def start_main_server():
|
|||||||
|
|
||||||
# 等待所有进程退出
|
# 等待所有进程退出
|
||||||
if p := processes.get("webui"):
|
if p := processes.get("webui"):
|
||||||
|
|
||||||
p.join()
|
p.join()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -306,9 +295,7 @@ async def start_main_server():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
create_tables()
|
create_tables()
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
else:
|
else:
|
||||||
|
|||||||
31
webui.py
31
webui.py
@ -1,7 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from webui_pages.loom_view_client import update_store
|
# from webui_pages.loom_view_client import update_store
|
||||||
from webui_pages.openai_plugins import openai_plugins_page
|
# from webui_pages.openai_plugins import openai_plugins_page
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||||
@ -12,9 +12,9 @@ from configs import VERSION
|
|||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
|
|
||||||
|
|
||||||
def on_change(key):
|
# def on_change(key):
|
||||||
if key:
|
# if key:
|
||||||
update_store()
|
# update_store()
|
||||||
|
|
||||||
|
|
||||||
api = ApiRequest(base_url=api_address())
|
api = ApiRequest(base_url=api_address())
|
||||||
@ -59,18 +59,18 @@ if __name__ == "__main__":
|
|||||||
"icon": "hdd-stack",
|
"icon": "hdd-stack",
|
||||||
"func": knowledge_base_page,
|
"func": knowledge_base_page,
|
||||||
},
|
},
|
||||||
"模型服务": {
|
# "模型服务": {
|
||||||
"icon": "hdd-stack",
|
# "icon": "hdd-stack",
|
||||||
"func": openai_plugins_page,
|
# "func": openai_plugins_page,
|
||||||
},
|
# },
|
||||||
}
|
}
|
||||||
# 更新状态
|
# 更新状态
|
||||||
if "status" not in st.session_state \
|
# if "status" not in st.session_state \
|
||||||
or "run_plugins_list" not in st.session_state \
|
# or "run_plugins_list" not in st.session_state \
|
||||||
or "launch_subscribe_info" not in st.session_state \
|
# or "launch_subscribe_info" not in st.session_state \
|
||||||
or "list_running_models" not in st.session_state \
|
# or "list_running_models" not in st.session_state \
|
||||||
or "model_config" not in st.session_state:
|
# or "model_config" not in st.session_state:
|
||||||
update_store()
|
# update_store()
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.image(
|
st.image(
|
||||||
@ -95,7 +95,6 @@ if __name__ == "__main__":
|
|||||||
icons=icons,
|
icons=icons,
|
||||||
# menu_icon="chat-quote",
|
# menu_icon="chat-quote",
|
||||||
default_index=default_index,
|
default_index=default_index,
|
||||||
on_change=on_change,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_page in pages:
|
if selected_page in pages:
|
||||||
|
|||||||
@ -4,8 +4,8 @@ import streamlit as st
|
|||||||
from streamlit_antd_components.utils import ParseItems
|
from streamlit_antd_components.utils import ParseItems
|
||||||
|
|
||||||
from webui_pages.dialogue.utils import process_files
|
from webui_pages.dialogue.utils import process_files
|
||||||
from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
# from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||||
get_select_model_endpoint
|
# get_select_model_endpoint
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from streamlit_modal import Modal
|
from streamlit_modal import Modal
|
||||||
@ -13,9 +13,9 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY)
|
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG)
|
||||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||||
from server.utils import MsgType
|
from server.utils import MsgType, get_config_models
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
@ -111,8 +111,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.session_state.setdefault("conversation_ids", {})
|
st.session_state.setdefault("conversation_ids", {})
|
||||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||||
st.session_state.setdefault("file_chat_id", None)
|
st.session_state.setdefault("file_chat_id", None)
|
||||||
st.session_state.setdefault("select_plugins_info", None)
|
|
||||||
st.session_state.setdefault("select_model_worker", None)
|
|
||||||
|
|
||||||
# 弹出自定义命令帮助信息
|
# 弹出自定义命令帮助信息
|
||||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||||
@ -131,18 +129,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_box.use_chat_name(conversation_name)
|
chat_box.use_chat_name(conversation_name)
|
||||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||||
|
|
||||||
with st.expander("模型选择"):
|
platforms = [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||||
plugins_menu = build_providers_model_plugins_name()
|
platform = st.selectbox("选择模型平台", platforms, 1)
|
||||||
|
llm_models = list(get_config_models(model_type="llm", platform_name=platform))
|
||||||
items, _ = ParseItems(plugins_menu).multi()
|
llm_model = st.selectbox("选择LLM模型", llm_models)
|
||||||
|
|
||||||
if len(plugins_menu) > 0:
|
|
||||||
|
|
||||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
|
|
||||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
|
||||||
set_llm_select(plugins_info, llm_model_worker)
|
|
||||||
else:
|
|
||||||
st.info("没有可用的插件")
|
|
||||||
|
|
||||||
# 传入后端的内容
|
# 传入后端的内容
|
||||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||||
@ -174,10 +164,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
if is_selected:
|
if is_selected:
|
||||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||||
|
|
||||||
llm_model = None
|
|
||||||
if st.session_state["select_model_worker"] is not None:
|
|
||||||
llm_model = st.session_state["select_model_worker"]['label']
|
|
||||||
|
|
||||||
if llm_model is not None:
|
if llm_model is not None:
|
||||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||||
|
|
||||||
@ -200,23 +186,23 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_box.output_messages()
|
chat_box.output_messages()
|
||||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||||
|
|
||||||
def on_feedback(
|
# def on_feedback(
|
||||||
feedback,
|
# feedback,
|
||||||
message_id: str = "",
|
# message_id: str = "",
|
||||||
history_index: int = -1,
|
# history_index: int = -1,
|
||||||
):
|
# ):
|
||||||
|
|
||||||
reason = feedback["text"]
|
# reason = feedback["text"]
|
||||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
# score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||||
api.chat_feedback(message_id=message_id,
|
# api.chat_feedback(message_id=message_id,
|
||||||
score=score_int,
|
# score=score_int,
|
||||||
reason=reason)
|
# reason=reason)
|
||||||
st.session_state["need_rerun"] = True
|
# st.session_state["need_rerun"] = True
|
||||||
|
|
||||||
feedback_kwargs = {
|
# feedback_kwargs = {
|
||||||
"feedback_type": "thumbs",
|
# "feedback_type": "thumbs",
|
||||||
"optional_text_label": "欢迎反馈您打分的理由",
|
# "optional_text_label": "欢迎反馈您打分的理由",
|
||||||
}
|
# }
|
||||||
|
|
||||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||||
@ -244,17 +230,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
text_action = ""
|
text_action = ""
|
||||||
element_index = 0
|
element_index = 0
|
||||||
|
|
||||||
openai_config = {}
|
|
||||||
endpoint_host, select_model_name = get_select_model_endpoint()
|
|
||||||
openai_config["endpoint_host"] = endpoint_host
|
|
||||||
openai_config["model_name"] = select_model_name
|
|
||||||
openai_config["endpoint_host_key"] = OPENAI_KEY
|
|
||||||
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
|
|
||||||
for d in api.chat_chat(query=prompt,
|
for d in api.chat_chat(query=prompt,
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
history=history,
|
history=history,
|
||||||
chat_model_config=chat_model_config,
|
chat_model_config=chat_model_config,
|
||||||
openai_config=openai_config,
|
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
tool_config=selected_tool_configs,
|
tool_config=selected_tool_configs,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from streamlit_antd_components.utils import ParseItems
|
from streamlit_antd_components.utils import ParseItems
|
||||||
|
|
||||||
from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
|
# from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
|
||||||
set_llm_select, set_embed_select, get_select_embed_endpoint
|
# set_llm_select, set_embed_select, get_select_embed_endpoint
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from st_aggrid import AgGrid, JsCode
|
from st_aggrid import AgGrid, JsCode
|
||||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||||
@ -10,10 +10,9 @@ import pandas as pd
|
|||||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||||
from typing import Literal, Dict, Tuple
|
from typing import Literal, Dict, Tuple
|
||||||
from configs import (kbs_config,
|
from configs import (kbs_config, DEFAULT_VS_TYPE,
|
||||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY)
|
from server.utils import get_config_models
|
||||||
from server.utils import list_embed_models
|
|
||||||
|
|
||||||
import streamlit_antd_components as sac
|
import streamlit_antd_components as sac
|
||||||
import os
|
import os
|
||||||
@ -116,25 +115,11 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
|
|
||||||
col1, _ = st.columns([3, 1])
|
col1, _ = st.columns([3, 1])
|
||||||
with col1:
|
with col1:
|
||||||
col1.text("Embedding 模型")
|
embed_models = list(get_config_models(model_type="embed"))
|
||||||
plugins_menu = build_providers_embedding_plugins_name()
|
index = 0
|
||||||
|
if DEFAULT_EMBEDDING_MODEL in embed_models:
|
||||||
embed_models = list_embed_models()
|
index = embed_models.index(DEFAULT_EMBEDDING_MODEL)
|
||||||
menu_item_children = []
|
embed_model = st.selectbox("Embeddings模型", embed_models, index)
|
||||||
for model in embed_models:
|
|
||||||
menu_item_children.append(sac.MenuItem(model, description=model))
|
|
||||||
|
|
||||||
plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children))
|
|
||||||
|
|
||||||
items, _ = ParseItems(plugins_menu).multi()
|
|
||||||
|
|
||||||
if len(plugins_menu) > 0:
|
|
||||||
|
|
||||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False)
|
|
||||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
|
||||||
set_embed_select(plugins_info, llm_model_worker)
|
|
||||||
else:
|
|
||||||
st.info("没有可用的插件")
|
|
||||||
|
|
||||||
submit_create_kb = st.form_submit_button(
|
submit_create_kb = st.form_submit_button(
|
||||||
"新建",
|
"新建",
|
||||||
@ -143,23 +128,17 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if submit_create_kb:
|
if submit_create_kb:
|
||||||
|
|
||||||
endpoint_host, select_embed_model_name = get_select_embed_endpoint()
|
|
||||||
if not kb_name or not kb_name.strip():
|
if not kb_name or not kb_name.strip():
|
||||||
st.error(f"知识库名称不能为空!")
|
st.error(f"知识库名称不能为空!")
|
||||||
elif kb_name in kb_list:
|
elif kb_name in kb_list:
|
||||||
st.error(f"名为 {kb_name} 的知识库已经存在!")
|
st.error(f"名为 {kb_name} 的知识库已经存在!")
|
||||||
elif select_embed_model_name is None:
|
elif embed_model is None:
|
||||||
st.error(f"请选择Embedding模型!")
|
st.error(f"请选择Embedding模型!")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
ret = api.create_knowledge_base(
|
ret = api.create_knowledge_base(
|
||||||
knowledge_base_name=kb_name,
|
knowledge_base_name=kb_name,
|
||||||
vector_store_type=vs_type,
|
vector_store_type=vs_type,
|
||||||
embed_model=select_embed_model_name,
|
embed_model=embed_model,
|
||||||
endpoint_host=endpoint_host,
|
|
||||||
endpoint_host_key=OPENAI_KEY,
|
|
||||||
endpoint_host_proxy=OPENAI_PROXY,
|
|
||||||
)
|
)
|
||||||
st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
st.session_state["selected_kb_name"] = kb_name
|
st.session_state["selected_kb_name"] = kb_name
|
||||||
@ -169,9 +148,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
elif selected_kb:
|
elif selected_kb:
|
||||||
kb = selected_kb
|
kb = selected_kb
|
||||||
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
||||||
st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host']
|
|
||||||
st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key']
|
|
||||||
st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy']
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
files = st.file_uploader("上传知识文件:",
|
files = st.file_uploader("上传知识文件:",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
@ -185,37 +161,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
st.session_state["selected_kb_info"] = kb_info
|
st.session_state["selected_kb_info"] = kb_info
|
||||||
api.update_kb_info(kb, kb_info)
|
api.update_kb_info(kb, kb_info)
|
||||||
|
|
||||||
if st.session_state["kb_endpoint_host"] is not None:
|
|
||||||
with st.expander(
|
|
||||||
"在线api接入点配置",
|
|
||||||
expanded=True,
|
|
||||||
):
|
|
||||||
endpoint_host = st.text_input(
|
|
||||||
"接入点地址",
|
|
||||||
placeholder="接入点地址",
|
|
||||||
key="endpoint_host",
|
|
||||||
value=st.session_state["kb_endpoint_host"],
|
|
||||||
)
|
|
||||||
endpoint_host_key = st.text_input(
|
|
||||||
"接入点key",
|
|
||||||
placeholder="接入点key",
|
|
||||||
key="endpoint_host_key",
|
|
||||||
value=st.session_state["kb_endpoint_host_key"],
|
|
||||||
)
|
|
||||||
endpoint_host_proxy = st.text_input(
|
|
||||||
"接入点代理地址",
|
|
||||||
placeholder="接入点代理地址",
|
|
||||||
key="endpoint_host_proxy",
|
|
||||||
value=st.session_state["kb_endpoint_host_proxy"],
|
|
||||||
)
|
|
||||||
if endpoint_host != st.session_state["kb_endpoint_host"] \
|
|
||||||
or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \
|
|
||||||
or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]:
|
|
||||||
st.session_state["kb_endpoint_host"] = endpoint_host
|
|
||||||
st.session_state["kb_endpoint_host_key"] = endpoint_host_key
|
|
||||||
st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy
|
|
||||||
api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
|
||||||
|
|
||||||
# with st.sidebar:
|
# with st.sidebar:
|
||||||
with st.expander(
|
with st.expander(
|
||||||
"文件处理配置",
|
"文件处理配置",
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
DEFAULT_EMBEDDING_MODEL,
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
LLM_MODEL_CONFIG,
|
LLM_MODEL_CONFIG,
|
||||||
SCORE_THRESHOLD,
|
SCORE_THRESHOLD,
|
||||||
@ -266,7 +266,6 @@ class ApiRequest:
|
|||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
chat_model_config: Dict = None,
|
chat_model_config: Dict = None,
|
||||||
openai_config: Dict = None,
|
|
||||||
tool_config: Dict = None,
|
tool_config: Dict = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -281,7 +280,6 @@ class ApiRequest:
|
|||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"chat_model_config": chat_model_config,
|
"chat_model_config": chat_model_config,
|
||||||
"openai_config": openai_config,
|
|
||||||
"tool_config": tool_config,
|
"tool_config": tool_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,10 +379,7 @@ class ApiRequest:
|
|||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
vector_store_type: str = DEFAULT_VS_TYPE,
|
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
endpoint_host: str = None,
|
|
||||||
endpoint_host_key: str = None,
|
|
||||||
endpoint_host_proxy: str = None
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/create_knowledge_base接口
|
对应api.py/knowledge_base/create_knowledge_base接口
|
||||||
@ -393,9 +388,6 @@ class ApiRequest:
|
|||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"vector_store_type": vector_store_type,
|
"vector_store_type": vector_store_type,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"endpoint_host": endpoint_host,
|
|
||||||
"endpoint_host_key": endpoint_host_key,
|
|
||||||
"endpoint_host_proxy": endpoint_host_proxy,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
@ -459,24 +451,6 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_docs_by_id(
|
|
||||||
self,
|
|
||||||
knowledge_base_name: str,
|
|
||||||
docs: Dict[str, Dict],
|
|
||||||
) -> bool:
|
|
||||||
'''
|
|
||||||
对应api.py/knowledge_base/update_docs_by_id接口
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"knowledge_base_name": knowledge_base_name,
|
|
||||||
"docs": docs,
|
|
||||||
}
|
|
||||||
response = self.post(
|
|
||||||
"/knowledge_base/update_docs_by_id",
|
|
||||||
json=data
|
|
||||||
)
|
|
||||||
return self._get_response_value(response)
|
|
||||||
|
|
||||||
def upload_kb_docs(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
@ -562,26 +536,6 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_kb_endpoint(self,
|
|
||||||
knowledge_base_name,
|
|
||||||
endpoint_host: str = None,
|
|
||||||
endpoint_host_key: str = None,
|
|
||||||
endpoint_host_proxy: str = None):
|
|
||||||
'''
|
|
||||||
对应api.py/knowledge_base/update_info接口
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"knowledge_base_name": knowledge_base_name,
|
|
||||||
"endpoint_host": endpoint_host,
|
|
||||||
"endpoint_host_key": endpoint_host_key,
|
|
||||||
"endpoint_host_proxy": endpoint_host_proxy,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/knowledge_base/update_kb_endpoint",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True)
|
|
||||||
def update_kb_docs(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
@ -621,7 +575,7 @@ class ApiRequest:
|
|||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
allow_empty_kb: bool = True,
|
allow_empty_kb: bool = True,
|
||||||
vs_type: str = DEFAULT_VS_TYPE,
|
vs_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
@ -650,7 +604,7 @@ class ApiRequest:
|
|||||||
def embed_texts(
|
def embed_texts(
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
to_query: bool = False,
|
to_query: bool = False,
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
'''
|
'''
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user