diff --git a/server/chat/chat.py b/server/chat/chat.py index 59728b43..80e28f2e 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -35,6 +35,7 @@ def create_models_from_config(configs, callbacks, stream): max_tokens=params.get('max_tokens', 1000), callbacks=callbacks, streaming=stream, + local_wrap=True, ) models[model_type] = model_instance prompt_name = params.get('prompt_name', 'default') diff --git a/server/chat/completion.py b/server/chat/completion.py index 05b45740..02092bc7 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -37,7 +37,8 @@ async def completion(query: str = Body(..., description="用户输入", examples temperature=temperature, max_tokens=max_tokens, callbacks=[callback], - echo=echo + echo=echo, + local_wrap=True, ) prompt_template = get_prompt_template("completion", prompt_name) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index f9b31714..3d243769 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -124,6 +124,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= temperature=temperature, max_tokens=max_tokens, callbacks=[callback], + local_wrap=True, ) embed_func = get_Embeddings() embeddings = await embed_func.aembed_query(query) diff --git a/server/knowledge_base/kb_service/es_kb_service.py b/server/knowledge_base/kb_service/es_kb_service.py index 327633ef..a00dafc9 100644 --- a/server/knowledge_base/kb_service/es_kb_service.py +++ b/server/knowledge_base/kb_service/es_kb_service.py @@ -2,13 +2,13 @@ from typing import List import os import shutil from langchain.schema import Document -from langchain.vectorstores.elasticsearch import ElasticsearchStore +from langchain_community.vectorstores.elasticsearch import ElasticsearchStore, ApproxRetrievalStrategy from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile from server.utils import get_Embeddings -from elasticsearch import Elasticsearch,BadRequestError -from configs import logger -from configs import kbs_config +from elasticsearch import Elasticsearch, BadRequestError +from configs import logger, kbs_config, KB_ROOT_PATH + class ESKBService(KBService): @@ -53,27 +53,24 @@ class ESKBService(KBService): try: # langchain ES 连接、创建索引 - if self.user != "" and self.password != "": - self.db_init = ElasticsearchStore( + params = dict( es_url=f"http://{self.IP}:{self.PORT}", index_name=self.index_name, query_field="context", vector_query_field="dense_vector", embedding=self.embeddings_model, - es_user=self.user, - es_password=self.password + strategy=ApproxRetrievalStrategy(), + es_params={ + "timeout": 60, + } ) - else: - logger.warning("ES未配置用户名和密码") - self.db_init = ElasticsearchStore( - es_url=f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - query_field="context", - vector_query_field="dense_vector", - embedding=self.embeddings_model, + if self.user != "" and self.password != "": + params.update( + es_user=self.user, + es_password=self.password ) + self.db = ElasticsearchStore(**params) except ConnectionError: - print("### 初始化 Elasticsearch 失败!") logger.error("### 初始化 Elasticsearch 失败!") raise ConnectionError except Exception as e: @@ -81,7 +78,7 @@ class ESKBService(KBService): raise e try: # 尝试通过db_init创建索引 - self.db_init._create_index_if_not_exists( + self.db._create_index_if_not_exists( index_name=self.index_name, dims_length=self.dims_length ) @@ -90,8 +87,6 @@ class ESKBService(KBService): logger.error(e) # raise e - - @staticmethod def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) @@ -110,46 +105,9 @@ class ESKBService(KBService): def vs_type(self) -> str: return SupportedVSType.ES - def _load_es(self, docs, embed_model): - # 将docs写入到ES中 - try: - # 连接 + 同时写入文档 - if self.user != "" and self.password != "": - self.db = ElasticsearchStore.from_documents( - documents=docs, - embedding=embed_model, - es_url= f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - distance_strategy="COSINE", - query_field="context", - vector_query_field="dense_vector", - verify_certs=False, - es_user=self.user, - es_password=self.password - ) - else: - self.db = ElasticsearchStore.from_documents( - documents=docs, - embedding=embed_model, - es_url= f"http://{self.IP}:{self.PORT}", - index_name=self.index_name, - distance_strategy="COSINE", - query_field="context", - vector_query_field="dense_vector", - verify_certs=False) - except ConnectionError as ce: - print(ce) - print("连接到 Elasticsearch 失败!") - logger.error("连接到 Elasticsearch 失败!") - except Exception as e: - logger.error(f"Error 发生 : {e}") - print(e) - - - def do_search(self, query:str, top_k: int, score_threshold: float): # 文本相似性检索 - docs = self.db_init.similarity_search_with_score(query=query, + docs = self.db.similarity_search_with_score(query=query, k=top_k) return docs @@ -200,15 +158,17 @@ class ESKBService(KBService): except Exception as e: logger.error(f"ES Docs Delete Error! {e}") - # self.db_init.delete(ids=delete_list) + # self.db.delete(ids=delete_list) #self.es_client_python.indices.refresh(index=self.index_name) def do_add_doc(self, docs: List[Document], **kwargs): '''向知识库添加文件''' + print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}") print("*"*100) - self._load_es(docs=docs, embed_model=self.embeddings_model) + + self.db.add_documents(documents=docs) # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] print("写入数据成功.") print("*"*100) @@ -229,8 +189,8 @@ class ESKBService(KBService): search_results = self.es_client_python.search(body=query, size=50) if len(search_results["hits"]["hits"]) == 0: raise ValueError("召回元素个数为0") - info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] - return info_docs + info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] + return info_docs def do_clear_vs(self): diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 86263e8e..7ee42956 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -51,11 +51,13 @@ def recreate_summary_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, @@ -131,11 +133,13 @@ def summary_file_to_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, @@ -196,11 +200,13 @@ def summary_doc_ids_to_vector_store( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) reduce_llm = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, + local_wrap=True, ) # 文本摘要适配器 summary = SummaryAdapter.form_summary(llm=llm, diff --git a/server/localai_embeddings.py b/server/localai_embeddings.py index f52681a9..44939319 100644 --- a/server/localai_embeddings.py +++ b/server/localai_embeddings.py @@ -27,6 +27,8 @@ from tenacity import ( stop_after_attempt, wait_exponential, ) +from server.utils import run_in_thread_pool + logger = logging.getLogger(__name__) @@ -316,8 +318,14 @@ class LocalAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - # call _embedding_func for each text - return [self._embedding_func(text, engine=self.deployment) for text in texts] + # call _embedding_func for each text with multithreads + def task(seq, text): + return (seq, self._embedding_func(text, engine=self.deployment)) + + params = [{"seq": i, "text": text} for i, text in enumerate(texts)] + result = list(run_in_thread_pool(func=task, params=params)) + result = sorted(result, key=lambda x: x[0]) + return [x[1] for x in result] async def aembed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 diff --git a/server/utils.py b/server/utils.py index 5365d5f0..c3541504 100644 --- a/server/utils.py +++ b/server/utils.py @@ -120,7 +120,7 @@ def get_ChatOpenAI( streaming: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) @@ -160,7 +160,7 @@ def get_OpenAI( echo: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 @@ -196,7 +196,7 @@ def get_OpenAI( def get_Embeddings( embed_model: str = DEFAULT_EMBEDDING_MODEL, - local_wrap: bool = True, # use local wrapped api + local_wrap: bool = False, # use local wrapped api ) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154