mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-08 08:13:20 +08:00
update for final working main
This commit is contained in:
parent
43ab6883b9
commit
04e2764903
1
.gitignore
vendored
1
.gitignore
vendored
@ -205,3 +205,4 @@ data/*
|
|||||||
**/.env
|
**/.env
|
||||||
**/llm-template2/*
|
**/llm-template2/*
|
||||||
**/llmops/outputs/*
|
**/llmops/outputs/*
|
||||||
|
**/*.zip
|
||||||
136
app/llmops/components/combine_chromadb/run.py
Normal file
136
app/llmops/components/combine_chromadb/run.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import wandb
|
||||||
|
import shutil
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
def combine_chromadb(chromadb_pdf_path, chromadb_scanned_pdf_path, output_path):
|
||||||
|
"""
|
||||||
|
Combines two ChromaDB instances into a single ChromaDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load the ChromaDB instances
|
||||||
|
chromadb_pdf_client = chromadb.PersistentClient(path=chromadb_pdf_path)
|
||||||
|
chromadb_scanned_pdf_client = chromadb.PersistentClient(path=chromadb_scanned_pdf_path)
|
||||||
|
|
||||||
|
# Get the collections
|
||||||
|
collection_name = "rag_experiment"
|
||||||
|
try:
|
||||||
|
chromadb_pdf_collection = chromadb_pdf_client.get_collection(name=collection_name)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Collection '{collection_name}' not found in ChromaDB at '{chromadb_pdf_path}'. Ensure the etl_chromdb_pdf step was run successfully.") from e
|
||||||
|
try:
|
||||||
|
chromadb_scanned_pdf_collection = chromadb_scanned_pdf_client.get_collection(name=collection_name)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Collection '{collection_name}' not found in ChromaDB at '{chromadb_scanned_pdf_path}'. Ensure the etl_chromdb_scanned_pdf step was run successfully.") from e
|
||||||
|
|
||||||
|
# Get all data from the collections
|
||||||
|
chromadb_pdf_data = chromadb_pdf_collection.get(include=["documents", "metadatas", "embeddings"])
|
||||||
|
chromadb_scanned_pdf_data = chromadb_scanned_pdf_collection.get(include=["documents", "metadatas", "embeddings"])
|
||||||
|
|
||||||
|
# Create a new ChromaDB instance
|
||||||
|
combined_chromadb_client = chromadb.PersistentClient(path=output_path)
|
||||||
|
combined_chromadb_collection = combined_chromadb_client.create_collection(name=collection_name)
|
||||||
|
|
||||||
|
# Add the data to the combined ChromaDB
|
||||||
|
combined_chromadb_collection.add(
|
||||||
|
documents=chromadb_pdf_data["documents"] + chromadb_scanned_pdf_data["documents"],
|
||||||
|
metadatas=chromadb_pdf_data["metadatas"] + chromadb_scanned_pdf_data["metadatas"],
|
||||||
|
ids=chromadb_pdf_data["ids"] + chromadb_scanned_pdf_data["ids"],
|
||||||
|
embeddings=chromadb_pdf_data["embeddings"] + chromadb_scanned_pdf_data["embeddings"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Combined ChromaDB created at {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def go(args):
|
||||||
|
"""
|
||||||
|
Run the combine chromadb component.
|
||||||
|
"""
|
||||||
|
|
||||||
|
run = wandb.init(job_type="combine_chromadb", entity='aimingmed')
|
||||||
|
run.config.update(args)
|
||||||
|
|
||||||
|
# Download the ChromaDB artifacts
|
||||||
|
logger.info("Downloading chromadb_pdf artifact")
|
||||||
|
chromadb_pdf_artifact = run.use_artifact(args.chromadb_pdf_artifact).file()
|
||||||
|
chromadb_pdf_path = os.path.join(chromadb_pdf_artifact, "chroma_db")
|
||||||
|
|
||||||
|
logger.info("Downloading chromadb_scanned_pdf artifact")
|
||||||
|
chromadb_scanned_pdf_artifact = run.use_artifact(args.chromadb_scanned_pdf_artifact).file()
|
||||||
|
chromadb_scanned_pdf_path = os.path.join(chromadb_scanned_pdf_artifact, "chroma_db")
|
||||||
|
|
||||||
|
# Create the output directory
|
||||||
|
output_folder = "combined_chromadb"
|
||||||
|
output_path = os.path.join(os.getcwd(), output_folder)
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
shutil.rmtree(output_path)
|
||||||
|
os.makedirs(output_path)
|
||||||
|
|
||||||
|
# Combine the ChromaDB instances
|
||||||
|
combine_chromadb(chromadb_pdf_path, chromadb_scanned_pdf_path, output_path)
|
||||||
|
|
||||||
|
# Create a new artifact
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
args.output_artifact,
|
||||||
|
type=args.output_type,
|
||||||
|
description=args.output_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Zip the database folder first
|
||||||
|
shutil.make_archive(output_path, 'zip', output_path)
|
||||||
|
|
||||||
|
# Add the database to the artifact
|
||||||
|
artifact.add_file(output_path + '.zip')
|
||||||
|
|
||||||
|
# Log the artifact
|
||||||
|
run.log_artifact(artifact)
|
||||||
|
|
||||||
|
# Finish the run
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
# clean up - remove zip
|
||||||
|
os.remove(output_path + '.zip')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Combine two ChromaDB instances into one.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chromadb_pdf_artifact",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Fully-qualified name for the ChromaDB PDF artifact",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chromadb_scanned_pdf_artifact",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Fully-qualified name for the ChromaDB Scanned PDF artifact",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_artifact",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Name for the output artifact",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Type for the output artifact",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_description",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Description for the output artifact",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
go(args)
|
||||||
@ -8,4 +8,6 @@ etl:
|
|||||||
document_folder: documents
|
document_folder: documents
|
||||||
path_document_folder: "../../../../data"
|
path_document_folder: "../../../../data"
|
||||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||||
|
prompt_engineering:
|
||||||
|
query: "怎么治疗肺癌?"
|
||||||
|
|
||||||
@ -12,16 +12,7 @@ _steps = [
|
|||||||
"get_documents",
|
"get_documents",
|
||||||
"etl_chromdb_pdf",
|
"etl_chromdb_pdf",
|
||||||
"etl_chromdb_scanned_pdf",
|
"etl_chromdb_scanned_pdf",
|
||||||
"data_check",
|
"chain_of_thought"
|
||||||
"data_split",
|
|
||||||
"train_random_forest_propensity",
|
|
||||||
"train_random_forest_revenue",
|
|
||||||
"train_lasso_revenue",
|
|
||||||
# NOTE: We do not include this in the steps so it is not run by mistake.
|
|
||||||
# You first need to promote a model export to "prod" before you can run this,
|
|
||||||
# then you need to run this step explicitly
|
|
||||||
"test_model",
|
|
||||||
"test_production"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||||
@ -81,135 +72,14 @@ def go(config: DictConfig):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if "data_check" in active_steps:
|
if "chain_of_thought" in active_steps:
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "data_check"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"csv": f"{config['data_check']['csv_to_check']}:latest",
|
"query": config["prompt_engineering"]["query"],
|
||||||
"ref": "clean_sample.csv:reference",
|
"input_chromadb_artifact": "chromdb.zip:latest",
|
||||||
"kl_threshold": config['data_check']['kl_threshold'],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
"min_age": config['etl']['min_age'],
|
|
||||||
"max_age": config['etl']['max_age'],
|
|
||||||
"min_tenure": config['etl']['min_tenure'],
|
|
||||||
"max_tenure": config['etl']['max_tenure']
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "data_split" in active_steps:
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "components", "train_val_test_split"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"input": "clean_sample.csv:latest",
|
|
||||||
"test_size": config['modeling']['test_size'],
|
|
||||||
"random_seed": config['modeling']['random_seed'],
|
|
||||||
"stratify_by": config['modeling']['stratify_by'],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "train_random_forest_propensity" in active_steps:
|
|
||||||
|
|
||||||
# NOTE: we need to serialize the random forest configuration into JSON
|
|
||||||
rf_config = os.path.abspath("rf_config.json")
|
|
||||||
with open(rf_config, "w+") as fp:
|
|
||||||
json.dump(dict(config["modeling"]["random_forest_classifier_propensity"].items()), fp) # DO NOT TOUCH
|
|
||||||
|
|
||||||
# NOTE: use the rf_config we created as the rf_config parameter for the train_random_forest
|
|
||||||
# step
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "train_random_forest_propensity"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"trainval_artifact": "trainval_data.csv:latest",
|
|
||||||
"val_size": config['modeling']['val_size'],
|
|
||||||
"random_seed": config['modeling']['random_seed'],
|
|
||||||
"ls_output_columns": ','.join(config['modeling']['ls_output_columns']),
|
|
||||||
"product": config['modeling']['product_to_train'],
|
|
||||||
"stratify_by": config['modeling']['stratify_by'],
|
|
||||||
"n_folds": config['modeling']['n_folds'],
|
|
||||||
"rf_config": rf_config,
|
|
||||||
"output_artifact": "random_forest_export",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "train_random_forest_revenue" in active_steps:
|
|
||||||
|
|
||||||
# NOTE: we need to serialize the random forest configuration into JSON
|
|
||||||
rf_config = os.path.abspath("rf_config_revenue.json")
|
|
||||||
with open(rf_config, "w+") as fp:
|
|
||||||
json.dump(dict(config["modeling"]["random_forest_regression_revenue"].items()), fp)
|
|
||||||
|
|
||||||
# NOTE: use the rf_config we created as the rf_config parameter for the train_random_forest
|
|
||||||
# step
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "train_random_forest_revenue"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"trainval_artifact": "trainval_data.csv:latest",
|
|
||||||
"val_size": config['modeling']['val_size'],
|
|
||||||
"random_seed": config['modeling']['random_seed'],
|
|
||||||
"ls_output_columns": ','.join(config['modeling']['ls_output_columns']),
|
|
||||||
"product": config['modeling']['product_to_train'],
|
|
||||||
"stratify_by": config['modeling']['stratify_by'],
|
|
||||||
"n_folds": config['modeling']['n_folds'],
|
|
||||||
"rf_config": rf_config,
|
|
||||||
"output_artifact": "random_forest_export",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "train_lasso_revenue" in active_steps:
|
|
||||||
|
|
||||||
# NOTE: use the lasso_config we created as the lasso_config parameter for the train_lasso
|
|
||||||
lasso_config = os.path.abspath("lasso_config.json")
|
|
||||||
with open(lasso_config, "w+") as fp:
|
|
||||||
json.dump(dict(config["modeling"]["lasso_regression_revenue"].items()), fp)
|
|
||||||
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "train_lasso_revenue"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"trainval_artifact": "trainval_data.csv:latest",
|
|
||||||
"val_size": config['modeling']['val_size'],
|
|
||||||
"random_seed": config['modeling']['random_seed'],
|
|
||||||
"ls_output_columns": ','.join(config['modeling']['ls_output_columns']),
|
|
||||||
"product": config['modeling']['product_to_train'],
|
|
||||||
"stratify_by": config['modeling']['stratify_by'],
|
|
||||||
"n_folds": config['modeling']['n_folds'],
|
|
||||||
"lasso_config": lasso_config,
|
|
||||||
"output_artifact": "lasso_export",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "test_model" in active_steps:
|
|
||||||
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_model"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"model_propensity_cc": config['best_model_propensity']['propensity_cc'],
|
|
||||||
"model_propensity_cl": config['best_model_propensity']['propensity_cl'],
|
|
||||||
"model_propensity_mf": config['best_model_propensity']['propensity_mf'],
|
|
||||||
"model_revenue_cc": config['best_model_revenue']['revenue_cc'],
|
|
||||||
"model_revenue_cl": config['best_model_revenue']['revenue_cl'],
|
|
||||||
"model_revenue_mf": config['best_model_revenue']['revenue_mf'],
|
|
||||||
"test_dataset": "test_data.csv:latest",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if "test_production" in active_steps:
|
|
||||||
|
|
||||||
_ = mlflow.run(
|
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_production"),
|
|
||||||
"main",
|
|
||||||
parameters={
|
|
||||||
"model_propensity_cc": config['best_model_propensity']['propensity_cc'],
|
|
||||||
"model_propensity_cl": config['best_model_propensity']['propensity_cl'],
|
|
||||||
"model_propensity_mf": config['best_model_propensity']['propensity_mf'],
|
|
||||||
"model_revenue_cc": config['best_model_revenue']['revenue_cc'],
|
|
||||||
"model_revenue_cl": config['best_model_revenue']['revenue_cl'],
|
|
||||||
"model_revenue_mf": config['best_model_revenue']['revenue_mf'],
|
|
||||||
"test_dataset": f"{config['production']['test_csv']}:latest",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
24
app/llmops/src/chain_of_thought/MLproject
Normal file
24
app/llmops/src/chain_of_thought/MLproject
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
name: chain_of_thought
|
||||||
|
python_env: python_env.yml
|
||||||
|
|
||||||
|
entry_points:
|
||||||
|
main:
|
||||||
|
parameters:
|
||||||
|
|
||||||
|
query:
|
||||||
|
description: Query to run
|
||||||
|
type: string
|
||||||
|
|
||||||
|
input_chromadb_artifact:
|
||||||
|
description: Fully-qualified name for the input artifact
|
||||||
|
type: string
|
||||||
|
|
||||||
|
embedding_model:
|
||||||
|
description: Fully-qualified name for the embedding model
|
||||||
|
type: string
|
||||||
|
|
||||||
|
|
||||||
|
command: >-
|
||||||
|
python run.py --query {query} \
|
||||||
|
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||||
|
--embedding_model {embedding_model}
|
||||||
16
app/llmops/src/chain_of_thought/python_env.yml
Normal file
16
app/llmops/src/chain_of_thought/python_env.yml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Python version required to run the project.
|
||||||
|
python: "3.11.11"
|
||||||
|
# Dependencies required to build packages. This field is optional.
|
||||||
|
build_dependencies:
|
||||||
|
- pip==23.3.1
|
||||||
|
- setuptools
|
||||||
|
- wheel==0.37.1
|
||||||
|
- chromadb
|
||||||
|
- langchain
|
||||||
|
- sentence_transformers
|
||||||
|
- python-decouple
|
||||||
|
- langchain_google_genai
|
||||||
|
# Dependencies required to run the project.
|
||||||
|
dependencies:
|
||||||
|
- mlflow==2.8.1
|
||||||
|
- wandb==0.16.0
|
||||||
104
app/llmops/src/chain_of_thought/run.py
Normal file
104
app/llmops/src/chain_of_thought/run.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import wandb
|
||||||
|
import chromadb
|
||||||
|
import shutil
|
||||||
|
from decouple import config
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||||
|
|
||||||
|
|
||||||
|
def go(args):
|
||||||
|
run = wandb.init(job_type="chain_of_thought", entity='aimingmed')
|
||||||
|
run.config.update(args)
|
||||||
|
|
||||||
|
logger.info("Downloading chromadb artifact")
|
||||||
|
artifact_chromadb_local_path = run.use_artifact(args.input_chromadb_artifact).file()
|
||||||
|
|
||||||
|
# unzip the artifact
|
||||||
|
logger.info("Unzipping the artifact")
|
||||||
|
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||||
|
|
||||||
|
# Load data from ChromaDB
|
||||||
|
db_folder = "chroma_db"
|
||||||
|
db_path = os.path.join(os.getcwd(), db_folder)
|
||||||
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||||
|
collection_name = "rag_experiment"
|
||||||
|
collection = chroma_client.get_collection(name=collection_name)
|
||||||
|
|
||||||
|
# Formulate a question
|
||||||
|
question = args.query
|
||||||
|
|
||||||
|
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GEMINI_API_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
# Chain of Thought Prompt
|
||||||
|
cot_template = """Let's think step by step.
|
||||||
|
Given the following document in text: {documents_text}
|
||||||
|
Question: {question}
|
||||||
|
"""
|
||||||
|
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||||
|
cot_chain = cot_prompt | llm
|
||||||
|
|
||||||
|
# Initialize embedding model (do this ONCE)
|
||||||
|
model = SentenceTransformer(args.embedding_model)
|
||||||
|
|
||||||
|
# Query (prompt)
|
||||||
|
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||||
|
|
||||||
|
# Search ChromaDB
|
||||||
|
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||||
|
|
||||||
|
# Generate chain of thought
|
||||||
|
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||||
|
print("Chain of Thought: ", cot_output)
|
||||||
|
|
||||||
|
# Answer Prompt
|
||||||
|
answer_template = """Given the chain of thought: {cot}
|
||||||
|
Provide a concise answer to the question: {question}
|
||||||
|
Provide the answer with language that is similar to the question asked.
|
||||||
|
"""
|
||||||
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||||
|
answer_chain = answer_prompt | llm
|
||||||
|
|
||||||
|
# Generate answer
|
||||||
|
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||||
|
print("Answer: ", answer_output)
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
help="Question to ask the model",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_chromadb_artifact",
|
||||||
|
type=str,
|
||||||
|
help="Fully-qualified name for the chromadb artifact",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
help="Sentence Transformer model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
go(args)
|
||||||
@ -70,7 +70,7 @@ def go(args):
|
|||||||
|
|
||||||
|
|
||||||
# Initialize embedding model (do this ONCE)
|
# Initialize embedding model (do this ONCE)
|
||||||
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
|
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||||
|
|
||||||
|
|
||||||
# Create database, delete the database directory if it exists
|
# Create database, delete the database directory if it exists
|
||||||
@ -115,6 +115,30 @@ def go(args):
|
|||||||
embeddings=[model_embedding.encode(split.page_content)]
|
embeddings=[model_embedding.encode(split.page_content)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a new artifact
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
args.output_artifact,
|
||||||
|
type=args.output_type,
|
||||||
|
description=args.output_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# zip the database folder first
|
||||||
|
shutil.make_archive(db_path, 'zip', db_path)
|
||||||
|
|
||||||
|
# Add the database to the artifact
|
||||||
|
artifact.add_file(db_path + '.zip')
|
||||||
|
|
||||||
|
# Log the artifact
|
||||||
|
run.log_artifact(artifact)
|
||||||
|
|
||||||
|
# Finish the run
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
# clean up
|
||||||
|
os.remove(db_path + '.zip')
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user