mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-27 09:13:25 +08:00
集成LOOM在线embedding业务
This commit is contained in:
parent
802cfe8805
commit
89b0d467ea
@ -14,9 +14,9 @@ from server.utils import get_prompt_template
|
||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
|
||||
@ -57,9 +57,9 @@ def _parse_files_in_thread(
|
||||
|
||||
|
||||
def upload_temp_docs(
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
prev_id: str = Form(None, description="前知识库ID"),
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
@ -110,9 +110,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
|
||||
@ -3,16 +3,22 @@ from server.db.session import with_session
|
||||
|
||||
|
||||
@with_session
|
||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
|
||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None,
|
||||
endpoint_host_key: str = None, endpoint_host_proxy: str = None):
|
||||
# 创建知识库实例
|
||||
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||
if not kb:
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model,
|
||||
endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
session.add(kb)
|
||||
else: # update kb with new vs_type and embed_model
|
||||
kb.kb_info = kb_info
|
||||
kb.vs_type = vs_type
|
||||
kb.embed_model = embed_model
|
||||
kb.endpoint_host = endpoint_host
|
||||
kb.endpoint_host_key = endpoint_host_key
|
||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -78,9 +78,9 @@ async def aembed_texts(
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
|
||||
@ -15,6 +15,9 @@ def list_kbs():
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
@ -28,7 +31,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
try:
|
||||
kb.create_kb()
|
||||
kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy)
|
||||
except Exception as e:
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
|
||||
@ -346,6 +346,9 @@ def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
@ -366,7 +369,9 @@ def recreate_vector_store(
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
kb.create_kb(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
|
||||
@ -58,14 +58,23 @@ class KBService(ABC):
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self):
|
||||
def create_kb(self,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None):
|
||||
"""
|
||||
创建知识库
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
|
||||
if status:
|
||||
self.do_create_kb()
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
|
||||
@ -19,9 +19,9 @@ def recreate_summary_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -112,9 +112,9 @@ def summary_file_to_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -196,9 +196,9 @@ def summary_doc_ids_to_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
endpoint_host: str = Body(None, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(None, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
|
||||
@ -4,7 +4,7 @@ import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from webui_pages.dialogue.utils import process_files
|
||||
from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
get_select_model_endpoint
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
@ -132,7 +132,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
with st.expander("模型选择"):
|
||||
plugins_menu = build_plugins_name()
|
||||
plugins_menu = build_providers_model_plugins_name()
|
||||
|
||||
items, _ = ParseItems(plugins_menu).multi()
|
||||
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
|
||||
set_llm_select, set_embed_select, get_select_embed_endpoint
|
||||
from webui_pages.utils import *
|
||||
from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
@ -7,12 +11,16 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs import (kbs_config,
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY)
|
||||
from server.utils import list_embed_models
|
||||
|
||||
import streamlit_antd_components as sac
|
||||
import os
|
||||
import time
|
||||
|
||||
# SENTENCE_SIZE = 100
|
||||
|
||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||
|
||||
|
||||
@ -96,24 +104,37 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
key="kb_info",
|
||||
)
|
||||
|
||||
cols = st.columns(2)
|
||||
col0, _ = st.columns([3, 1])
|
||||
|
||||
vs_types = list(kbs_config.keys())
|
||||
vs_type = cols[0].selectbox(
|
||||
vs_type = col0.selectbox(
|
||||
"向量库类型",
|
||||
vs_types,
|
||||
index=vs_types.index(DEFAULT_VS_TYPE),
|
||||
key="vs_type",
|
||||
)
|
||||
|
||||
embed_models = list_embed_models()
|
||||
col1, _ = st.columns([3, 1])
|
||||
with col1:
|
||||
col1.text("Embedding 模型")
|
||||
plugins_menu = build_providers_embedding_plugins_name()
|
||||
|
||||
embed_model = cols[1].selectbox(
|
||||
"Embedding 模型",
|
||||
embed_models,
|
||||
index=embed_models.index(EMBEDDING_MODEL),
|
||||
key="embed_model",
|
||||
)
|
||||
embed_models = list_embed_models()
|
||||
menu_item_children = []
|
||||
for model in embed_models:
|
||||
menu_item_children.append(sac.MenuItem(model, description=model))
|
||||
|
||||
plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children))
|
||||
|
||||
items, _ = ParseItems(plugins_menu).multi()
|
||||
|
||||
if len(plugins_menu) > 0:
|
||||
|
||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False)
|
||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
||||
set_embed_select(plugins_info, llm_model_worker)
|
||||
else:
|
||||
st.info("没有可用的插件")
|
||||
|
||||
submit_create_kb = st.form_submit_button(
|
||||
"新建",
|
||||
@ -122,15 +143,23 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
)
|
||||
|
||||
if submit_create_kb:
|
||||
|
||||
endpoint_host, select_embed_model_name = get_select_embed_endpoint()
|
||||
if not kb_name or not kb_name.strip():
|
||||
st.error(f"知识库名称不能为空!")
|
||||
elif kb_name in kb_list:
|
||||
st.error(f"名为 {kb_name} 的知识库已经存在!")
|
||||
elif select_embed_model_name is None:
|
||||
st.error(f"请选择Embedding模型!")
|
||||
else:
|
||||
|
||||
ret = api.create_knowledge_base(
|
||||
knowledge_base_name=kb_name,
|
||||
vector_store_type=vs_type,
|
||||
embed_model=embed_model,
|
||||
embed_model=select_embed_model_name,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=OPENAI_KEY,
|
||||
endpoint_host_proxy=OPENAI_PROXY,
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
@ -249,7 +278,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||
# 将文件分词并加载到向量库中
|
||||
if cols[1].button(
|
||||
"重新添加至向量库" if selected_rows and (
|
||||
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
|
||||
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
|
||||
disabled=not file_exists(kb, selected_rows)[0],
|
||||
use_container_width=True,
|
||||
):
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Tuple, Any
|
||||
import streamlit as st
|
||||
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
|
||||
|
||||
@ -45,11 +46,11 @@ def start_plugin():
|
||||
|
||||
st.toast("start_plugin " + start_plugins_name + ",starting.")
|
||||
result = client.launch_subscribe(start_plugins_name)
|
||||
st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", ""))
|
||||
st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", ""))
|
||||
time.sleep(3)
|
||||
result1 = client.launch_subscribe_start(start_plugins_name)
|
||||
|
||||
st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", ""))
|
||||
st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", ""))
|
||||
time.sleep(2)
|
||||
update_store()
|
||||
|
||||
@ -103,7 +104,7 @@ def stop_worker():
|
||||
update_store()
|
||||
|
||||
|
||||
def build_plugins_name():
|
||||
def build_providers_model_plugins_name():
|
||||
import streamlit_antd_components as sac
|
||||
if "run_plugins_list" not in st.session_state:
|
||||
return []
|
||||
@ -112,7 +113,25 @@ def build_plugins_name():
|
||||
for key, value in st.session_state.list_running_models.items():
|
||||
menu_item_children = []
|
||||
for model in value:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
if "model" in model["providers"]:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
|
||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||
|
||||
return menu_items
|
||||
|
||||
|
||||
def build_providers_embedding_plugins_name():
|
||||
import streamlit_antd_components as sac
|
||||
if "run_plugins_list" not in st.session_state:
|
||||
return []
|
||||
# 按照模型构建sac.menu(菜单
|
||||
menu_items = []
|
||||
for key, value in st.session_state.list_running_models.items():
|
||||
menu_item_children = []
|
||||
for model in value:
|
||||
if "embedding" in model["providers"]:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
|
||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||
|
||||
@ -144,3 +163,22 @@ def get_select_model_endpoint() -> Tuple[str, str]:
|
||||
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
|
||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||
return endpoint_host, select_model_name
|
||||
|
||||
|
||||
def set_embed_select(plugins_info, embed_model_worker):
|
||||
st.session_state["select_embed_plugins_info"] = plugins_info
|
||||
st.session_state["select_embed_model_worker"] = embed_model_worker
|
||||
|
||||
|
||||
def get_select_embed_endpoint() -> Tuple[str, str]:
|
||||
select_embed_plugins_info = st.session_state["select_embed_plugins_info"]
|
||||
select_embed_model_worker = st.session_state["select_embed_model_worker"]
|
||||
if select_embed_plugins_info is None or select_embed_model_worker is None:
|
||||
raise ValueError("select_embed_plugins_info or select_embed_model_worker is None")
|
||||
embed_plugins_name = st.session_state["select_embed_plugins_info"]['label']
|
||||
select_embed_model_name = st.session_state["select_embed_model_worker"]['label']
|
||||
endpoint_host = None
|
||||
if embed_plugins_name in st.session_state.launch_subscribe_info:
|
||||
adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name]
|
||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||
return endpoint_host, select_embed_model_name
|
||||
|
||||
@ -382,6 +382,9 @@ class ApiRequest:
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
endpoint_host: str = None,
|
||||
endpoint_host_key: str = None,
|
||||
endpoint_host_proxy: str = None
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/create_knowledge_base接口
|
||||
@ -390,6 +393,9 @@ class ApiRequest:
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"vector_store_type": vector_store_type,
|
||||
"embed_model": embed_model,
|
||||
"endpoint_host": endpoint_host,
|
||||
"endpoint_host_key": endpoint_host_key,
|
||||
"endpoint_host_proxy": endpoint_host_proxy,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user