mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)
- 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model}
- Lite 模式支持所有知识库相关功能。此模式下最主要的限制是:
- 不能使用本地 LLM 和 Embeddings 模型
- 知识库不支持 PDF 文件
- init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。
- API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型
问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误
开发者:
- 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误
- KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model
- 优化 kb_service 中 Embeddings 操作:
- 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
- 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents]
- 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
164 lines
6.0 KiB
Python
164 lines
6.0 KiB
Python
from configs import CACHED_VS_NUM
|
||
from server.knowledge_base.kb_cache.base import *
|
||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||
from server.utils import load_local_embeddings
|
||
from server.knowledge_base.utils import get_vs_path
|
||
from langchain.vectorstores.faiss import FAISS
|
||
from langchain.schema import Document
|
||
import os
|
||
from langchain.schema import Document
|
||
|
||
class ThreadSafeFaiss(ThreadSafeObject):
|
||
def __repr__(self) -> str:
|
||
cls = type(self).__name__
|
||
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||
|
||
def docs_count(self) -> int:
|
||
return len(self._obj.docstore._dict)
|
||
|
||
def save(self, path: str, create_path: bool = True):
|
||
with self.acquire():
|
||
if not os.path.isdir(path) and create_path:
|
||
os.makedirs(path)
|
||
ret = self._obj.save_local(path)
|
||
logger.info(f"已将向量库 {self.key} 保存到磁盘")
|
||
return ret
|
||
|
||
def clear(self):
|
||
ret = []
|
||
with self.acquire():
|
||
ids = list(self._obj.docstore._dict.keys())
|
||
if ids:
|
||
ret = self._obj.delete(ids)
|
||
assert len(self._obj.docstore._dict) == 0
|
||
logger.info(f"已将向量库 {self.key} 清空")
|
||
return ret
|
||
|
||
|
||
class _FaissPool(CachePool):
|
||
def new_vector_store(
|
||
self,
|
||
embed_model: str = EMBEDDING_MODEL,
|
||
embed_device: str = embedding_device(),
|
||
) -> FAISS:
|
||
# TODO: 整个Embeddings加载逻辑有些混乱,待清理
|
||
# create an empty vector store
|
||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||
doc = Document(page_content="init", metadata={})
|
||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||
ids = list(vector_store.docstore._dict.keys())
|
||
vector_store.delete(ids)
|
||
return vector_store
|
||
|
||
def save_vector_store(self, kb_name: str, path: str=None):
|
||
if cache := self.get(kb_name):
|
||
return cache.save(path)
|
||
|
||
def unload_vector_store(self, kb_name: str):
|
||
if cache := self.get(kb_name):
|
||
self.pop(kb_name)
|
||
logger.info(f"成功释放向量库:{kb_name}")
|
||
|
||
|
||
class KBFaissPool(_FaissPool):
|
||
def load_vector_store(
|
||
self,
|
||
kb_name: str,
|
||
vector_name: str = None,
|
||
create: bool = True,
|
||
embed_model: str = EMBEDDING_MODEL,
|
||
embed_device: str = embedding_device(),
|
||
) -> ThreadSafeFaiss:
|
||
self.atomic.acquire()
|
||
vector_name = vector_name or embed_model
|
||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||
if cache is None:
|
||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||
self.set((kb_name, vector_name), item)
|
||
with item.acquire(msg="初始化"):
|
||
self.atomic.release()
|
||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||
vs_path = get_vs_path(kb_name, vector_name)
|
||
|
||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||
elif create:
|
||
# create an empty vector store
|
||
if not os.path.exists(vs_path):
|
||
os.makedirs(vs_path)
|
||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||
vector_store.save_local(vs_path)
|
||
else:
|
||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||
item.obj = vector_store
|
||
item.finish_loading()
|
||
else:
|
||
self.atomic.release()
|
||
return self.get((kb_name, vector_name))
|
||
|
||
|
||
class MemoFaissPool(_FaissPool):
|
||
def load_vector_store(
|
||
self,
|
||
kb_name: str,
|
||
embed_model: str = EMBEDDING_MODEL,
|
||
embed_device: str = embedding_device(),
|
||
) -> ThreadSafeFaiss:
|
||
self.atomic.acquire()
|
||
cache = self.get(kb_name)
|
||
if cache is None:
|
||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||
self.set(kb_name, item)
|
||
with item.acquire(msg="初始化"):
|
||
self.atomic.release()
|
||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||
# create an empty vector store
|
||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||
item.obj = vector_store
|
||
item.finish_loading()
|
||
else:
|
||
self.atomic.release()
|
||
return self.get(kb_name)
|
||
|
||
|
||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||
memo_faiss_pool = MemoFaissPool()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import time, random
|
||
from pprint import pprint
|
||
|
||
kb_names = ["vs1", "vs2", "vs3"]
|
||
# for name in kb_names:
|
||
# memo_faiss_pool.load_vector_store(name)
|
||
|
||
def worker(vs_name: str, name: str):
|
||
vs_name = "samples"
|
||
time.sleep(random.randint(1, 5))
|
||
embeddings = load_local_embeddings()
|
||
r = random.randint(1, 3)
|
||
|
||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||
if r == 1: # add docs
|
||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||
pprint(ids)
|
||
elif r == 2: # search docs
|
||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||
pprint(docs)
|
||
if r == 3: # delete docs
|
||
logger.warning(f"清除 {vs_name} by {name}")
|
||
kb_faiss_pool.get(vs_name).clear()
|
||
|
||
threads = []
|
||
for n in range(1, 30):
|
||
t = threading.Thread(target=worker,
|
||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||
daemon=True)
|
||
t.start()
|
||
threads.append(t)
|
||
|
||
for t in threads:
|
||
t.join()
|