mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
update
This commit is contained in:
parent
dab06086a0
commit
320bae36c7
29
app/llmops/components/test_rag_cot/MLproject
Normal file
29
app/llmops/components/test_rag_cot/MLproject
Normal file
@ -0,0 +1,29 @@
|
||||
name: test_rag_cot
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
parameters:
|
||||
|
||||
query:
|
||||
description: Query to run
|
||||
type: string
|
||||
|
||||
input_chromadb_local:
|
||||
description: path to input chromadb local
|
||||
type: string
|
||||
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--input_chromadb_local {input_chromadb_local} \
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider}
|
||||
17
app/llmops/components/test_rag_cot/python_env.yml
Normal file
17
app/llmops/components/test_rag_cot/python_env.yml
Normal file
@ -0,0 +1,17 @@
|
||||
# Python version required to run the project.
|
||||
python: "3.11.11"
|
||||
# Dependencies required to build packages. This field is optional.
|
||||
build_dependencies:
|
||||
- pip==23.3.1
|
||||
- setuptools
|
||||
- wheel==0.37.1
|
||||
- chromadb
|
||||
- langchain
|
||||
- sentence_transformers
|
||||
- python-decouple
|
||||
- langchain_google_genai
|
||||
- langchain-deepseek
|
||||
- langchain-community
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
146
app/llmops/components/test_rag_cot/run.py
Normal file
146
app/llmops/components/test_rag_cot/run.py
Normal file
@ -0,0 +1,146 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import chromadb
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
# start a new MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'query' not in existing_params:
|
||||
mlflow.log_param('query', args.query)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_chromadb_local": args.input_chromadb_local,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model_provider": args.chat_model_provider
|
||||
})
|
||||
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_path = args.input_chromadb_local
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", cot_output)
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", answer_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_local",
|
||||
type=str,
|
||||
help="Path to input chromadb local directory",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -120,7 +120,7 @@ def go(config: DictConfig):
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
|
||||
@ -133,5 +133,20 @@ def go(config: DictConfig):
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if "test_rag_cot" in active_steps:
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot", "chroma_db"),
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
go()
|
||||
|
||||
29
app/streamlit/Chatbot.py
Normal file
29
app/streamlit/Chatbot.py
Normal file
@ -0,0 +1,29 @@
|
||||
from openai import OpenAI
|
||||
import streamlit as st
|
||||
|
||||
with st.sidebar:
|
||||
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password")
|
||||
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
|
||||
"[View the source code](https://github.com/streamlit/llm-examples/blob/main/Chatbot.py)"
|
||||
"[](https://codespaces.new/streamlit/llm-examples?quickstart=1)"
|
||||
|
||||
st.title("💬 Chatbot")
|
||||
st.caption("🚀 A Streamlit chatbot powered by OpenAI")
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
st.chat_message(msg["role"]).write(msg["content"])
|
||||
|
||||
if prompt := st.chat_input():
|
||||
if not openai_api_key:
|
||||
st.info("Please add your OpenAI API key to continue.")
|
||||
st.stop()
|
||||
|
||||
client = OpenAI(api_key=openai_api_key)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
st.chat_message("user").write(prompt)
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages=st.session_state.messages)
|
||||
msg = response.choices[0].message.content
|
||||
st.session_state.messages.append({"role": "assistant", "content": msg})
|
||||
st.chat_message("assistant").write(msg)
|
||||
Loading…
x
Reference in New Issue
Block a user