mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-22 14:58:43 +08:00
157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
import os
|
|
import logging
|
|
import argparse
|
|
import mlflow
|
|
import chromadb
|
|
from decouple import config
|
|
from langchain.prompts import PromptTemplate
|
|
from sentence_transformers import SentenceTransformer
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_deepseek import ChatDeepSeek
|
|
from langchain_community.llms.moonshot import Moonshot
|
|
import sys
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
|
logger = logging.getLogger()
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
|
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
|
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
|
|
|
def stream_output(text):
|
|
for char in text:
|
|
print(char, end="")
|
|
sys.stdout.flush()
|
|
|
|
def go(args):
|
|
|
|
# start a new MLflow run
|
|
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromadb_pdf"):
|
|
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
|
if 'query' not in existing_params:
|
|
mlflow.log_param('query', args.query)
|
|
|
|
# Log parameters to MLflow
|
|
mlflow.log_params({
|
|
"input_chromadb_local": args.input_chromadb_local,
|
|
"embedding_model": args.embedding_model,
|
|
"chat_model_provider": args.chat_model_provider
|
|
})
|
|
|
|
|
|
# Load data from ChromaDB
|
|
db_path = args.input_chromadb_local
|
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
|
collection_name = "rag_experiment"
|
|
collection = chroma_client.get_collection(name=collection_name)
|
|
|
|
# Formulate a question
|
|
question = args.query
|
|
|
|
if args.chat_model_provider == "deepseek":
|
|
# Initialize DeepSeek model
|
|
llm = ChatDeepSeek(
|
|
model="deepseek-chat",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
max_retries=2,
|
|
api_key=DEEKSEEK_API_KEY
|
|
)
|
|
|
|
elif args.chat_model_provider == "gemini":
|
|
# Initialize Gemini model
|
|
llm = ChatGoogleGenerativeAI(
|
|
model="gemini-1.5-flash",
|
|
google_api_key=GEMINI_API_KEY,
|
|
temperature=0,
|
|
max_retries=3
|
|
)
|
|
|
|
elif args.chat_model_provider == "moonshot":
|
|
# Initialize Moonshot model
|
|
llm = Moonshot(
|
|
model="moonshot-v1-128k",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
max_retries=2,
|
|
api_key=MOONSHOT_API_KEY
|
|
)
|
|
|
|
|
|
# Chain of Thought Prompt
|
|
cot_template = """Let's think step by step.
|
|
Given the following document in text: {documents_text}
|
|
Question: {question}
|
|
Reply with language that is similar to the language used with asked question.
|
|
"""
|
|
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
|
cot_chain = cot_prompt | llm
|
|
|
|
# Initialize embedding model (do this ONCE)
|
|
model = SentenceTransformer(args.embedding_model)
|
|
|
|
# Query (prompt)
|
|
query_embedding = model.encode(question) # Embed the query using the SAME model
|
|
|
|
# Search ChromaDB
|
|
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
|
|
|
# Generate chain of thought
|
|
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
|
print("Chain of Thought: ", end="")
|
|
stream_output(cot_output.content)
|
|
print()
|
|
|
|
# Answer Prompt
|
|
answer_template = """Given the chain of thought: {cot}
|
|
Provide a concise answer to the question: {question}
|
|
Provide the answer with language that is similar to the question asked.
|
|
"""
|
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
|
answer_chain = answer_prompt | llm
|
|
|
|
# Generate answer
|
|
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
|
print("Answer: ", end="")
|
|
stream_output(answer_output.content)
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
|
|
|
parser.add_argument(
|
|
"--query",
|
|
type=str,
|
|
help="Question to ask the model",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--input_chromadb_local",
|
|
type=str,
|
|
help="Path to input chromadb local directory",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--embedding_model",
|
|
type=str,
|
|
default="paraphrase-multilingual-mpnet-base-v2",
|
|
help="Sentence Transformer model name"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--chat_model_provider",
|
|
type=str,
|
|
default="gemini",
|
|
help="Chat model provider"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
go(args) |