diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 60956b44..8a9197f5 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -7,18 +7,16 @@ from configs import (LLM_MODELS, TEMPERATURE, USE_RERANKER, RERANKER_MODEL, - RERANKER_MAX_LENGTH, - MODEL_PATH) -from server.utils import wrap_done, get_ChatOpenAI + RERANKER_MAX_LENGTH) +from server.utils import wrap_done, get_ChatOpenAI, get_model_path from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional -import asyncio +import asyncio, json from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory -import json from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs from server.reranker.reranker import LangchainReranker @@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", # 加入reranker if USE_RERANKER: - reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") - print("-----------------model path------------------") - print(reranker_model_path) + reranker_model_path = get_model_path(RERANKER_MODEL) reranker_model = LangchainReranker(top_n=top_k, device=embedding_device(), max_length=RERANKER_MAX_LENGTH, model_name_or_path=reranker_model_path ) + print("-------------before rerank-----------------") print(docs) docs = reranker_model.compress_documents(documents=docs, query=query) - print("---------after rerank------------------") + print("------------after rerank------------------") print(docs) context = "\n".join([doc.page_content for doc in docs]) @@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", await task return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name)) -