mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-03 05:03:12 +08:00
34 lines
996 B
Python
34 lines
996 B
Python
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
|
from langchain.vectorstores import VectorStore
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
|
|
class VectorstoreRetrieverService(BaseRetrieverService):
|
|
def do_init(
|
|
self,
|
|
retriever: BaseRetriever = None,
|
|
top_k: int = 5,
|
|
):
|
|
self.vs = None
|
|
self.top_k = top_k
|
|
self.retriever = None
|
|
|
|
|
|
@staticmethod
|
|
def from_vectorstore(
|
|
vectorstore: VectorStore,
|
|
top_k: int,
|
|
score_threshold: int or float,
|
|
):
|
|
retriever = vectorstore.as_retriever(
|
|
search_type="similarity_score_threshold",
|
|
search_kwargs={
|
|
"score_threshold": score_threshold,
|
|
"k": top_k
|
|
}
|
|
)
|
|
return VectorstoreRetrieverService(retriever=retriever)
|
|
|
|
def get_related_documents(self, query: str):
|
|
self.retriever.get_relevant_documents(query)[:self.top_k]
|