From f29ab1e67f89a9e47772dc23bdfb23026baf2337 Mon Sep 17 00:00:00 2001 From: saliven1970 <37821954+saliven1970@users.noreply.github.com> Date: Mon, 15 Apr 2024 21:53:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=8E=B7=E5=8F=96reranker=5F?= =?UTF-8?q?model=5Fpath=E5=87=BA=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98=20(#3?= =?UTF-8?q?458)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/knowledge_base_chat.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 60956b44..8a9197f5 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -7,18 +7,16 @@ from configs import (LLM_MODELS, TEMPERATURE, USE_RERANKER, RERANKER_MODEL, - RERANKER_MAX_LENGTH, - MODEL_PATH) -from server.utils import wrap_done, get_ChatOpenAI + RERANKER_MAX_LENGTH) +from server.utils import wrap_done, get_ChatOpenAI, get_model_path from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional -import asyncio +import asyncio, json from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory -import json from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs from server.reranker.reranker import LangchainReranker @@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", # 加入reranker if USE_RERANKER: - reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") - print("-----------------model path------------------") - print(reranker_model_path) + reranker_model_path = get_model_path(RERANKER_MODEL) reranker_model = LangchainReranker(top_n=top_k, device=embedding_device(), max_length=RERANKER_MAX_LENGTH, model_name_or_path=reranker_model_path ) + print("-------------before rerank-----------------") print(docs) docs = reranker_model.compress_documents(documents=docs, query=query) - print("---------after rerank------------------") + print("------------after rerank------------------") print(docs) context = "\n".join([doc.page_content for doc in docs]) @@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", await task return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name)) -