2024-03-06 13:47:56 +08:00

95 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
def embed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化。返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制,减少 token 消耗
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=embed_model,
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
chunk_size=CHUNK_SIZE)
return BaseResponse(data=embeddings.embed_documents(texts))
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化。返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
return await run_in_threadpool(embed_texts,
texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
接入api对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model, to_query=to_query)