diff --git a/app/llmops/config.yaml b/app/llmops/config.yaml index 016e5a3..a4451ac 100644 --- a/app/llmops/config.yaml +++ b/app/llmops/config.yaml @@ -9,5 +9,6 @@ etl: path_document_folder: "../../../../data" embedding_model: paraphrase-multilingual-mpnet-base-v2 prompt_engineering: + chat_model_provider: deepseek query: "怎么治疗肺癌?" \ No newline at end of file diff --git a/app/llmops/main.py b/app/llmops/main.py index 39029ab..b8260d1 100644 --- a/app/llmops/main.py +++ b/app/llmops/main.py @@ -3,7 +3,6 @@ import json import mlflow import tempfile import os -import wandb import hydra from omegaconf import DictConfig from decouple import config @@ -15,9 +14,6 @@ _steps = [ "chain_of_thought" ] -GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str) - - # This automatically reads in the configuration @hydra.main(config_name='config') @@ -71,7 +67,6 @@ def go(config: DictConfig): "embedding_model": config["etl"]["embedding_model"] }, ) - if "chain_of_thought" in active_steps: _ = mlflow.run( os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"), @@ -80,6 +75,7 @@ def go(config: DictConfig): "query": config["prompt_engineering"]["query"], "input_chromadb_artifact": "chromdb.zip:latest", "embedding_model": config["etl"]["embedding_model"], + "chat_model_provider": config["prompt_engineering"]["chat_model_provider"] }, ) diff --git a/app/llmops/src/chain_of_thought/MLproject b/app/llmops/src/chain_of_thought/MLproject index a1b776d..d317b94 100644 --- a/app/llmops/src/chain_of_thought/MLproject +++ b/app/llmops/src/chain_of_thought/MLproject @@ -17,8 +17,13 @@ entry_points: description: Fully-qualified name for the embedding model type: string + chat_model_provider: + description: Fully-qualified name for the chat model provider + type: string + command: >- python run.py --query {query} \ --input_chromadb_artifact {input_chromadb_artifact} \ - --embedding_model {embedding_model} \ No newline at end of file + --embedding_model {embedding_model} \ + --chat_model_provider {chat_model_provider} \ No newline at end of file diff --git a/app/llmops/src/chain_of_thought/python_env.yml b/app/llmops/src/chain_of_thought/python_env.yml index b5708cb..1c6198e 100644 --- a/app/llmops/src/chain_of_thought/python_env.yml +++ b/app/llmops/src/chain_of_thought/python_env.yml @@ -10,6 +10,7 @@ build_dependencies: - sentence_transformers - python-decouple - langchain_google_genai + - langchain-deepseek # Dependencies required to run the project. dependencies: - mlflow==2.8.1 diff --git a/app/llmops/src/chain_of_thought/run.py b/app/llmops/src/chain_of_thought/run.py index 50727ce..50df938 100644 --- a/app/llmops/src/chain_of_thought/run.py +++ b/app/llmops/src/chain_of_thought/run.py @@ -8,12 +8,14 @@ from decouple import config from langchain.prompts import PromptTemplate from sentence_transformers import SentenceTransformer from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_deepseek import ChatDeepSeek logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logger = logging.getLogger() os.environ["TOKENIZERS_PARALLELISM"] = "false" GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str) +DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str) def go(args): @@ -37,13 +39,32 @@ def go(args): # Formulate a question question = args.query - llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GEMINI_API_KEY) - + if args.chat_model_provider == "deepseek": + # Initialize DeepSeek model + llm = ChatDeepSeek( + model="deepseek-chat", + temperature=0, + max_tokens=None, + timeout=None, + max_retries=2, + api_key=DEEKSEEK_API_KEY + ) + + elif args.chat_model_provider == "gemini": + # Initialize Gemini model + llm = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key=GEMINI_API_KEY, + temperature=0, + max_retries=3 + ) + # Chain of Thought Prompt cot_template = """Let's think step by step. Given the following document in text: {documents_text} Question: {question} + Reply with language that is similar to the language used with asked question. """ cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"]) cot_chain = cot_prompt | llm @@ -99,6 +120,13 @@ if __name__ == "__main__": help="Sentence Transformer model name" ) + parser.add_argument( + "--chat_model_provider", + type=str, + default="gemini", + help="Chat model provider" + ) + args = parser.parse_args() go(args) \ No newline at end of file