mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 21:37:31 +08:00
done
This commit is contained in:
parent
7399b56fa1
commit
3a4d59c0e3
@ -10,5 +10,4 @@ build_dependencies:
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- git+https://github.com/udacity/nd0821-c2-build-model-workflow-starter.git#egg=wandb-utils&subdirectory=components
|
||||
|
||||
@ -5,33 +5,33 @@ This script download a URL to a local destination
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
import wandb
|
||||
|
||||
from wandb_utils.log_artifact import log_artifact
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def go(args):
|
||||
|
||||
zip_path = os.path.join(args.path_document_folder, f"{args.document_folder}.zip")
|
||||
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', args.path_document_folder, args.document_folder)
|
||||
|
||||
run = wandb.init(job_type="get_documents", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id) as run:
|
||||
|
||||
logger.info(f"Uploading {args.artifact_name} to Weights & Biases")
|
||||
log_artifact(
|
||||
args.artifact_name,
|
||||
args.artifact_type,
|
||||
args.artifact_description,
|
||||
zip_path,
|
||||
run,
|
||||
)
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'artifact_description' not in existing_params:
|
||||
mlflow.log_param('artifact_description', args.artifact_description)
|
||||
if 'artifact_types' not in existing_params:
|
||||
mlflow.log_param('artifact_types', args.artifact_type)
|
||||
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.artifact_name,
|
||||
})
|
||||
|
||||
logger.info(f"Uploading {args.artifact_name} to MLFlow")
|
||||
mlflow.log_artifact(zip_path, args.artifact_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -7,8 +7,10 @@ etl:
|
||||
input_artifact_name: documents
|
||||
document_folder: documents
|
||||
path_document_folder: "../../../../data"
|
||||
run_id_documents: None
|
||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||
prompt_engineering:
|
||||
chat_model_provider: kimi
|
||||
run_id_chromadb: None
|
||||
chat_model_provider: moonshot
|
||||
query: "怎么治疗有kras的肺癌?"
|
||||
|
||||
@ -9,9 +9,9 @@ from decouple import config
|
||||
|
||||
_steps = [
|
||||
"get_documents",
|
||||
"etl_chromdb_pdf",
|
||||
"etl_chromdb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||
"chain_of_thought"
|
||||
"etl_chromadb_pdf",
|
||||
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||
"rag_cot",
|
||||
]
|
||||
|
||||
|
||||
@ -19,9 +19,8 @@ _steps = [
|
||||
@hydra.main(config_name='config')
|
||||
def go(config: DictConfig):
|
||||
|
||||
# Setup the wandb experiment. All runs will be grouped under this name
|
||||
os.environ["WANDB_PROJECT"] = config["main"]["project_name"]
|
||||
os.environ["WANDB_RUN_GROUP"] = config["main"]["experiment_name"]
|
||||
# Setup the MLflow experiment. All runs will be grouped under this name
|
||||
mlflow.set_experiment(config["main"]["experiment_name"])
|
||||
|
||||
# Steps to execute
|
||||
steps_par = config['main']['steps']
|
||||
@ -43,37 +42,92 @@ def go(config: DictConfig):
|
||||
"artifact_description": "Raw file as downloaded"
|
||||
},
|
||||
)
|
||||
if "etl_chromdb_pdf" in active_steps:
|
||||
if "etl_chromadb_pdf" in active_steps:
|
||||
if config["etl"]["run_id_documents"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "documents":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_pdf"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_pdf"),
|
||||
"main",
|
||||
parameters={
|
||||
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
|
||||
"output_artifact": "chromdb.zip",
|
||||
"output_type": "chromdb",
|
||||
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
|
||||
"output_artifact": "chromadb",
|
||||
"output_type": "chromadb",
|
||||
"output_description": "Documents in pdf to be read and stored in chromdb",
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
if "etl_chromdb_scanned_pdf" in active_steps:
|
||||
|
||||
if "etl_chromadb_scanned_pdf" in active_steps:
|
||||
|
||||
if config["etl"]["run_id_documents"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "documents":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_scanned_pdf"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_scanned_pdf"),
|
||||
"main",
|
||||
parameters={
|
||||
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
|
||||
"output_artifact": "chromdb.zip",
|
||||
"output_type": "chromdb",
|
||||
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
|
||||
"output_artifact": "chromadb",
|
||||
"output_type": "chromadb",
|
||||
"output_description": "Scanned Documents in pdf to be read and stored in chromdb",
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
if "chain_of_thought" in active_steps:
|
||||
if "rag_cot" in active_steps:
|
||||
|
||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "chromadb":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"input_chromadb_artifact": "chromdb.zip:latest",
|
||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
},
|
||||
|
||||
@ -1,144 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import wandb
|
||||
import chromadb
|
||||
import shutil
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
def go(args):
|
||||
run = wandb.init(job_type="chain_of_thought", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
logger.info("Downloading chromadb artifact")
|
||||
artifact_chromadb_local_path = run.use_artifact(args.input_chromadb_artifact).file()
|
||||
|
||||
# unzip the artifact
|
||||
logger.info("Unzipping the artifact")
|
||||
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", cot_output)
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", answer_output)
|
||||
|
||||
run.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the chromadb artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: etl_chromdb_pdf
|
||||
name: etl_chromadb_pdf
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
@ -12,5 +12,4 @@ build_dependencies:
|
||||
- sentence_transformers
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- mlflow==2.8.1
|
||||
178
app/llmops/src/etl_chromadb_pdf/run.py
Normal file
178
app/llmops/src/etl_chromadb_pdf/run.py
Normal file
@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
import io
|
||||
from pdfminer.converter import TextConverter
|
||||
from pdfminer.pdfinterp import PDFPageInterpreter
|
||||
from pdfminer.pdfinterp import PDFResourceManager
|
||||
from pdfminer.pdfpage import PDFPage
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_chinese_text_from_pdf(pdf_path):
|
||||
"""
|
||||
Extracts Chinese text from a PDF file.
|
||||
|
||||
Args:
|
||||
pdf_path (str): The path to the PDF file.
|
||||
|
||||
Returns:
|
||||
str: The extracted Chinese text, or None if an error occurs.
|
||||
"""
|
||||
resource_manager = PDFResourceManager()
|
||||
fake_file_handle = io.StringIO()
|
||||
converter = TextConverter(resource_manager, fake_file_handle)
|
||||
page_interpreter = PDFPageInterpreter(resource_manager, converter)
|
||||
|
||||
try:
|
||||
with open(pdf_path, 'rb') as fh:
|
||||
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
|
||||
page_interpreter.process_page(page)
|
||||
|
||||
text = fake_file_handle.getvalue()
|
||||
|
||||
return text
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: PDF file not found at {pdf_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
finally:
|
||||
converter.close()
|
||||
fake_file_handle.close()
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
# Start an MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'output_description' not in existing_params:
|
||||
mlflow.log_param('output_description', args.output_description)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.input_artifact,
|
||||
"output_artifact": args.output_artifact,
|
||||
"output_type": args.output_type,
|
||||
"embedding_model": args.embedding_model
|
||||
})
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
logger.info("Logging artifact with mlflow")
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
shutil.rmtree(db_path)
|
||||
shutil.rmtree(documents_folder)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: etl_chromdb_scanned_pdf
|
||||
name: etl_chromadb_scanned_pdf
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
@ -14,4 +14,3 @@ build_dependencies:
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
162
app/llmops/src/etl_chromadb_scanned_pdf/run.py
Normal file
162
app/llmops/src/etl_chromadb_scanned_pdf/run.py
Normal file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pytesseract as pt
|
||||
from pdf2image import convert_from_path
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_text_from_pdf_ocr(pdf_path):
|
||||
try:
|
||||
images = convert_from_path(pdf_path) # Convert PDF pages to images
|
||||
extracted_text = ""
|
||||
for image in images:
|
||||
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
|
||||
|
||||
extracted_text += text + "\n"
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"OCR failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
# Start an MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'output_description' not in existing_params:
|
||||
mlflow.log_param('output_description', args.output_description)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.input_artifact,
|
||||
"output_artifact": args.output_artifact,
|
||||
"output_type": args.output_type,
|
||||
"embedding_model": args.embedding_model
|
||||
})
|
||||
|
||||
|
||||
# Initialize embedding model
|
||||
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
logger.info("Uploading artifact to MLFlow")
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
shutil.rmtree(db_path)
|
||||
shutil.rmtree(documents_folder)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import wandb
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
import io
|
||||
from pdfminer.converter import TextConverter
|
||||
from pdfminer.pdfinterp import PDFPageInterpreter
|
||||
from pdfminer.pdfinterp import PDFResourceManager
|
||||
from pdfminer.pdfpage import PDFPage
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_chinese_text_from_pdf(pdf_path):
|
||||
"""
|
||||
Extracts Chinese text from a PDF file.
|
||||
|
||||
Args:
|
||||
pdf_path (str): The path to the PDF file.
|
||||
|
||||
Returns:
|
||||
str: The extracted Chinese text, or None if an error occurs.
|
||||
"""
|
||||
resource_manager = PDFResourceManager()
|
||||
fake_file_handle = io.StringIO()
|
||||
converter = TextConverter(resource_manager, fake_file_handle)
|
||||
page_interpreter = PDFPageInterpreter(resource_manager, converter)
|
||||
|
||||
try:
|
||||
with open(pdf_path, 'rb') as fh:
|
||||
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
|
||||
page_interpreter.process_page(page)
|
||||
|
||||
text = fake_file_handle.getvalue()
|
||||
|
||||
return text
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: PDF file not found at {pdf_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
finally:
|
||||
converter.close()
|
||||
fake_file_handle.close()
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = run.use_artifact(args.input_artifact).file()
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
os.remove(artifact_local_path)
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
# Create a new artifact
|
||||
artifact = wandb.Artifact(
|
||||
args.output_artifact,
|
||||
type=args.output_type,
|
||||
description=args.output_description
|
||||
)
|
||||
|
||||
# zip the database folder first
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
|
||||
# Add the database to the artifact
|
||||
artifact.add_file(db_path + '.zip')
|
||||
|
||||
# Log the artifact
|
||||
run.log_artifact(artifact)
|
||||
|
||||
# Finish the run
|
||||
run.finish()
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
os.remove(db_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,173 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import wandb
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pytesseract as pt
|
||||
from pdf2image import convert_from_path
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_text_from_pdf_ocr(pdf_path):
|
||||
try:
|
||||
images = convert_from_path(pdf_path) # Convert PDF pages to images
|
||||
extracted_text = ""
|
||||
for image in images:
|
||||
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
|
||||
|
||||
extracted_text += text + "\n"
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"OCR failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
# Setup the Gemini client
|
||||
# client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
# )
|
||||
|
||||
|
||||
# def get_google_embedding(text: str) -> List[float]:
|
||||
# response = client.embeddings.create(
|
||||
# model="text-embedding-004",
|
||||
# input=text
|
||||
# )
|
||||
# return response.data[0].embedding
|
||||
|
||||
# class GeminiEmbeddingFunction(object):
|
||||
# def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||
# self.client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url=base_url
|
||||
# )
|
||||
# self.model_name = model_name
|
||||
|
||||
# def __call__(self, input: List[str]) -> List[List[float]]:
|
||||
# all_embeddings = []
|
||||
# for text in input:
|
||||
# response = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
# embeddings = [record.embedding for record in response.data]
|
||||
# all_embeddings.append(np.array(embeddings[0]))
|
||||
# return all_embeddings
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = run.use_artifact(args.input_artifact).file()
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
os.remove(artifact_local_path)
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: chain_of_thought
|
||||
name: rag_cot
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
@ -14,5 +14,4 @@ build_dependencies:
|
||||
- langchain-community
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- mlflow==2.8.1
|
||||
155
app/llmops/src/rag_cot/run.py
Normal file
155
app/llmops/src/rag_cot/run.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import chromadb
|
||||
import shutil
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
# start a new MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'query' not in existing_params:
|
||||
mlflow.log_param('query', args.query)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_chromadb_artifact": args.input_chromadb_artifact,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model_provider": args.chat_model_provider
|
||||
})
|
||||
|
||||
|
||||
logger.info("Downloading chromadb artifact")
|
||||
artifact_chromadb_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_chromadb_artifact)
|
||||
|
||||
# unzip the artifact
|
||||
logger.info("Unzipping the artifact")
|
||||
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", cot_output)
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", answer_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the chromadb artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
Loading…
x
Reference in New Issue
Block a user