修复获取reranker_model_path出错的问题 (#3458)

This commit is contained in:
saliven1970 2024-04-15 21:53:06 +08:00 committed by GitHub
parent 341ee9db44
commit f29ab1e67f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,18 +7,16 @@ from configs import (LLM_MODELS,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import wrap_done, get_ChatOpenAI
RERANKER_MAX_LENGTH)
from server.utils import wrap_done, get_ChatOpenAI, get_model_path
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
import asyncio, json
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model_path = get_model_path(RERANKER_MODEL)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print("-------------before rerank-----------------")
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print("------------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs])
@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
await task
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))