mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-07 15:53:45 +08:00
working with deepseek
This commit is contained in:
parent
04e2764903
commit
a0ac1fd961
@ -9,5 +9,6 @@ etl:
|
|||||||
path_document_folder: "../../../../data"
|
path_document_folder: "../../../../data"
|
||||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||||
prompt_engineering:
|
prompt_engineering:
|
||||||
|
chat_model_provider: deepseek
|
||||||
query: "怎么治疗肺癌?"
|
query: "怎么治疗肺癌?"
|
||||||
|
|
||||||
@ -3,7 +3,6 @@ import json
|
|||||||
import mlflow
|
import mlflow
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
import wandb
|
|
||||||
import hydra
|
import hydra
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from decouple import config
|
from decouple import config
|
||||||
@ -15,9 +14,6 @@ _steps = [
|
|||||||
"chain_of_thought"
|
"chain_of_thought"
|
||||||
]
|
]
|
||||||
|
|
||||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# This automatically reads in the configuration
|
# This automatically reads in the configuration
|
||||||
@hydra.main(config_name='config')
|
@hydra.main(config_name='config')
|
||||||
@ -71,7 +67,6 @@ def go(config: DictConfig):
|
|||||||
"embedding_model": config["etl"]["embedding_model"]
|
"embedding_model": config["etl"]["embedding_model"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chain_of_thought" in active_steps:
|
if "chain_of_thought" in active_steps:
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
||||||
@ -80,6 +75,7 @@ def go(config: DictConfig):
|
|||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["prompt_engineering"]["query"],
|
||||||
"input_chromadb_artifact": "chromdb.zip:latest",
|
"input_chromadb_artifact": "chromdb.zip:latest",
|
||||||
"embedding_model": config["etl"]["embedding_model"],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
|
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -17,8 +17,13 @@ entry_points:
|
|||||||
description: Fully-qualified name for the embedding model
|
description: Fully-qualified name for the embedding model
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
chat_model_provider:
|
||||||
|
description: Fully-qualified name for the chat model provider
|
||||||
|
type: string
|
||||||
|
|
||||||
|
|
||||||
command: >-
|
command: >-
|
||||||
python run.py --query {query} \
|
python run.py --query {query} \
|
||||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||||
--embedding_model {embedding_model}
|
--embedding_model {embedding_model} \
|
||||||
|
--chat_model_provider {chat_model_provider}
|
||||||
@ -10,6 +10,7 @@ build_dependencies:
|
|||||||
- sentence_transformers
|
- sentence_transformers
|
||||||
- python-decouple
|
- python-decouple
|
||||||
- langchain_google_genai
|
- langchain_google_genai
|
||||||
|
- langchain-deepseek
|
||||||
# Dependencies required to run the project.
|
# Dependencies required to run the project.
|
||||||
dependencies:
|
dependencies:
|
||||||
- mlflow==2.8.1
|
- mlflow==2.8.1
|
||||||
|
|||||||
@ -8,12 +8,14 @@ from decouple import config
|
|||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||||
|
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||||
|
|
||||||
|
|
||||||
def go(args):
|
def go(args):
|
||||||
@ -37,13 +39,32 @@ def go(args):
|
|||||||
# Formulate a question
|
# Formulate a question
|
||||||
question = args.query
|
question = args.query
|
||||||
|
|
||||||
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GEMINI_API_KEY)
|
if args.chat_model_provider == "deepseek":
|
||||||
|
# Initialize DeepSeek model
|
||||||
|
llm = ChatDeepSeek(
|
||||||
|
model="deepseek-chat",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
timeout=None,
|
||||||
|
max_retries=2,
|
||||||
|
api_key=DEEKSEEK_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.chat_model_provider == "gemini":
|
||||||
|
# Initialize Gemini model
|
||||||
|
llm = ChatGoogleGenerativeAI(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
google_api_key=GEMINI_API_KEY,
|
||||||
|
temperature=0,
|
||||||
|
max_retries=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Chain of Thought Prompt
|
# Chain of Thought Prompt
|
||||||
cot_template = """Let's think step by step.
|
cot_template = """Let's think step by step.
|
||||||
Given the following document in text: {documents_text}
|
Given the following document in text: {documents_text}
|
||||||
Question: {question}
|
Question: {question}
|
||||||
|
Reply with language that is similar to the language used with asked question.
|
||||||
"""
|
"""
|
||||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||||
cot_chain = cot_prompt | llm
|
cot_chain = cot_prompt | llm
|
||||||
@ -99,6 +120,13 @@ if __name__ == "__main__":
|
|||||||
help="Sentence Transformer model name"
|
help="Sentence Transformer model name"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat_model_provider",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Chat model provider"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
go(args)
|
go(args)
|
||||||
Loading…
x
Reference in New Issue
Block a user