diff --git a/demo/langchain_demo.py b/demo/langchain_demo.py index cb2e79d..7d903c6 100644 --- a/demo/langchain_demo.py +++ b/demo/langchain_demo.py @@ -29,27 +29,35 @@ import re import gradio as gr parser = ArgumentParser() + # 大语言模型参数设置 parser.add_argument( "--cpm_model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16", + help="MiniCPM模型路径或者huggingface id" ) parser.add_argument( - "--cpm_device", type=str, default="cuda:0", choices=["auto", "cuda:0"] + "--cpm_device", type=str, default="cuda:0", choices=["auto", "cuda:0"], + help="MiniCPM模型所在设备,默认为cuda:0" +) +parser.add_argument("--backend", type=str, default="torch", choices=["torch", "vllm"], + help="使用torch还是vllm后端,默认为torch" ) -parser.add_argument("--backend", type=str, default="torch", choices=["torch", "vllm"]) # 嵌入模型参数设置 parser.add_argument( - "--encode_model", type=str, default="BAAI/bge-base-zh" + "--encode_model", type=str, default="BAAI/bge-base-zh", + help="用于召回编码的embedding模型,默认为BAAI/bge-base-zh,可输入本地地址" ) parser.add_argument( - "--encode_model_device", type=str, default="cpu", choices=["cpu", "cuda:0"] + "--encode_model_device", type=str, default="cpu", choices=["cpu", "cuda:0"], + help="用于召回编码的embedding模型所在设备,默认为cpu" ) -parser.add_argument("--query_instruction", type=str, default="") +parser.add_argument("--query_instruction", type=str, default="",help="召回时增加的前缀") parser.add_argument( - "--file_path", type=str, default="/root/ld/pull_request/rag/红楼梦.pdf" + "--file_path", type=str, default="/root/ld/pull_request/rag/红楼梦.pdf", + help="需要检索的文本文件路径,gradio运行时无效" ) # 生成参数 @@ -60,9 +68,9 @@ parser.add_argument("--max_new_tokens", type=int, default=4096) parser.add_argument("--repetition_penalty", type=float, default=1.02) # retriever参数设置 -parser.add_argument("--embed_top_k", type=int, default=5) -parser.add_argument("--chunk_size", type=int, default=256) -parser.add_argument("--chunk_overlap", type=int, default=50) +parser.add_argument("--embed_top_k", type=int, default=5,help="召回几个最相似的文本") +parser.add_argument("--chunk_size", type=int, default=256,help="文本切分时切分的长度") +parser.add_argument("--chunk_overlap", type=int, default=50,help="文本切分的重叠长度") args = parser.parse_args()