mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 21:53:14 +08:00
优化 ES 知识库
- 开发者
- get_OpenAIClient 的 local_wrap 默认值改为 False,避免 API 服务未启动导致其它功能受阻(如Embeddings)
- 修改 ES 知识库服务:
- 检索策略改为 ApproxRetrievalStrategy
- 设置 timeout 为 60, 避免文档过多导致 ConnecitonTimeout Error
- 修改 LocalAIEmbeddings,使用多线程进行 embed_texts,效果不明显,瓶颈可能主要在提供 Embedding 的服务器上
This commit is contained in:
parent
c839a1791a
commit
51301dfe6a
@ -35,6 +35,7 @@ def create_models_from_config(configs, callbacks, stream):
|
|||||||
max_tokens=params.get('max_tokens', 1000),
|
max_tokens=params.get('max_tokens', 1000),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
streaming=stream,
|
streaming=stream,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
models[model_type] = model_instance
|
models[model_type] = model_instance
|
||||||
prompt_name = params.get('prompt_name', 'default')
|
prompt_name = params.get('prompt_name', 'default')
|
||||||
|
|||||||
@ -37,7 +37,8 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
echo=echo
|
echo=echo,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_template = get_prompt_template("completion", prompt_name)
|
prompt_template = get_prompt_template("completion", prompt_name)
|
||||||
|
|||||||
@ -124,6 +124,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
embed_func = get_Embeddings()
|
embed_func = get_Embeddings()
|
||||||
embeddings = await embed_func.aembed_query(query)
|
embeddings = await embed_func.aembed_query(query)
|
||||||
|
|||||||
@ -2,13 +2,13 @@ from typing import List
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
from langchain_community.vectorstores.elasticsearch import ElasticsearchStore, ApproxRetrievalStrategy
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
from server.utils import get_Embeddings
|
from server.utils import get_Embeddings
|
||||||
from elasticsearch import Elasticsearch,BadRequestError
|
from elasticsearch import Elasticsearch, BadRequestError
|
||||||
from configs import logger
|
from configs import logger, kbs_config, KB_ROOT_PATH
|
||||||
from configs import kbs_config
|
|
||||||
|
|
||||||
class ESKBService(KBService):
|
class ESKBService(KBService):
|
||||||
|
|
||||||
@ -53,27 +53,24 @@ class ESKBService(KBService):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# langchain ES 连接、创建索引
|
# langchain ES 连接、创建索引
|
||||||
if self.user != "" and self.password != "":
|
params = dict(
|
||||||
self.db_init = ElasticsearchStore(
|
|
||||||
es_url=f"http://{self.IP}:{self.PORT}",
|
es_url=f"http://{self.IP}:{self.PORT}",
|
||||||
index_name=self.index_name,
|
index_name=self.index_name,
|
||||||
query_field="context",
|
query_field="context",
|
||||||
vector_query_field="dense_vector",
|
vector_query_field="dense_vector",
|
||||||
embedding=self.embeddings_model,
|
embedding=self.embeddings_model,
|
||||||
es_user=self.user,
|
strategy=ApproxRetrievalStrategy(),
|
||||||
es_password=self.password
|
es_params={
|
||||||
|
"timeout": 60,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
if self.user != "" and self.password != "":
|
||||||
logger.warning("ES未配置用户名和密码")
|
params.update(
|
||||||
self.db_init = ElasticsearchStore(
|
es_user=self.user,
|
||||||
es_url=f"http://{self.IP}:{self.PORT}",
|
es_password=self.password
|
||||||
index_name=self.index_name,
|
|
||||||
query_field="context",
|
|
||||||
vector_query_field="dense_vector",
|
|
||||||
embedding=self.embeddings_model,
|
|
||||||
)
|
)
|
||||||
|
self.db = ElasticsearchStore(**params)
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
print("### 初始化 Elasticsearch 失败!")
|
|
||||||
logger.error("### 初始化 Elasticsearch 失败!")
|
logger.error("### 初始化 Elasticsearch 失败!")
|
||||||
raise ConnectionError
|
raise ConnectionError
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -81,7 +78,7 @@ class ESKBService(KBService):
|
|||||||
raise e
|
raise e
|
||||||
try:
|
try:
|
||||||
# 尝试通过db_init创建索引
|
# 尝试通过db_init创建索引
|
||||||
self.db_init._create_index_if_not_exists(
|
self.db._create_index_if_not_exists(
|
||||||
index_name=self.index_name,
|
index_name=self.index_name,
|
||||||
dims_length=self.dims_length
|
dims_length=self.dims_length
|
||||||
)
|
)
|
||||||
@ -90,8 +87,6 @@ class ESKBService(KBService):
|
|||||||
logger.error(e)
|
logger.error(e)
|
||||||
# raise e
|
# raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kb_path(knowledge_base_name: str):
|
def get_kb_path(knowledge_base_name: str):
|
||||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||||
@ -110,46 +105,9 @@ class ESKBService(KBService):
|
|||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.ES
|
return SupportedVSType.ES
|
||||||
|
|
||||||
def _load_es(self, docs, embed_model):
|
|
||||||
# 将docs写入到ES中
|
|
||||||
try:
|
|
||||||
# 连接 + 同时写入文档
|
|
||||||
if self.user != "" and self.password != "":
|
|
||||||
self.db = ElasticsearchStore.from_documents(
|
|
||||||
documents=docs,
|
|
||||||
embedding=embed_model,
|
|
||||||
es_url= f"http://{self.IP}:{self.PORT}",
|
|
||||||
index_name=self.index_name,
|
|
||||||
distance_strategy="COSINE",
|
|
||||||
query_field="context",
|
|
||||||
vector_query_field="dense_vector",
|
|
||||||
verify_certs=False,
|
|
||||||
es_user=self.user,
|
|
||||||
es_password=self.password
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.db = ElasticsearchStore.from_documents(
|
|
||||||
documents=docs,
|
|
||||||
embedding=embed_model,
|
|
||||||
es_url= f"http://{self.IP}:{self.PORT}",
|
|
||||||
index_name=self.index_name,
|
|
||||||
distance_strategy="COSINE",
|
|
||||||
query_field="context",
|
|
||||||
vector_query_field="dense_vector",
|
|
||||||
verify_certs=False)
|
|
||||||
except ConnectionError as ce:
|
|
||||||
print(ce)
|
|
||||||
print("连接到 Elasticsearch 失败!")
|
|
||||||
logger.error("连接到 Elasticsearch 失败!")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 发生 : {e}")
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def do_search(self, query:str, top_k: int, score_threshold: float):
|
def do_search(self, query:str, top_k: int, score_threshold: float):
|
||||||
# 文本相似性检索
|
# 文本相似性检索
|
||||||
docs = self.db_init.similarity_search_with_score(query=query,
|
docs = self.db.similarity_search_with_score(query=query,
|
||||||
k=top_k)
|
k=top_k)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -200,15 +158,17 @@ class ESKBService(KBService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ES Docs Delete Error! {e}")
|
logger.error(f"ES Docs Delete Error! {e}")
|
||||||
|
|
||||||
# self.db_init.delete(ids=delete_list)
|
# self.db.delete(ids=delete_list)
|
||||||
#self.es_client_python.indices.refresh(index=self.index_name)
|
#self.es_client_python.indices.refresh(index=self.index_name)
|
||||||
|
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
def do_add_doc(self, docs: List[Document], **kwargs):
|
||||||
'''向知识库添加文件'''
|
'''向知识库添加文件'''
|
||||||
|
|
||||||
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
|
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
|
||||||
print("*"*100)
|
print("*"*100)
|
||||||
self._load_es(docs=docs, embed_model=self.embeddings_model)
|
|
||||||
|
self.db.add_documents(documents=docs)
|
||||||
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
|
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
|
||||||
print("写入数据成功.")
|
print("写入数据成功.")
|
||||||
print("*"*100)
|
print("*"*100)
|
||||||
@ -229,8 +189,8 @@ class ESKBService(KBService):
|
|||||||
search_results = self.es_client_python.search(body=query, size=50)
|
search_results = self.es_client_python.search(body=query, size=50)
|
||||||
if len(search_results["hits"]["hits"]) == 0:
|
if len(search_results["hits"]["hits"]) == 0:
|
||||||
raise ValueError("召回元素个数为0")
|
raise ValueError("召回元素个数为0")
|
||||||
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
|
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
|
||||||
return info_docs
|
return info_docs
|
||||||
|
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
|
|||||||
@ -51,11 +51,13 @@ def recreate_summary_vector_store(
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
# 文本摘要适配器
|
# 文本摘要适配器
|
||||||
summary = SummaryAdapter.form_summary(llm=llm,
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
@ -131,11 +133,13 @@ def summary_file_to_vector_store(
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
# 文本摘要适配器
|
# 文本摘要适配器
|
||||||
summary = SummaryAdapter.form_summary(llm=llm,
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
@ -196,11 +200,13 @@ def summary_doc_ids_to_vector_store(
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
# 文本摘要适配器
|
# 文本摘要适配器
|
||||||
summary = SummaryAdapter.form_summary(llm=llm,
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
|
|||||||
@ -27,6 +27,8 @@ from tenacity import (
|
|||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
from server.utils import run_in_thread_pool
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -316,8 +318,14 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
# call _embedding_func for each text
|
# call _embedding_func for each text with multithreads
|
||||||
return [self._embedding_func(text, engine=self.deployment) for text in texts]
|
def task(seq, text):
|
||||||
|
return (seq, self._embedding_func(text, engine=self.deployment))
|
||||||
|
|
||||||
|
params = [{"seq": i, "text": text} for i, text in enumerate(texts)]
|
||||||
|
result = list(run_in_thread_pool(func=task, params=params))
|
||||||
|
result = sorted(result, key=lambda x: x[0])
|
||||||
|
return [x[1] for x in result]
|
||||||
|
|
||||||
async def aembed_documents(
|
async def aembed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def get_ChatOpenAI(
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
callbacks: List[Callable] = [],
|
callbacks: List[Callable] = [],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
local_wrap: bool = True, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI:
|
||||||
model_info = get_model_info(model_name)
|
model_info = get_model_info(model_name)
|
||||||
@ -160,7 +160,7 @@ def get_OpenAI(
|
|||||||
echo: bool = True,
|
echo: bool = True,
|
||||||
callbacks: List[Callable] = [],
|
callbacks: List[Callable] = [],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
local_wrap: bool = True, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OpenAI:
|
) -> OpenAI:
|
||||||
# TODO: 从API获取模型信息
|
# TODO: 从API获取模型信息
|
||||||
@ -196,7 +196,7 @@ def get_OpenAI(
|
|||||||
|
|
||||||
def get_Embeddings(
|
def get_Embeddings(
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
local_wrap: bool = True, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
) -> Embeddings:
|
) -> Embeddings:
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user