diff --git a/README.md b/README.md index bfdee2e9..295dcf37 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) - [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) - [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) +- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct) - [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence) - [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase) - [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) @@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese) - [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) - [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) +- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings) --- @@ -206,6 +208,7 @@ embedding_model_dict = { "m3e-base": "/Users/xxx/Downloads/m3e-base", } ``` +如果你选择使用OpenAI的Embedding模型,请将模型的```key```写入`embedding_model_dict`中。使用该模型,你需要鞥能够访问OpenAI官的API,或设置代理。 ### 4. 知识库初始化与迁移 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 0dab00f3..3d9e02f3 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -24,7 +24,9 @@ embedding_model_dict = { "m3e-large": "moka-ai/m3e-large", "bge-small-zh": "BAAI/bge-small-zh", "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh" + "bge-large-zh": "BAAI/bge-large-zh", + "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", + "text-embedding-ada-002": os.environ.get("OPENAI_API_KEY") } # 选用的 Embedding 名称 diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 9fccfa23..b3f5439b 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -13,7 +13,8 @@ from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings from typing import List from langchain.docstore.document import Document from server.utils import torch_gc @@ -21,10 +22,19 @@ from server.utils import torch_gc # make HuggingFaceEmbeddings hashable def _embeddings_hash(self): - return hash(self.model_name) - + if isinstance(self, HuggingFaceEmbeddings): + return hash(self.model_name) + elif isinstance(self, HuggingFaceBgeEmbeddings): + return hash(self.model_name) + elif isinstance(self, OpenAIEmbeddings): + return hash(self.model) HuggingFaceEmbeddings.__hash__ = _embeddings_hash +OpenAIEmbeddings.__hash__ = _embeddings_hash +HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash + +_VECTOR_STORE_TICKS = {} + _VECTOR_STORE_TICKS = {} diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 4f0bad9c..da530495 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,7 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings import HuggingFaceBgeEmbeddings from configs.model_config import ( embedding_model_dict, KB_ROOT_PATH, @@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str): @lru_cache(1) def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) + if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + elif 'bge-' in model: + embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}, + query_instruction="为这个句子生成表示以用于检索相关文章:") + if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding + embeddings.query_instruction = "" + else: + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings + LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', '.pdf',