mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
working with deepseek
This commit is contained in:
parent
04e2764903
commit
a0ac1fd961
@ -9,5 +9,6 @@ etl:
|
||||
path_document_folder: "../../../../data"
|
||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||
prompt_engineering:
|
||||
chat_model_provider: deepseek
|
||||
query: "怎么治疗肺癌?"
|
||||
|
||||
@ -3,7 +3,6 @@ import json
|
||||
import mlflow
|
||||
import tempfile
|
||||
import os
|
||||
import wandb
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
from decouple import config
|
||||
@ -15,9 +14,6 @@ _steps = [
|
||||
"chain_of_thought"
|
||||
]
|
||||
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
|
||||
|
||||
|
||||
# This automatically reads in the configuration
|
||||
@hydra.main(config_name='config')
|
||||
@ -71,7 +67,6 @@ def go(config: DictConfig):
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
|
||||
if "chain_of_thought" in active_steps:
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
||||
@ -80,6 +75,7 @@ def go(config: DictConfig):
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"input_chromadb_artifact": "chromdb.zip:latest",
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -17,8 +17,13 @@ entry_points:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||
--embedding_model {embedding_model}
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider}
|
||||
@ -10,6 +10,7 @@ build_dependencies:
|
||||
- sentence_transformers
|
||||
- python-decouple
|
||||
- langchain_google_genai
|
||||
- langchain-deepseek
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
|
||||
@ -8,12 +8,14 @@ from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
|
||||
|
||||
def go(args):
|
||||
@ -37,13 +39,32 @@ def go(args):
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GEMINI_API_KEY)
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
@ -99,6 +120,13 @@ if __name__ == "__main__":
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
Loading…
x
Reference in New Issue
Block a user