From 51301dfe6a4b6ac9234d50cc4e41bcd4dacd458b Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 7 Mar 2024 11:58:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=20ES=20=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=20-=20=E5=BC=80=E5=8F=91=E8=80=85=20=20=20=20=20-=20g?= =?UTF-8?q?et=5FOpenAIClient=20=E7=9A=84=20local=5Fwrap=20=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=80=BC=E6=94=B9=E4=B8=BA=20False=EF=BC=8C=E9=81=BF?= =?UTF-8?q?=E5=85=8D=20API=20=E6=9C=8D=E5=8A=A1=E6=9C=AA=E5=90=AF=E5=8A=A8?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E5=85=B6=E5=AE=83=E5=8A=9F=E8=83=BD=E5=8F=97?= =?UTF-8?q?=E9=98=BB=EF=BC=88=E5=A6=82Embeddings=EF=BC=89=20=20=20=20=20-?= =?UTF-8?q?=20=E4=BF=AE=E6=94=B9=20ES=20=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=EF=BC=9A=20=09-=20=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E7=AD=96=E7=95=A5=E6=94=B9=E4=B8=BA=20ApproxRetrievalStrategy?= =?UTF-8?q?=20=09-=20=E8=AE=BE=E7=BD=AE=20timeout=20=E4=B8=BA=2060?= =?UTF-8?q?=EF=BC=8C=20=E9=81=BF=E5=85=8D=E6=96=87=E6=A1=A3=E8=BF=87?= =?UTF-8?q?=E5=A4=9A=E5=AF=BC=E8=87=B4=20ConnecitonTimeout=20Error=20=20?= =?UTF-8?q?=20=20=20-=20=E4=BF=AE=E6=94=B9=20LocalAIEmbeddings=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=A4=9A=E7=BA=BF=E7=A8=8B=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=20=20embed=5Ftexts=EF=BC=8C=E6=95=88=E6=9E=9C=E4=B8=8D?= =?UTF-8?q?=E6=98=8E=E6=98=BE=EF=BC=8C=E7=93=B6=E9=A2=88=E5=8F=AF=E8=83=BD?= =?UTF-8?q?=E4=B8=BB=E8=A6=81=E5=9C=A8=E6=8F=90=E4=BE=9B=20Embedding=20?= =?UTF-8?q?=E7=9A=84=E6=9C=8D=E5=8A=A1=E5=99=A8=E4=B8=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/chat.py | 1 + server/chat/completion.py | 3 +- server/chat/file_chat.py | 1 + .../kb_service/es_kb_service.py | 84 +++++-------------- server/knowledge_base/kb_summary_api.py | 6 ++ server/localai_embeddings.py | 12 ++- server/utils.py | 6 +- 7 files changed, 45 insertions(+), 68 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index 59728b43..80e28f2e 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -35,6 +35,7 @@ def create_models_from_config(configs, callbacks, stream): max_tokens=params.get('max_tokens', 1000), callbacks=callbacks, streaming=stream, + local_wrap=True, ) models[model_type] = model_instance prompt_name = params.get('prompt_name', 'default') diff --git a/server/chat/completion.py b/server/chat/completion.py index 05b45740..02092bc7 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -37,7 +37,8 @@ async def completion(query: str = Body(..., description="用户输入", examples temperature=temperature, max_tokens=max_tokens, callbacks=[callback], - echo=echo + echo=echo, + local_wrap=True, ) prompt_template = get_prompt_template("completion", prompt_name) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index f9b31714..3d243769 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -124,6 +124,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= temperature=temperature, max_tokens=max_tokens, callbacks=[callback], + local_wrap=True, ) embed_func = get_Embeddings() embeddings = await embed_func.aembed_query(query) diff --git a/server/knowledge_base/kb_service/es_kb_service.py b/server/knowledge_base/kb_service/es_kb_service.py index 327633ef..a00dafc9 100644 --- a/server/knowledge_base/kb_service/es_kb_service.py +++ b/server/knowledge_base/kb_service/es_kb_service.py @@ -2,13 +2,13 @@ from typing import List import os import shutil from langchain.schema import Document -from langchain.vectorstores.elasticsearch import ElasticsearchStore +from langchain_community.vectorstores.elasticsearch import ElasticsearchStore, ApproxRetrievalStrategy from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile from server.utils import get_Embeddings -from elasticsearch import Elasticsearch,BadRequestError -from configs import logger -from configs import kbs_config +from elasticsearch import Elasticsearch, BadRequestError +from configs import logger, kbs_config, KB_ROOT_PATH + class ESKBService(KBService): @@ -53,27 +53,24 @@ class ESKBService(KBService): try: # langchain ES 连接、创建索引 - if self.user != "" and self.password != "": - self.db_init = ElasticsearchStore( + params = dict( es_url=f"http://{self.IP}:{self.PORT}", index_name=self.index_name, query_field="context", vector_query_field="dense_vector", embedding=self.embeddings_model, - es_user=self.user, - es_password=self.password + strategy=ApproxRetrievalStrategy(), + es_params={ + "timeout": 60, + } ) - else: - logger.warning("ES未配置用户名和密码") - self.db_init = ElasticsearchStore( - es_url=f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - query_field="context", - vector_query_field="dense_vector", - embedding=self.embeddings_model, + if self.user != "" and self.password != "": + params.update( + es_user=self.user, + es_password=self.password ) + self.db = ElasticsearchStore(**params) except ConnectionError: - print("### 初始化 Elasticsearch 失败!") logger.error("### 初始化 Elasticsearch 失败!") raise ConnectionError except Exception as e: @@ -81,7 +78,7 @@ class ESKBService(KBService): raise e try: # 尝试通过db_init创建索引 - self.db_init._create_index_if_not_exists( + self.db._create_index_if_not_exists( index_name=self.index_name, dims_length=self.dims_length ) @@ -90,8 +87,6 @@ class ESKBService(KBService): logger.error(e) # raise e - - @staticmethod def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) @@ -110,46 +105,9 @@ class ESKBService(KBService): def vs_type(self) -> str: return SupportedVSType.ES - def _load_es(self, docs, embed_model): - # 将docs写入到ES中 - try: - # 连接 + 同时写入文档 - if self.user != "" and self.password != "": - self.db = ElasticsearchStore.from_documents( - documents=docs, - embedding=embed_model, - es_url= f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - distance_strategy="COSINE", - query_field="context", - vector_query_field="dense_vector", - verify_certs=False, - es_user=self.user, - es_password=self.password - ) - else: - self.db = ElasticsearchStore.from_documents( - documents=docs, - embedding=embed_model, - es_url= f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - distance_strategy="COSINE", - query_field="context", - vector_query_field="dense_vector", - verify_certs=False) - except ConnectionError as ce: - print(ce) - print("连接到 Elasticsearch 失败!") - logger.error("连接到 Elasticsearch 失败!") - except Exception as e: - logger.error(f"Error 发生 : {e}") - print(e) - - - def do_search(self, query:str, top_k: int, score_threshold: float): # 文本相似性检索 - docs = self.db_init.similarity_search_with_score(query=query, + docs = self.db.similarity_search_with_score(query=query, k=top_k) return docs @@ -200,15 +158,17 @@ class ESKBService(KBService): except Exception as e: logger.error(f"ES Docs Delete Error! {e}") - # self.db_init.delete(ids=delete_list) + # self.db.delete(ids=delete_list) #self.es_client_python.indices.refresh(index=self.index_name) def do_add_doc(self, docs: List[Document], **kwargs): '''向知识库添加文件''' + print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}") print("*"*100) - self._load_es(docs=docs, embed_model=self.embeddings_model) + + self.db.add_documents(documents=docs) # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] print("写入数据成功.") print("*"*100) @@ -229,8 +189,8 @@ class ESKBService(KBService): search_results = self.es_client_python.search(body=query, size=50) if len(search_results["hits"]["hits"]) == 0: raise ValueError("召回元素个数为0") - info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] - return info_docs + info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] + return info_docs def do_clear_vs(self): diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 86263e8e..7ee42956 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -51,11 +51,13 @@ def recreate_summary_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, @@ -131,11 +133,13 @@ def summary_file_to_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, @@ -196,11 +200,13 @@ def summary_doc_ids_to_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, diff --git a/server/localai_embeddings.py b/server/localai_embeddings.py index f52681a9..44939319 100644 --- a/server/localai_embeddings.py +++ b/server/localai_embeddings.py @@ -27,6 +27,8 @@ from tenacity import ( stop_after_attempt, wait_exponential, ) +from server.utils import run_in_thread_pool + logger = logging.getLogger(__name__) @@ -316,8 +318,14 @@ class LocalAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - # call _embedding_func for each text - return [self._embedding_func(text, engine=self.deployment) for text in texts] + # call _embedding_func for each text with multithreads + def task(seq, text): + return (seq, self._embedding_func(text, engine=self.deployment)) + + params = [{"seq": i, "text": text} for i, text in enumerate(texts)] + result = list(run_in_thread_pool(func=task, params=params)) + result = sorted(result, key=lambda x: x[0]) + return [x[1] for x in result] async def aembed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 diff --git a/server/utils.py b/server/utils.py index 5365d5f0..c3541504 100644 --- a/server/utils.py +++ b/server/utils.py @@ -120,7 +120,7 @@ def get_ChatOpenAI( streaming: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) @@ -160,7 +160,7 @@ def get_OpenAI( echo: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 @@ -196,7 +196,7 @@ def get_OpenAI( def get_Embeddings( embed_model: str = DEFAULT_EMBEDDING_MODEL, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api ) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154