修复获取reranker_model_path出错的问题 (#3458)

This commit is contained in:
saliven1970 2024-04-15 21:53:06 +08:00 committed by GitHub
parent 341ee9db44
commit f29ab1e67f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,18 +7,16 @@ from configs import (LLM_MODELS,
TEMPERATURE, TEMPERATURE,
USE_RERANKER, USE_RERANKER,
RERANKER_MODEL, RERANKER_MODEL,
RERANKER_MAX_LENGTH, RERANKER_MAX_LENGTH)
MODEL_PATH) from server.utils import wrap_done, get_ChatOpenAI, get_model_path
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
import asyncio import asyncio, json
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker from server.reranker.reranker import LangchainReranker
@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
# 加入reranker # 加入reranker
if USE_RERANKER: if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") reranker_model_path = get_model_path(RERANKER_MODEL)
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k, reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(), device=embedding_device(),
max_length=RERANKER_MAX_LENGTH, max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path model_name_or_path=reranker_model_path
) )
print("-------------before rerank-----------------")
print(docs) print(docs)
docs = reranker_model.compress_documents(documents=docs, docs = reranker_model.compress_documents(documents=docs,
query=query) query=query)
print("---------after rerank------------------") print("------------after rerank------------------")
print(docs) print(docs)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
await task await task
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name)) return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))