mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码 修改依赖文件,移除 torch transformers 等重依赖 暂时移出对 loom 的集成 后续: 1、优化目录结构 2、检查合并中有无被覆盖的 0.2.10 内容
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
from langchain.embeddings.base import Embeddings
|
|
from langchain.vectorstores.faiss import FAISS
|
|
import threading
|
|
from configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE,
|
|
logger, log_verbose)
|
|
from contextlib import contextmanager
|
|
from collections import OrderedDict
|
|
from typing import List, Any, Union, Tuple
|
|
|
|
|
|
class ThreadSafeObject:
|
|
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
|
self._obj = obj
|
|
self._key = key
|
|
self._pool = pool
|
|
self._lock = threading.RLock()
|
|
self._loaded = threading.Event()
|
|
|
|
def __repr__(self) -> str:
|
|
cls = type(self).__name__
|
|
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
|
|
|
@property
|
|
def key(self):
|
|
return self._key
|
|
|
|
@contextmanager
|
|
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
|
|
owner = owner or f"thread {threading.get_native_id()}"
|
|
try:
|
|
self._lock.acquire()
|
|
if self._pool is not None:
|
|
self._pool._cache.move_to_end(self.key)
|
|
if log_verbose:
|
|
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
|
yield self._obj
|
|
finally:
|
|
if log_verbose:
|
|
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
|
self._lock.release()
|
|
|
|
def start_loading(self):
|
|
self._loaded.clear()
|
|
|
|
def finish_loading(self):
|
|
self._loaded.set()
|
|
|
|
def wait_for_loading(self):
|
|
self._loaded.wait()
|
|
|
|
@property
|
|
def obj(self):
|
|
return self._obj
|
|
|
|
@obj.setter
|
|
def obj(self, val: Any):
|
|
self._obj = val
|
|
|
|
|
|
class CachePool:
|
|
def __init__(self, cache_num: int = -1):
|
|
self._cache_num = cache_num
|
|
self._cache = OrderedDict()
|
|
self.atomic = threading.RLock()
|
|
|
|
def keys(self) -> List[str]:
|
|
return list(self._cache.keys())
|
|
|
|
def _check_count(self):
|
|
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
|
while len(self._cache) > self._cache_num:
|
|
self._cache.popitem(last=False)
|
|
|
|
def get(self, key: str) -> ThreadSafeObject:
|
|
if cache := self._cache.get(key):
|
|
cache.wait_for_loading()
|
|
return cache
|
|
|
|
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
|
self._cache[key] = obj
|
|
self._check_count()
|
|
return obj
|
|
|
|
def pop(self, key: str = None) -> ThreadSafeObject:
|
|
if key is None:
|
|
return self._cache.popitem(last=False)
|
|
else:
|
|
return self._cache.pop(key, None)
|
|
|
|
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
|
cache = self.get(key)
|
|
if cache is None:
|
|
raise RuntimeError(f"请求的资源 {key} 不存在")
|
|
elif isinstance(cache, ThreadSafeObject):
|
|
self._cache.move_to_end(key)
|
|
return cache.acquire(owner=owner, msg=msg)
|
|
else:
|
|
return cache
|
|
|