mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-06 23:35:28 +08:00
Merge pull request #5 from aimingmed/feature/front-end
Feature/front end
This commit is contained in:
commit
19e4a151d7
@ -1,11 +1,10 @@
|
|||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
|
||||||
chroma:
|
|
||||||
image: ghcr.io/chroma-core/chroma:latest
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- chroma_data:/chroma
|
|
||||||
|
|
||||||
volumes:
|
services:
|
||||||
chroma_data:
|
streamlit:
|
||||||
|
build: ./streamlit
|
||||||
|
ports:
|
||||||
|
- "8501:8501"
|
||||||
|
volumes:
|
||||||
|
- ./llmops/src/rag_cot/chroma_db:/app/llmops/src/rag_cot/chroma_db
|
||||||
|
|
||||||
|
|||||||
29
app/llmops/components/test_rag_cot/MLproject
Normal file
29
app/llmops/components/test_rag_cot/MLproject
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
name: test_rag_cot
|
||||||
|
python_env: python_env.yml
|
||||||
|
|
||||||
|
entry_points:
|
||||||
|
main:
|
||||||
|
parameters:
|
||||||
|
|
||||||
|
query:
|
||||||
|
description: Query to run
|
||||||
|
type: string
|
||||||
|
|
||||||
|
input_chromadb_local:
|
||||||
|
description: path to input chromadb local
|
||||||
|
type: string
|
||||||
|
|
||||||
|
embedding_model:
|
||||||
|
description: Fully-qualified name for the embedding model
|
||||||
|
type: string
|
||||||
|
|
||||||
|
chat_model_provider:
|
||||||
|
description: Fully-qualified name for the chat model provider
|
||||||
|
type: string
|
||||||
|
|
||||||
|
|
||||||
|
command: >-
|
||||||
|
python run.py --query {query} \
|
||||||
|
--input_chromadb_local {input_chromadb_local} \
|
||||||
|
--embedding_model {embedding_model} \
|
||||||
|
--chat_model_provider {chat_model_provider}
|
||||||
17
app/llmops/components/test_rag_cot/python_env.yml
Normal file
17
app/llmops/components/test_rag_cot/python_env.yml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Python version required to run the project.
|
||||||
|
python: "3.11.11"
|
||||||
|
# Dependencies required to build packages. This field is optional.
|
||||||
|
build_dependencies:
|
||||||
|
- pip==23.3.1
|
||||||
|
- setuptools
|
||||||
|
- wheel==0.37.1
|
||||||
|
- chromadb
|
||||||
|
- langchain
|
||||||
|
- sentence_transformers
|
||||||
|
- python-decouple
|
||||||
|
- langchain_google_genai
|
||||||
|
- langchain-deepseek
|
||||||
|
- langchain-community
|
||||||
|
# Dependencies required to run the project.
|
||||||
|
dependencies:
|
||||||
|
- mlflow==2.8.1
|
||||||
156
app/llmops/components/test_rag_cot/run.py
Normal file
156
app/llmops/components/test_rag_cot/run.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import mlflow
|
||||||
|
import chromadb
|
||||||
|
from decouple import config
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
from langchain_community.llms.moonshot import Moonshot
|
||||||
|
import sys
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||||
|
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||||
|
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||||
|
|
||||||
|
def stream_output(text):
|
||||||
|
for char in text:
|
||||||
|
print(char, end="")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def go(args):
|
||||||
|
|
||||||
|
# start a new MLflow run
|
||||||
|
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromadb_pdf"):
|
||||||
|
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||||
|
if 'query' not in existing_params:
|
||||||
|
mlflow.log_param('query', args.query)
|
||||||
|
|
||||||
|
# Log parameters to MLflow
|
||||||
|
mlflow.log_params({
|
||||||
|
"input_chromadb_local": args.input_chromadb_local,
|
||||||
|
"embedding_model": args.embedding_model,
|
||||||
|
"chat_model_provider": args.chat_model_provider
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Load data from ChromaDB
|
||||||
|
db_path = args.input_chromadb_local
|
||||||
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||||
|
collection_name = "rag_experiment"
|
||||||
|
collection = chroma_client.get_collection(name=collection_name)
|
||||||
|
|
||||||
|
# Formulate a question
|
||||||
|
question = args.query
|
||||||
|
|
||||||
|
if args.chat_model_provider == "deepseek":
|
||||||
|
# Initialize DeepSeek model
|
||||||
|
llm = ChatDeepSeek(
|
||||||
|
model="deepseek-chat",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=DEEKSEEK_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.chat_model_provider == "gemini":
|
||||||
|
# Initialize Gemini model
|
||||||
|
llm = ChatGoogleGenerativeAI(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
google_api_key=GEMINI_API_KEY,
|
||||||
|
temperature=0,
|
||||||
|
max_retries=3
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.chat_model_provider == "moonshot":
|
||||||
|
# Initialize Moonshot model
|
||||||
|
llm = Moonshot(
|
||||||
|
model="moonshot-v1-128k",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=MOONSHOT_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Chain of Thought Prompt
|
||||||
|
cot_template = """Let's think step by step.
|
||||||
|
Given the following document in text: {documents_text}
|
||||||
|
Question: {question}
|
||||||
|
Reply with language that is similar to the language used with asked question.
|
||||||
|
"""
|
||||||
|
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||||
|
cot_chain = cot_prompt | llm
|
||||||
|
|
||||||
|
# Initialize embedding model (do this ONCE)
|
||||||
|
model = SentenceTransformer(args.embedding_model)
|
||||||
|
|
||||||
|
# Query (prompt)
|
||||||
|
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||||
|
|
||||||
|
# Search ChromaDB
|
||||||
|
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||||
|
|
||||||
|
# Generate chain of thought
|
||||||
|
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||||
|
print("Chain of Thought: ", end="")
|
||||||
|
stream_output(cot_output.content)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Answer Prompt
|
||||||
|
answer_template = """Given the chain of thought: {cot}
|
||||||
|
Provide a concise answer to the question: {question}
|
||||||
|
Provide the answer with language that is similar to the question asked.
|
||||||
|
"""
|
||||||
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||||
|
answer_chain = answer_prompt | llm
|
||||||
|
|
||||||
|
# Generate answer
|
||||||
|
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||||
|
print("Answer: ", end="")
|
||||||
|
stream_output(answer_output.content)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
help="Question to ask the model",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_chromadb_local",
|
||||||
|
type=str,
|
||||||
|
help="Path to input chromadb local directory",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
help="Sentence Transformer model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat_model_provider",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Chat model provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
go(args)
|
||||||
@ -11,6 +11,6 @@ etl:
|
|||||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||||
prompt_engineering:
|
prompt_engineering:
|
||||||
run_id_chromadb: None
|
run_id_chromadb: None
|
||||||
chat_model_provider: moonshot
|
chat_model_provider: gemini
|
||||||
query: "怎么治疗有kras的肺癌?"
|
query: "怎么治疗有kras的肺癌?"
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ def go(config: DictConfig):
|
|||||||
if run_id is None:
|
if run_id is None:
|
||||||
raise ValueError("No run_id found with artifact logged as documents")
|
raise ValueError("No run_id found with artifact logged as documents")
|
||||||
else:
|
else:
|
||||||
run_id = config["etl"]["run_id_documents"]
|
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
||||||
|
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
|
||||||
@ -133,5 +133,20 @@ def go(config: DictConfig):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if "test_rag_cot" in active_steps:
|
||||||
|
|
||||||
|
_ = mlflow.run(
|
||||||
|
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
||||||
|
"main",
|
||||||
|
parameters={
|
||||||
|
"query": config["prompt_engineering"]["query"],
|
||||||
|
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot", "chroma_db"),
|
||||||
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
|
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
go()
|
go()
|
||||||
|
|||||||
109
app/streamlit/Chatbot.py
Normal file
109
app/streamlit/Chatbot.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import streamlit as st
|
||||||
|
import chromadb
|
||||||
|
from decouple import config
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
from langchain_community.llms.moonshot import Moonshot
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||||
|
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||||
|
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||||
|
CHAT_MODEL_PROVIDER = config("CHAT_MODEL_PROVIDER", cast=str)
|
||||||
|
INPUT_CHROMADB_LOCAL = config("INPUT_CHROMADB_LOCAL", cast=str)
|
||||||
|
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str)
|
||||||
|
COLLECTION_NAME = config("COLLECTION_NAME", cast=str)
|
||||||
|
|
||||||
|
st.title("💬 RAG AI for Medical Guideline")
|
||||||
|
st.caption(f"🚀 A RAG AI for Medical Guideline powered by {CHAT_MODEL_PROVIDER}")
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||||
|
for msg in st.session_state.messages:
|
||||||
|
st.chat_message(msg["role"]).write(msg["content"])
|
||||||
|
|
||||||
|
# Load data from ChromaDB
|
||||||
|
chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL)
|
||||||
|
collection = chroma_client.get_collection(name=COLLECTION_NAME)
|
||||||
|
|
||||||
|
# Initialize embedding model
|
||||||
|
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
if CHAT_MODEL_PROVIDER == "deepseek":
|
||||||
|
# Initialize DeepSeek model
|
||||||
|
llm = ChatDeepSeek(
|
||||||
|
model="deepseek-chat",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=DEEKSEEK_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
elif CHAT_MODEL_PROVIDER == "gemini":
|
||||||
|
# Initialize Gemini model
|
||||||
|
llm = ChatGoogleGenerativeAI(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
google_api_key=GEMINI_API_KEY,
|
||||||
|
temperature=0,
|
||||||
|
max_retries=3
|
||||||
|
)
|
||||||
|
|
||||||
|
elif CHAT_MODEL_PROVIDER == "moonshot":
|
||||||
|
# Initialize Moonshot model
|
||||||
|
llm = Moonshot(
|
||||||
|
model="moonshot-v1-128k",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=MOONSHOT_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chain of Thought Prompt
|
||||||
|
cot_template = """Let's think step by step.
|
||||||
|
Given the following document in text: {documents_text}
|
||||||
|
Question: {question}
|
||||||
|
Reply with language that is similar to the language used with asked question.
|
||||||
|
"""
|
||||||
|
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||||
|
cot_chain = cot_prompt | llm
|
||||||
|
|
||||||
|
# Answer Prompt
|
||||||
|
answer_template = """Given the chain of thought: {cot}
|
||||||
|
Provide a concise answer to the question: {question}
|
||||||
|
Provide the answer with language that is similar to the question asked.
|
||||||
|
"""
|
||||||
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||||
|
answer_chain = answer_prompt | llm
|
||||||
|
|
||||||
|
if prompt := st.chat_input():
|
||||||
|
|
||||||
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||||
|
st.chat_message("user").write(prompt)
|
||||||
|
|
||||||
|
# Query (prompt)
|
||||||
|
query_embedding = model.encode(prompt) # Embed the query using the SAME model
|
||||||
|
|
||||||
|
# Search ChromaDB
|
||||||
|
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||||
|
|
||||||
|
# Generate chain of thought
|
||||||
|
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": prompt})
|
||||||
|
|
||||||
|
# response = client.chat.completions.create(model="gpt-3.5-turbo", messages=st.session_state.messages)
|
||||||
|
msg = cot_output.content
|
||||||
|
st.session_state.messages.append({"role": "assistant", "content": msg})
|
||||||
|
st.chat_message("assistant").write(msg)
|
||||||
|
|
||||||
|
# Generate answer
|
||||||
|
answer_output = answer_chain.invoke({"cot": cot_output, "question": prompt})
|
||||||
|
msg = answer_output.content
|
||||||
|
st.session_state.messages.append({"role": "assistant", "content": msg})
|
||||||
|
st.chat_message("assistant").write(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
14
app/streamlit/Dockerfile
Normal file
14
app/streamlit/Dockerfile
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app/streamlit
|
||||||
|
|
||||||
|
COPY Pipfile Pipfile.lock ./
|
||||||
|
|
||||||
|
RUN pip install pipenv && pipenv install --system --deploy
|
||||||
|
|
||||||
|
COPY Chatbot.py .
|
||||||
|
COPY .env .
|
||||||
|
|
||||||
|
EXPOSE 8501
|
||||||
|
|
||||||
|
ENTRYPOINT ["streamlit", "run", "Chatbot.py"]
|
||||||
@ -5,15 +5,25 @@ name = "pypi"
|
|||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
streamlit = "==1.28"
|
streamlit = "==1.28"
|
||||||
langchain = "==0.0.217"
|
langchain = "*"
|
||||||
openai = "==1.2"
|
|
||||||
duckduckgo-search = "*"
|
duckduckgo-search = "*"
|
||||||
anthropic = "==0.3.0"
|
anthropic = "*"
|
||||||
trubrics = "==1.4.3"
|
trubrics = "*"
|
||||||
streamlit-feedback = "*"
|
streamlit-feedback = "*"
|
||||||
langchain-community = "*"
|
langchain-community = "*"
|
||||||
|
watchdog = "*"
|
||||||
|
mlflow = "==2.8.1"
|
||||||
|
python-decouple = "*"
|
||||||
|
langchain_google_genai = "*"
|
||||||
|
langchain-deepseek = "*"
|
||||||
|
sentence_transformers = "*"
|
||||||
|
chromadb = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
|
pytest = "==8.0.0"
|
||||||
|
pytest-cov = "==4.1.0"
|
||||||
|
pytest-mock = "==3.10.0"
|
||||||
|
pytest-asyncio = "*"
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.11"
|
python_version = "3.11"
|
||||||
|
|||||||
2469
app/streamlit/Pipfile.lock
generated
2469
app/streamlit/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
39
app/streamlit/tests/test_chatbot.py
Normal file
39
app/streamlit/tests/test_chatbot.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
import streamlit as st
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
# add app/streamlit to sys.path
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/Users/leehongkai/projects/aimingmed/aimingmed-ai/app/streamlit")
|
||||||
|
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_title():
|
||||||
|
with patch("streamlit.title") as mock_title, \
|
||||||
|
patch("streamlit.session_state", new_callable=MagicMock) as mock_session_state:
|
||||||
|
import Chatbot
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
mock_title.assert_called_once_with("💬 RAG AI for Medical Guideline")
|
||||||
|
|
||||||
|
def test_caption():
|
||||||
|
with patch("streamlit.caption") as mock_caption, \
|
||||||
|
patch("streamlit.session_state", new_callable=MagicMock) as mock_session_state:
|
||||||
|
import Chatbot
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
mock_caption.assert_called()
|
||||||
|
|
||||||
|
def test_chat_input():
|
||||||
|
with patch("streamlit.chat_input", return_value="test_prompt") as mock_chat_input, \
|
||||||
|
patch("streamlit.session_state", new_callable=MagicMock) as mock_session_state:
|
||||||
|
import Chatbot
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
mock_chat_input.assert_called_once()
|
||||||
|
|
||||||
|
def test_chat_message():
|
||||||
|
with patch("streamlit.chat_message") as mock_chat_message, \
|
||||||
|
patch("streamlit.session_state", new_callable=MagicMock) as mock_session_state:
|
||||||
|
with patch("streamlit.chat_input", return_value="test_prompt"):
|
||||||
|
import Chatbot
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
mock_chat_message.assert_called()
|
||||||
Loading…
x
Reference in New Issue
Block a user