working with deepseek

This commit is contained in:
leehk 2025-02-27 17:09:15 +08:00
parent 04e2764903
commit a0ac1fd961
5 changed files with 39 additions and 8 deletions

View File

@ -9,5 +9,6 @@ etl:
path_document_folder: "../../../../data" path_document_folder: "../../../../data"
embedding_model: paraphrase-multilingual-mpnet-base-v2 embedding_model: paraphrase-multilingual-mpnet-base-v2
prompt_engineering: prompt_engineering:
chat_model_provider: deepseek
query: "怎么治疗肺癌?" query: "怎么治疗肺癌?"

View File

@ -3,7 +3,6 @@ import json
import mlflow import mlflow
import tempfile import tempfile
import os import os
import wandb
import hydra import hydra
from omegaconf import DictConfig from omegaconf import DictConfig
from decouple import config from decouple import config
@ -15,9 +14,6 @@ _steps = [
"chain_of_thought" "chain_of_thought"
] ]
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
# This automatically reads in the configuration # This automatically reads in the configuration
@hydra.main(config_name='config') @hydra.main(config_name='config')
@ -71,7 +67,6 @@ def go(config: DictConfig):
"embedding_model": config["etl"]["embedding_model"] "embedding_model": config["etl"]["embedding_model"]
}, },
) )
if "chain_of_thought" in active_steps: if "chain_of_thought" in active_steps:
_ = mlflow.run( _ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"), os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
@ -80,6 +75,7 @@ def go(config: DictConfig):
"query": config["prompt_engineering"]["query"], "query": config["prompt_engineering"]["query"],
"input_chromadb_artifact": "chromdb.zip:latest", "input_chromadb_artifact": "chromdb.zip:latest",
"embedding_model": config["etl"]["embedding_model"], "embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
}, },
) )

View File

@ -17,8 +17,13 @@ entry_points:
description: Fully-qualified name for the embedding model description: Fully-qualified name for the embedding model
type: string type: string
chat_model_provider:
description: Fully-qualified name for the chat model provider
type: string
command: >- command: >-
python run.py --query {query} \ python run.py --query {query} \
--input_chromadb_artifact {input_chromadb_artifact} \ --input_chromadb_artifact {input_chromadb_artifact} \
--embedding_model {embedding_model} --embedding_model {embedding_model} \
--chat_model_provider {chat_model_provider}

View File

@ -10,6 +10,7 @@ build_dependencies:
- sentence_transformers - sentence_transformers
- python-decouple - python-decouple
- langchain_google_genai - langchain_google_genai
- langchain-deepseek
# Dependencies required to run the project. # Dependencies required to run the project.
dependencies: dependencies:
- mlflow==2.8.1 - mlflow==2.8.1

View File

@ -8,12 +8,14 @@ from decouple import config
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger() logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str) GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
def go(args): def go(args):
@ -37,13 +39,32 @@ def go(args):
# Formulate a question # Formulate a question
question = args.query question = args.query
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GEMINI_API_KEY) if args.chat_model_provider == "deepseek":
# Initialize DeepSeek model
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=DEEKSEEK_API_KEY
)
elif args.chat_model_provider == "gemini":
# Initialize Gemini model
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0,
max_retries=3
)
# Chain of Thought Prompt # Chain of Thought Prompt
cot_template = """Let's think step by step. cot_template = """Let's think step by step.
Given the following document in text: {documents_text} Given the following document in text: {documents_text}
Question: {question} Question: {question}
Reply with language that is similar to the language used with asked question.
""" """
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"]) cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
cot_chain = cot_prompt | llm cot_chain = cot_prompt | llm
@ -99,6 +120,13 @@ if __name__ == "__main__":
help="Sentence Transformer model name" help="Sentence Transformer model name"
) )
parser.add_argument(
"--chat_model_provider",
type=str,
default="gemini",
help="Chat model provider"
)
args = parser.parse_args() args = parser.parse_args()
go(args) go(args)