mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
修复获取reranker_model_path出错的问题 (#3458)
This commit is contained in:
parent
341ee9db44
commit
f29ab1e67f
@ -7,18 +7,16 @@ from configs import (LLM_MODELS,
|
||||
TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH)
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
RERANKER_MAX_LENGTH)
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_model_path
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable, List, Optional
|
||||
import asyncio
|
||||
import asyncio, json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
|
||||
# 加入reranker
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model_path = get_model_path(RERANKER_MODEL)
|
||||
reranker_model = LangchainReranker(top_n=top_k,
|
||||
device=embedding_device(),
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
print("-------------before rerank-----------------")
|
||||
print(docs)
|
||||
docs = reranker_model.compress_documents(documents=docs,
|
||||
query=query)
|
||||
print("---------after rerank------------------")
|
||||
print("------------after rerank------------------")
|
||||
print(docs)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
await task
|
||||
|
||||
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user