Langchain-Chatchat/server/embeddings_api.py
liunux4odoo 9c525b7fa5
publish 0.2.10 (#2797)
新功能:
- 优化 PDF 文件的 OCR,过滤无意义的小图片 by @liunux4odoo #2525
- 支持 Gemini 在线模型 by @yhfgyyf #2630
- 支持 GLM4 在线模型 by @zRzRzRzRzRzRzR
- elasticsearch更新https连接 by @xldistance #2390
- 增强对PPT、DOC知识库文件的OCR识别 by @596192804 #2013
- 更新 Agent 对话功能 by @zRzRzRzRzRzRzR
- 每次创建对象时从连接池获取连接,避免每次执行方法时都新建连接 by @Lijia0 #2480
- 实现 ChatOpenAI 判断token有没有超过模型的context上下文长度 by @glide-the
- 更新运行数据库报错和项目里程碑 by @zRzRzRzRzRzRzR #2659
- 更新配置文件/文档/依赖 by @imClumsyPanda @zRzRzRzRzRzRzR
- 添加日文版 readme by @eltociear #2787

修复:
- langchain 更新后,PGVector 向量库连接错误 by @HALIndex #2591
- Minimax's model worker 错误 by @xyhshen 
- ES库无法向量检索.添加mappings创建向量索引 by MSZheng20 #2688
2024-01-26 06:58:49 +08:00

97 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
online_embed_models = list_online_embed_models()
def embed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化。返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
config = get_model_worker_config(embed_model)
worker_class = config.get("worker_class")
embed_model = config.get("embed_model")
worker = worker_class()
if worker_class.can_embedding():
params = ApiEmbeddingsParams(texts=texts, to_query=to_query, embed_model=embed_model)
resp = worker.do_embeddings(params)
return BaseResponse(**resp)
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化。返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
return await run_in_threadpool(embed_texts,
texts=texts,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL,
description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
def embed_documents(
docs: List[Document],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> Dict:
"""
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
"""
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
if embeddings is not None:
return {
"texts": texts,
"embeddings": embeddings,
"metadatas": metadatas,
}