From 8a8337cc5c7dcbc23b570b7da24b6089c3ff3770 Mon Sep 17 00:00:00 2001 From: leehk Date: Fri, 7 Mar 2025 12:35:29 +0800 Subject: [PATCH] update --- app/llmops/components/test_rag_cot/run.py | 1 + app/llmops/main.py | 9 ++--- .../{rag_cot => rag_cot_evaluation}/MLproject | 0 .../python_env.yml | 0 .../{rag_cot => rag_cot_evaluation}/run.py | 0 app/streamlit/Chatbot.py | 11 ++++++ app/streamlit/app_test.py | 35 ++++++++++--------- 7 files changed, 35 insertions(+), 21 deletions(-) rename app/llmops/src/{rag_cot => rag_cot_evaluation}/MLproject (100%) rename app/llmops/src/{rag_cot => rag_cot_evaluation}/python_env.yml (100%) rename app/llmops/src/{rag_cot => rag_cot_evaluation}/run.py (100%) diff --git a/app/llmops/components/test_rag_cot/run.py b/app/llmops/components/test_rag_cot/run.py index b18e6aa..6e404a3 100644 --- a/app/llmops/components/test_rag_cot/run.py +++ b/app/llmops/components/test_rag_cot/run.py @@ -11,6 +11,7 @@ from langchain_deepseek import ChatDeepSeek from langchain_community.llms.moonshot import Moonshot import sys + logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logger = logging.getLogger() diff --git a/app/llmops/main.py b/app/llmops/main.py index 670a7fa..4bd19ce 100644 --- a/app/llmops/main.py +++ b/app/llmops/main.py @@ -8,7 +8,8 @@ _steps = [ "get_documents", "etl_chromadb_pdf", "etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good - "rag_cot", + "rag_cot_evaluation", + "test_rag_cot" ] @@ -100,7 +101,7 @@ def go(config: DictConfig): "embedding_model": config["etl"]["embedding_model"] }, ) - if "rag_cot" in active_steps: + if "rag_cot_evaluation" in active_steps: if config["prompt_engineering"]["run_id_chromadb"] == "None": # Look for run_id that has artifact logged as documents @@ -120,7 +121,7 @@ def go(config: DictConfig): run_id = config["prompt_engineering"]["run_id_chromadb"] _ = mlflow.run( - os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"), + os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"), "main", parameters={ "query": config["prompt_engineering"]["query"], @@ -138,7 +139,7 @@ def go(config: DictConfig): "main", parameters={ "query": config["prompt_engineering"]["query"], - "input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot", "chroma_db"), + "input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"), "embedding_model": config["etl"]["embedding_model"], "chat_model_provider": config["prompt_engineering"]["chat_model_provider"] }, diff --git a/app/llmops/src/rag_cot/MLproject b/app/llmops/src/rag_cot_evaluation/MLproject similarity index 100% rename from app/llmops/src/rag_cot/MLproject rename to app/llmops/src/rag_cot_evaluation/MLproject diff --git a/app/llmops/src/rag_cot/python_env.yml b/app/llmops/src/rag_cot_evaluation/python_env.yml similarity index 100% rename from app/llmops/src/rag_cot/python_env.yml rename to app/llmops/src/rag_cot_evaluation/python_env.yml diff --git a/app/llmops/src/rag_cot/run.py b/app/llmops/src/rag_cot_evaluation/run.py similarity index 100% rename from app/llmops/src/rag_cot/run.py rename to app/llmops/src/rag_cot_evaluation/run.py diff --git a/app/streamlit/Chatbot.py b/app/streamlit/Chatbot.py index 16266d4..3e16323 100644 --- a/app/streamlit/Chatbot.py +++ b/app/streamlit/Chatbot.py @@ -8,6 +8,13 @@ from langchain_google_genai import ChatGoogleGenerativeAI from langchain_deepseek import ChatDeepSeek from langchain_community.llms.moonshot import Moonshot +import torch +torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)] + +# # # or simply: +# torch.classes.__path__ = [] + + os.environ["TOKENIZERS_PARALLELISM"] = "false" GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str) DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str) @@ -24,12 +31,15 @@ if "messages" not in st.session_state: for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) +print('i am here1') # Load data from ChromaDB chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL) collection = chroma_client.get_collection(name=COLLECTION_NAME) +print('i am here2') # Initialize embedding model model = SentenceTransformer(EMBEDDING_MODEL) +print('i am here3') if CHAT_MODEL_PROVIDER == "deepseek": # Initialize DeepSeek model @@ -78,6 +88,7 @@ Provide the answer with language that is similar to the question asked. """ answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"]) answer_chain = answer_prompt | llm +print('i am here4') if prompt := st.chat_input(): diff --git a/app/streamlit/app_test.py b/app/streamlit/app_test.py index dce15fd..a0b7f6b 100644 --- a/app/streamlit/app_test.py +++ b/app/streamlit/app_test.py @@ -25,27 +25,27 @@ def create_chat_completion(response: str, role: str = "assistant") -> ChatComple ) -# @patch("langchain_deepseek.ChatDeepSeek.__call__") -# @patch("langchain_google_genai.ChatGoogleGenerativeAI.invoke") -# @patch("langchain_community.llms.moonshot.Moonshot.__call__") -# def test_Chatbot(moonshot_llm, gemini_llm, deepseek_llm): -# at = AppTest.from_file("Chatbot.py").run() -# assert not at.exception +@patch("langchain_deepseek.ChatDeepSeek.invoke") +@patch("langchain_google_genai.ChatGoogleGenerativeAI.invoke") +@patch("langchain_community.llms.moonshot.Moonshot.invoke") +def test_Chatbot(moonshot_llm, gemini_llm, deepseek_llm): + at = AppTest.from_file("Chatbot.py").run() + assert not at.exception -# QUERY = "What is the best treatment for hypertension?" -# RESPONSE = "The best treatment for hypertension is..." + QUERY = "What is the best treatment for hypertension?" + RESPONSE = "The best treatment for hypertension is..." -# deepseek_llm.return_value.content = RESPONSE -# gemini_llm.return_value.content = RESPONSE -# moonshot_llm.return_value = RESPONSE + deepseek_llm.return_value.content = RESPONSE + gemini_llm.return_value.content = RESPONSE + moonshot_llm.return_value = RESPONSE -# at.chat_input[0].set_value(QUERY).run() + at.chat_input[0].set_value(QUERY).run() -# assert any(mock.called for mock in [deepseek_llm, gemini_llm, moonshot_llm]) -# assert at.chat_message[1].markdown[0].value == QUERY -# assert at.chat_message[2].markdown[0].value == RESPONSE -# assert at.chat_message[2].avatar == "assistant" -# assert not at.exception + assert any(mock.called for mock in [deepseek_llm, gemini_llm, moonshot_llm]) + assert at.chat_message[1].markdown[0].value == QUERY + assert at.chat_message[2].markdown[0].value == RESPONSE + assert at.chat_message[2].avatar == "assistant" + assert not at.exception @patch("langchain.llms.OpenAI.__call__") @@ -59,3 +59,4 @@ def test_Langchain_Quickstart(langchain_llm): at.button[0].set_value(True).run() print(at) assert at.info[0].value == RESPONSE +