This commit is contained in:
leehk 2025-03-07 12:35:29 +08:00
parent 29d82e7cef
commit 8a8337cc5c
7 changed files with 35 additions and 21 deletions

View File

@ -11,6 +11,7 @@ from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
import sys
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()

View File

@ -8,7 +8,8 @@ _steps = [
"get_documents",
"etl_chromadb_pdf",
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
"rag_cot",
"rag_cot_evaluation",
"test_rag_cot"
]
@ -100,7 +101,7 @@ def go(config: DictConfig):
"embedding_model": config["etl"]["embedding_model"]
},
)
if "rag_cot" in active_steps:
if "rag_cot_evaluation" in active_steps:
if config["prompt_engineering"]["run_id_chromadb"] == "None":
# Look for run_id that has artifact logged as documents
@ -120,7 +121,7 @@ def go(config: DictConfig):
run_id = config["prompt_engineering"]["run_id_chromadb"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
"main",
parameters={
"query": config["prompt_engineering"]["query"],
@ -138,7 +139,7 @@ def go(config: DictConfig):
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot", "chroma_db"),
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
},

View File

@ -8,6 +8,13 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
import torch
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
# # # or simply:
# torch.classes.__path__ = []
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
@ -24,12 +31,15 @@ if "messages" not in st.session_state:
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
print('i am here1')
# Load data from ChromaDB
chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL)
collection = chroma_client.get_collection(name=COLLECTION_NAME)
print('i am here2')
# Initialize embedding model
model = SentenceTransformer(EMBEDDING_MODEL)
print('i am here3')
if CHAT_MODEL_PROVIDER == "deepseek":
# Initialize DeepSeek model
@ -78,6 +88,7 @@ Provide the answer with language that is similar to the question asked.
"""
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
answer_chain = answer_prompt | llm
print('i am here4')
if prompt := st.chat_input():

View File

@ -25,27 +25,27 @@ def create_chat_completion(response: str, role: str = "assistant") -> ChatComple
)
# @patch("langchain_deepseek.ChatDeepSeek.__call__")
# @patch("langchain_google_genai.ChatGoogleGenerativeAI.invoke")
# @patch("langchain_community.llms.moonshot.Moonshot.__call__")
# def test_Chatbot(moonshot_llm, gemini_llm, deepseek_llm):
# at = AppTest.from_file("Chatbot.py").run()
# assert not at.exception
@patch("langchain_deepseek.ChatDeepSeek.invoke")
@patch("langchain_google_genai.ChatGoogleGenerativeAI.invoke")
@patch("langchain_community.llms.moonshot.Moonshot.invoke")
def test_Chatbot(moonshot_llm, gemini_llm, deepseek_llm):
at = AppTest.from_file("Chatbot.py").run()
assert not at.exception
# QUERY = "What is the best treatment for hypertension?"
# RESPONSE = "The best treatment for hypertension is..."
QUERY = "What is the best treatment for hypertension?"
RESPONSE = "The best treatment for hypertension is..."
# deepseek_llm.return_value.content = RESPONSE
# gemini_llm.return_value.content = RESPONSE
# moonshot_llm.return_value = RESPONSE
deepseek_llm.return_value.content = RESPONSE
gemini_llm.return_value.content = RESPONSE
moonshot_llm.return_value = RESPONSE
# at.chat_input[0].set_value(QUERY).run()
at.chat_input[0].set_value(QUERY).run()
# assert any(mock.called for mock in [deepseek_llm, gemini_llm, moonshot_llm])
# assert at.chat_message[1].markdown[0].value == QUERY
# assert at.chat_message[2].markdown[0].value == RESPONSE
# assert at.chat_message[2].avatar == "assistant"
# assert not at.exception
assert any(mock.called for mock in [deepseek_llm, gemini_llm, moonshot_llm])
assert at.chat_message[1].markdown[0].value == QUERY
assert at.chat_message[2].markdown[0].value == RESPONSE
assert at.chat_message[2].avatar == "assistant"
assert not at.exception
@patch("langchain.llms.OpenAI.__call__")
@ -59,3 +59,4 @@ def test_Langchain_Quickstart(langchain_llm):
at.button[0].set_value(True).run()
print(at)
assert at.info[0].value == RESPONSE