mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-25 16:23:22 +08:00
95 lines
4.5 KiB
Python
95 lines
4.5 KiB
Python
from langchain.docstore.document import Document
|
||
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
|
||
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
|
||
from fastapi import Body
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from typing import Dict, List
|
||
|
||
|
||
def embed_texts(
|
||
texts: List[str],
|
||
endpoint_host: str,
|
||
endpoint_host_key: str,
|
||
endpoint_host_proxy: str,
|
||
embed_model: str = EMBEDDING_MODEL,
|
||
to_query: bool = False,
|
||
) -> BaseResponse:
|
||
'''
|
||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||
TODO: 也许需要加入缓存机制,减少 token 消耗
|
||
'''
|
||
try:
|
||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||
from server.utils import load_local_embeddings
|
||
embeddings = load_local_embeddings(model=embed_model)
|
||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||
|
||
# 使用在线API
|
||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||
endpoint_host_key=endpoint_host_key,
|
||
endpoint_host_proxy=endpoint_host_proxy):
|
||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||
embeddings = OpenAIEmbeddings(model=embed_model,
|
||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||
chunk_size=CHUNK_SIZE)
|
||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||
|
||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||
except Exception as e:
|
||
logger.error(e)
|
||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||
|
||
|
||
async def aembed_texts(
|
||
texts: List[str],
|
||
endpoint_host: str,
|
||
endpoint_host_key: str,
|
||
endpoint_host_proxy: str,
|
||
embed_model: str = EMBEDDING_MODEL,
|
||
to_query: bool = False,
|
||
) -> BaseResponse:
|
||
'''
|
||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||
'''
|
||
try:
|
||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||
from server.utils import load_local_embeddings
|
||
|
||
embeddings = load_local_embeddings(model=embed_model)
|
||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||
|
||
# 使用在线API
|
||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||
endpoint_host_key=endpoint_host_key,
|
||
endpoint_host_proxy=endpoint_host_proxy):
|
||
return await run_in_threadpool(embed_texts,
|
||
texts=texts,
|
||
endpoint_host=endpoint_host,
|
||
endpoint_host_key=endpoint_host_key,
|
||
endpoint_host_proxy=endpoint_host_proxy,
|
||
embed_model=embed_model,
|
||
to_query=to_query)
|
||
except Exception as e:
|
||
logger.error(e)
|
||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||
|
||
|
||
def embed_texts_endpoint(
|
||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||
endpoint_host: str = Body(None, description="接入点地址"),
|
||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||
) -> BaseResponse:
|
||
'''
|
||
接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||
'''
|
||
return embed_texts(texts=texts,
|
||
endpoint_host=endpoint_host,
|
||
endpoint_host_key=endpoint_host_key,
|
||
endpoint_host_proxy=endpoint_host_proxy,
|
||
embed_model=embed_model, to_query=to_query)
|