mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-08 00:03:15 +08:00
commit
890d949778
@ -7,4 +7,4 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8501:8501"
|
- "8501:8501"
|
||||||
volumes:
|
volumes:
|
||||||
- ./llmops/src/rag_cot/chroma_db:/app/llmops/src/rag_cot/chroma_db
|
- ./llmops/src/rag_cot_evaluation/chroma_db:/app/llmops/src/rag_cot_evaluation/chroma_db
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from langchain_deepseek import ChatDeepSeek
|
|||||||
from langchain_community.llms.moonshot import Moonshot
|
from langchain_community.llms.moonshot import Moonshot
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,8 @@ _steps = [
|
|||||||
"get_documents",
|
"get_documents",
|
||||||
"etl_chromadb_pdf",
|
"etl_chromadb_pdf",
|
||||||
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
|
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||||
"rag_cot",
|
"rag_cot_evaluation",
|
||||||
|
"test_rag_cot"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +101,7 @@ def go(config: DictConfig):
|
|||||||
"embedding_model": config["etl"]["embedding_model"]
|
"embedding_model": config["etl"]["embedding_model"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if "rag_cot" in active_steps:
|
if "rag_cot_evaluation" in active_steps:
|
||||||
|
|
||||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
||||||
# Look for run_id that has artifact logged as documents
|
# Look for run_id that has artifact logged as documents
|
||||||
@ -120,7 +121,7 @@ def go(config: DictConfig):
|
|||||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
||||||
|
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
|
||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["prompt_engineering"]["query"],
|
||||||
@ -138,7 +139,7 @@ def go(config: DictConfig):
|
|||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["prompt_engineering"]["query"],
|
||||||
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot", "chroma_db"),
|
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
|
||||||
"embedding_model": config["etl"]["embedding_model"],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||||
},
|
},
|
||||||
|
|||||||
@ -8,14 +8,21 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
|||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
from langchain_community.llms.moonshot import Moonshot
|
from langchain_community.llms.moonshot import Moonshot
|
||||||
|
|
||||||
|
import torch
|
||||||
|
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
|
||||||
|
|
||||||
|
# # # or simply:
|
||||||
|
# torch.classes.__path__ = []
|
||||||
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str, default="123456")
|
||||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str, default="123456")
|
||||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str, default="123456")
|
||||||
CHAT_MODEL_PROVIDER = config("CHAT_MODEL_PROVIDER", cast=str)
|
CHAT_MODEL_PROVIDER = config("CHAT_MODEL_PROVIDER", cast=str, default="gemini")
|
||||||
INPUT_CHROMADB_LOCAL = config("INPUT_CHROMADB_LOCAL", cast=str)
|
INPUT_CHROMADB_LOCAL = config("INPUT_CHROMADB_LOCAL", cast=str, default="../llmops/src/rag_cot_evaluation/chroma_db")
|
||||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str)
|
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||||
COLLECTION_NAME = config("COLLECTION_NAME", cast=str)
|
COLLECTION_NAME = config("COLLECTION_NAME", cast=str, default="rag_experiment")
|
||||||
|
|
||||||
st.title("💬 RAG AI for Medical Guideline")
|
st.title("💬 RAG AI for Medical Guideline")
|
||||||
st.caption(f"🚀 A RAG AI for Medical Guideline powered by {CHAT_MODEL_PROVIDER}")
|
st.caption(f"🚀 A RAG AI for Medical Guideline powered by {CHAT_MODEL_PROVIDER}")
|
||||||
@ -24,12 +31,15 @@ if "messages" not in st.session_state:
|
|||||||
for msg in st.session_state.messages:
|
for msg in st.session_state.messages:
|
||||||
st.chat_message(msg["role"]).write(msg["content"])
|
st.chat_message(msg["role"]).write(msg["content"])
|
||||||
|
|
||||||
|
print('i am here1')
|
||||||
# Load data from ChromaDB
|
# Load data from ChromaDB
|
||||||
chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL)
|
chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL)
|
||||||
collection = chroma_client.get_collection(name=COLLECTION_NAME)
|
collection = chroma_client.get_collection(name=COLLECTION_NAME)
|
||||||
|
print('i am here2')
|
||||||
|
|
||||||
# Initialize embedding model
|
# Initialize embedding model
|
||||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||||
|
print('i am here3')
|
||||||
|
|
||||||
if CHAT_MODEL_PROVIDER == "deepseek":
|
if CHAT_MODEL_PROVIDER == "deepseek":
|
||||||
# Initialize DeepSeek model
|
# Initialize DeepSeek model
|
||||||
@ -78,6 +88,7 @@ Provide the answer with language that is similar to the question asked.
|
|||||||
"""
|
"""
|
||||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||||
answer_chain = answer_prompt | llm
|
answer_chain = answer_prompt | llm
|
||||||
|
print('i am here4')
|
||||||
|
|
||||||
if prompt := st.chat_input():
|
if prompt := st.chat_input():
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,10 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
COPY Chatbot.py .
|
COPY Chatbot.py .
|
||||||
COPY .env .
|
COPY .env .
|
||||||
|
|
||||||
|
# Run python to initialize download of SentenceTransformer model
|
||||||
|
COPY initialize_sentence_transformer.py .
|
||||||
|
RUN python initialize_sentence_transformer.py
|
||||||
|
|
||||||
EXPOSE 8501
|
EXPOSE 8501
|
||||||
|
|
||||||
ENTRYPOINT ["streamlit", "run", "Chatbot.py"]
|
ENTRYPOINT ["streamlit", "run", "Chatbot.py"]
|
||||||
@ -1,52 +1,7 @@
|
|||||||
import datetime
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from streamlit.testing.v1 import AppTest
|
from streamlit.testing.v1 import AppTest
|
||||||
from openai.types.chat import ChatCompletionMessage
|
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
||||||
|
|
||||||
|
|
||||||
# See https://github.com/openai/openai-python/issues/715#issuecomment-1809203346
|
|
||||||
def create_chat_completion(response: str, role: str = "assistant") -> ChatCompletion:
|
|
||||||
return ChatCompletion(
|
|
||||||
id="foo",
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
object="chat.completion",
|
|
||||||
choices=[
|
|
||||||
Choice(
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
content=response,
|
|
||||||
role=role,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=int(datetime.datetime.now().timestamp()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# @patch("langchain_deepseek.ChatDeepSeek.__call__")
|
|
||||||
# @patch("langchain_google_genai.ChatGoogleGenerativeAI.invoke")
|
|
||||||
# @patch("langchain_community.llms.moonshot.Moonshot.__call__")
|
|
||||||
# def test_Chatbot(moonshot_llm, gemini_llm, deepseek_llm):
|
|
||||||
# at = AppTest.from_file("Chatbot.py").run()
|
|
||||||
# assert not at.exception
|
|
||||||
|
|
||||||
# QUERY = "What is the best treatment for hypertension?"
|
|
||||||
# RESPONSE = "The best treatment for hypertension is..."
|
|
||||||
|
|
||||||
# deepseek_llm.return_value.content = RESPONSE
|
|
||||||
# gemini_llm.return_value.content = RESPONSE
|
|
||||||
# moonshot_llm.return_value = RESPONSE
|
|
||||||
|
|
||||||
# at.chat_input[0].set_value(QUERY).run()
|
|
||||||
|
|
||||||
# assert any(mock.called for mock in [deepseek_llm, gemini_llm, moonshot_llm])
|
|
||||||
# assert at.chat_message[1].markdown[0].value == QUERY
|
|
||||||
# assert at.chat_message[2].markdown[0].value == RESPONSE
|
|
||||||
# assert at.chat_message[2].avatar == "assistant"
|
|
||||||
# assert not at.exception
|
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain.llms.OpenAI.__call__")
|
@patch("langchain.llms.OpenAI.__call__")
|
||||||
def test_Langchain_Quickstart(langchain_llm):
|
def test_Langchain_Quickstart(langchain_llm):
|
||||||
@ -59,3 +14,4 @@ def test_Langchain_Quickstart(langchain_llm):
|
|||||||
at.button[0].set_value(True).run()
|
at.button[0].set_value(True).run()
|
||||||
print(at)
|
print(at)
|
||||||
assert at.info[0].value == RESPONSE
|
assert at.info[0].value == RESPONSE
|
||||||
|
|
||||||
|
|||||||
7
app/streamlit/initialize_sentence_transformer.py
Normal file
7
app/streamlit/initialize_sentence_transformer.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from decouple import config
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||||
|
|
||||||
|
# Initialize embedding model
|
||||||
|
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||||
Loading…
x
Reference in New Issue
Block a user