mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 15:38:27 +08:00
修复获取reranker_model_path出错的问题 (#3458)
This commit is contained in:
parent
341ee9db44
commit
f29ab1e67f
@ -7,18 +7,16 @@ from configs import (LLM_MODELS,
|
|||||||
TEMPERATURE,
|
TEMPERATURE,
|
||||||
USE_RERANKER,
|
USE_RERANKER,
|
||||||
RERANKER_MODEL,
|
RERANKER_MODEL,
|
||||||
RERANKER_MAX_LENGTH,
|
RERANKER_MAX_LENGTH)
|
||||||
MODEL_PATH)
|
from server.utils import wrap_done, get_ChatOpenAI, get_model_path
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable, List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
import asyncio
|
import asyncio, json
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
import json
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
from server.reranker.reranker import LangchainReranker
|
from server.reranker.reranker import LangchainReranker
|
||||||
@ -86,18 +84,17 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
|
|
||||||
# 加入reranker
|
# 加入reranker
|
||||||
if USE_RERANKER:
|
if USE_RERANKER:
|
||||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
reranker_model_path = get_model_path(RERANKER_MODEL)
|
||||||
print("-----------------model path------------------")
|
|
||||||
print(reranker_model_path)
|
|
||||||
reranker_model = LangchainReranker(top_n=top_k,
|
reranker_model = LangchainReranker(top_n=top_k,
|
||||||
device=embedding_device(),
|
device=embedding_device(),
|
||||||
max_length=RERANKER_MAX_LENGTH,
|
max_length=RERANKER_MAX_LENGTH,
|
||||||
model_name_or_path=reranker_model_path
|
model_name_or_path=reranker_model_path
|
||||||
)
|
)
|
||||||
|
print("-------------before rerank-----------------")
|
||||||
print(docs)
|
print(docs)
|
||||||
docs = reranker_model.compress_documents(documents=docs,
|
docs = reranker_model.compress_documents(documents=docs,
|
||||||
query=query)
|
query=query)
|
||||||
print("---------after rerank------------------")
|
print("------------after rerank------------------")
|
||||||
print(docs)
|
print(docs)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
@ -144,4 +141,3 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
|
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user