集成LOOM在线embedding业务

This commit is contained in:
glide-the 2024-01-25 19:04:43 +08:00 committed by liunux4odoo
parent 802cfe8805
commit 89b0d467ea
12 changed files with 143 additions and 47 deletions

View File

@ -14,9 +14,9 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),

View File

@ -57,9 +57,9 @@ def _parse_files_in_thread(
def upload_temp_docs(
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
@ -110,9 +110,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),

View File

@ -3,16 +3,22 @@ from server.db.session import with_session
@with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None,
endpoint_host_key: str = None, endpoint_host_proxy: str = None):
# 创建知识库实例
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model,
endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
session.add(kb)
else: # update kb with new vs_type and embed_model
kb.kb_info = kb_info
kb.vs_type = vs_type
kb.embed_model = embed_model
kb.endpoint_host = endpoint_host
kb.endpoint_host_key = endpoint_host_key
kb.endpoint_host_proxy = endpoint_host_proxy
return True

View File

@ -78,9 +78,9 @@ async def aembed_texts(
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:

View File

@ -15,6 +15,9 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
@ -28,7 +31,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb()
kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy)
except Exception as e:
msg = f"创建知识库出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',

View File

@ -346,6 +346,9 @@ def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
@ -366,7 +369,9 @@ def recreate_vector_store(
else:
if kb.exists():
kb.clear_vs()
kb.create_kb()
kb.create_kb(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0

View File

@ -58,14 +58,23 @@ class KBService(ABC):
'''
pass
def create_kb(self):
def create_kb(self,
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
if status:
self.do_create_kb()
return status
def clear_vs(self):

View File

@ -19,9 +19,9 @@ def recreate_summary_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -112,9 +112,9 @@ def summary_file_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -196,9 +196,9 @@ def summary_doc_ids_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
endpoint_host: str = Body(None, description="接入点地址"),
endpoint_host_key: str = Body(None, description="接入点key"),
endpoint_host_proxy: str = Body(None, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),

View File

@ -4,7 +4,7 @@ import streamlit as st
from streamlit_antd_components.utils import ParseItems
from webui_pages.dialogue.utils import process_files
from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \
from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
get_select_model_endpoint
from webui_pages.utils import *
from streamlit_chatbox import *
@ -132,7 +132,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
conversation_id = st.session_state["conversation_ids"][conversation_name]
with st.expander("模型选择"):
plugins_menu = build_plugins_name()
plugins_menu = build_providers_model_plugins_name()
items, _ = ParseItems(plugins_menu).multi()

View File

@ -1,4 +1,8 @@
import streamlit as st
from streamlit_antd_components.utils import ParseItems
from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \
set_llm_select, set_embed_select, get_select_embed_endpoint
from webui_pages.utils import *
from st_aggrid import AgGrid, JsCode
from st_aggrid.grid_options_builder import GridOptionsBuilder
@ -7,12 +11,16 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY)
from server.utils import list_embed_models
import streamlit_antd_components as sac
import os
import time
# SENTENCE_SIZE = 100
cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""")
@ -96,24 +104,37 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="kb_info",
)
cols = st.columns(2)
col0, _ = st.columns([3, 1])
vs_types = list(kbs_config.keys())
vs_type = cols[0].selectbox(
vs_type = col0.selectbox(
"向量库类型",
vs_types,
index=vs_types.index(DEFAULT_VS_TYPE),
key="vs_type",
)
embed_models = list_embed_models()
col1, _ = st.columns([3, 1])
with col1:
col1.text("Embedding 模型")
plugins_menu = build_providers_embedding_plugins_name()
embed_model = cols[1].selectbox(
"Embedding 模型",
embed_models,
index=embed_models.index(EMBEDDING_MODEL),
key="embed_model",
)
embed_models = list_embed_models()
menu_item_children = []
for model in embed_models:
menu_item_children.append(sac.MenuItem(model, description=model))
plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children))
items, _ = ParseItems(plugins_menu).multi()
if len(plugins_menu) > 0:
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False)
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
set_embed_select(plugins_info, llm_model_worker)
else:
st.info("没有可用的插件")
submit_create_kb = st.form_submit_button(
"新建",
@ -122,15 +143,23 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
)
if submit_create_kb:
endpoint_host, select_embed_model_name = get_select_embed_endpoint()
if not kb_name or not kb_name.strip():
st.error(f"知识库名称不能为空!")
elif kb_name in kb_list:
st.error(f"名为 {kb_name} 的知识库已经存在!")
elif select_embed_model_name is None:
st.error(f"请选择Embedding模型")
else:
ret = api.create_knowledge_base(
knowledge_base_name=kb_name,
vector_store_type=vs_type,
embed_model=embed_model,
embed_model=select_embed_model_name,
endpoint_host=endpoint_host,
endpoint_host_key=OPENAI_KEY,
endpoint_host_proxy=OPENAI_PROXY,
)
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
@ -249,7 +278,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
# 将文件分词并加载到向量库中
if cols[1].button(
"重新添加至向量库" if selected_rows and (
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True,
):

View File

@ -3,6 +3,7 @@ from typing import Tuple, Any
import streamlit as st
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
import logging
logger = logging.getLogger(__name__)
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
@ -45,11 +46,11 @@ def start_plugin():
st.toast("start_plugin " + start_plugins_name + ",starting.")
result = client.launch_subscribe(start_plugins_name)
st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", ""))
st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", ""))
time.sleep(3)
result1 = client.launch_subscribe_start(start_plugins_name)
st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", ""))
st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", ""))
time.sleep(2)
update_store()
@ -103,7 +104,7 @@ def stop_worker():
update_store()
def build_plugins_name():
def build_providers_model_plugins_name():
import streamlit_antd_components as sac
if "run_plugins_list" not in st.session_state:
return []
@ -112,7 +113,25 @@ def build_plugins_name():
for key, value in st.session_state.list_running_models.items():
menu_item_children = []
for model in value:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
if "model" in model["providers"]:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
return menu_items
def build_providers_embedding_plugins_name():
import streamlit_antd_components as sac
if "run_plugins_list" not in st.session_state:
return []
# 按照模型构建sac.menu(菜单
menu_items = []
for key, value in st.session_state.list_running_models.items():
menu_item_children = []
for model in value:
if "embedding" in model["providers"]:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
@ -144,3 +163,22 @@ def get_select_model_endpoint() -> Tuple[str, str]:
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
return endpoint_host, select_model_name
def set_embed_select(plugins_info, embed_model_worker):
st.session_state["select_embed_plugins_info"] = plugins_info
st.session_state["select_embed_model_worker"] = embed_model_worker
def get_select_embed_endpoint() -> Tuple[str, str]:
select_embed_plugins_info = st.session_state["select_embed_plugins_info"]
select_embed_model_worker = st.session_state["select_embed_model_worker"]
if select_embed_plugins_info is None or select_embed_model_worker is None:
raise ValueError("select_embed_plugins_info or select_embed_model_worker is None")
embed_plugins_name = st.session_state["select_embed_plugins_info"]['label']
select_embed_model_name = st.session_state["select_embed_model_worker"]['label']
endpoint_host = None
if embed_plugins_name in st.session_state.launch_subscribe_info:
adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name]
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
return endpoint_host, select_embed_model_name

View File

@ -382,6 +382,9 @@ class ApiRequest:
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
endpoint_host: str = None,
endpoint_host_key: str = None,
endpoint_host_proxy: str = None
):
'''
对应api.py/knowledge_base/create_knowledge_base接口
@ -390,6 +393,9 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type,
"embed_model": embed_model,
"endpoint_host": endpoint_host,
"endpoint_host_key": endpoint_host_key,
"endpoint_host_proxy": endpoint_host_proxy,
}
response = self.post(