mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)
- 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model}
- Lite 模式支持所有知识库相关功能。此模式下最主要的限制是:
- 不能使用本地 LLM 和 Embeddings 模型
- 知识库不支持 PDF 文件
- init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。
- API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型
问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误
开发者:
- 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误
- KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model
- 优化 kb_service 中 Embeddings 操作:
- 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
- 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents]
- 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
161 lines
5.9 KiB
Python
161 lines
5.9 KiB
Python
from langchain.embeddings.base import Embeddings
|
|
import threading
|
|
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
|
logger, log_verbose)
|
|
from server.utils import embedding_device, get_model_path, list_online_embed_models
|
|
from contextlib import contextmanager
|
|
from collections import OrderedDict
|
|
from typing import List, Any, Union, Tuple
|
|
|
|
|
|
class ThreadSafeObject:
|
|
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
|
self._obj = obj
|
|
self._key = key
|
|
self._pool = pool
|
|
self._lock = threading.RLock()
|
|
self._loaded = threading.Event()
|
|
|
|
def __repr__(self) -> str:
|
|
cls = type(self).__name__
|
|
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
|
|
|
@property
|
|
def key(self):
|
|
return self._key
|
|
|
|
@contextmanager
|
|
def acquire(self, owner: str = "", msg: str = ""):
|
|
owner = owner or f"thread {threading.get_native_id()}"
|
|
try:
|
|
self._lock.acquire()
|
|
if self._pool is not None:
|
|
self._pool._cache.move_to_end(self.key)
|
|
if log_verbose:
|
|
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
|
yield self._obj
|
|
finally:
|
|
if log_verbose:
|
|
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
|
self._lock.release()
|
|
|
|
def start_loading(self):
|
|
self._loaded.clear()
|
|
|
|
def finish_loading(self):
|
|
self._loaded.set()
|
|
|
|
def wait_for_loading(self):
|
|
self._loaded.wait()
|
|
|
|
@property
|
|
def obj(self):
|
|
return self._obj
|
|
|
|
@obj.setter
|
|
def obj(self, val: Any):
|
|
self._obj = val
|
|
|
|
|
|
class CachePool:
|
|
def __init__(self, cache_num: int = -1):
|
|
self._cache_num = cache_num
|
|
self._cache = OrderedDict()
|
|
self.atomic = threading.RLock()
|
|
|
|
def keys(self) -> List[str]:
|
|
return list(self._cache.keys())
|
|
|
|
def _check_count(self):
|
|
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
|
while len(self._cache) > self._cache_num:
|
|
self._cache.popitem(last=False)
|
|
|
|
def get(self, key: str) -> ThreadSafeObject:
|
|
if cache := self._cache.get(key):
|
|
cache.wait_for_loading()
|
|
return cache
|
|
|
|
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
|
self._cache[key] = obj
|
|
self._check_count()
|
|
return obj
|
|
|
|
def pop(self, key: str = None) -> ThreadSafeObject:
|
|
if key is None:
|
|
return self._cache.popitem(last=False)
|
|
else:
|
|
return self._cache.pop(key, None)
|
|
|
|
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
|
cache = self.get(key)
|
|
if cache is None:
|
|
raise RuntimeError(f"请求的资源 {key} 不存在")
|
|
elif isinstance(cache, ThreadSafeObject):
|
|
self._cache.move_to_end(key)
|
|
return cache.acquire(owner=owner, msg=msg)
|
|
else:
|
|
return cache
|
|
|
|
def load_kb_embeddings(
|
|
self,
|
|
kb_name: str,
|
|
embed_device: str = embedding_device(),
|
|
default_embed_model: str = EMBEDDING_MODEL,
|
|
) -> Embeddings:
|
|
from server.db.repository.knowledge_base_repository import get_kb_detail
|
|
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
|
|
|
kb_detail = get_kb_detail(kb_name)
|
|
embed_model = kb_detail.get("embed_model", default_embed_model)
|
|
|
|
if embed_model in list_online_embed_models():
|
|
return EmbeddingsFunAdapter(embed_model)
|
|
else:
|
|
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
|
|
|
|
|
class EmbeddingsPool(CachePool):
|
|
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
|
self.atomic.acquire()
|
|
model = model or EMBEDDING_MODEL
|
|
device = device or embedding_device()
|
|
key = (model, device)
|
|
if not self.get(key):
|
|
item = ThreadSafeObject(key, pool=self)
|
|
self.set(key, item)
|
|
with item.acquire(msg="初始化"):
|
|
self.atomic.release()
|
|
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
embeddings = OpenAIEmbeddings(model_name=model,
|
|
openai_api_key=get_model_path(model),
|
|
chunk_size=CHUNK_SIZE)
|
|
elif 'bge-' in model:
|
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|
if 'zh' in model:
|
|
# for chinese model
|
|
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
|
elif 'en' in model:
|
|
# for english model
|
|
query_instruction = "Represent this sentence for searching relevant passages:"
|
|
else:
|
|
# maybe ReRanker or else, just use empty string instead
|
|
query_instruction = ""
|
|
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
|
model_kwargs={'device': device},
|
|
query_instruction=query_instruction)
|
|
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
|
embeddings.query_instruction = ""
|
|
else:
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
|
|
item.obj = embeddings
|
|
item.finish_loading()
|
|
else:
|
|
self.atomic.release()
|
|
return self.get(key).obj
|
|
|
|
|
|
embeddings_pool = EmbeddingsPool(cache_num=1)
|