diff --git a/server/chat/chat.py b/server/chat/chat.py index 44d42ebd..d80a1c98 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import AsyncIterable, List, Union, Dict +from typing import AsyncIterable, List, Union, Dict, Annotated from fastapi import Body from fastapi.responses import StreamingResponse @@ -103,7 +103,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), conversation_id: str = Body("", description="对话框ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), - history: Union[int, List[History]] = Body( + history: List[History] = Body( [], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[ @@ -115,9 +115,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 ] ), stream: bool = Body(True, description="流式输出"), - model_config: Dict = Body({}, description="LLM 模型配置"), - openai_config: Dict = Body({}, description="openaiEndpoint配置"), - tool_config: Dict = Body({}, description="工具配置"), + chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), + openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]), + tool_config: dict = Body({}, description="工具配置", examples=[]), ): async def chat_iterator() -> AsyncIterable[str]: message_id = add_message_to_db( @@ -128,7 +128,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] - models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, + models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, openai_config=openai_config, stream=stream) tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 24f4b693..9ba7e0c3 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -145,7 +145,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.info("没有可用的插件") # 传入后端的内容 - model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} + chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} tool_use = True for key in LLM_MODEL_CONFIG: if key == 'llm_model': @@ -158,7 +158,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): continue if LLM_MODEL_CONFIG[key]: first_key = next(iter(LLM_MODEL_CONFIG[key])) - model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] + chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] # 选择工具 selected_tool_configs = {} @@ -179,7 +179,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): llm_model = st.session_state["select_model_worker"]['label'] if llm_model is not None: - model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) + chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) files_upload = process_files(files=[uploaded_file]) if uploaded_file else None @@ -223,7 +223,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.rerun() else: history = get_messages_history( - model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1) + chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1) ) chat_box.user_say(prompt) if files_upload: @@ -253,7 +253,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): for d in api.chat_chat(query=prompt, metadata=files_upload, history=history, - model_config=model_config, + chat_model_config=chat_model_config, openai_config=openai_config, conversation_id=conversation_id, tool_config=selected_tool_configs, diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 76a97377..1443dfb7 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -7,9 +7,9 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs import (kbs_config, - EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) -from server.utils import list_embed_models, list_online_embed_models + EMBEDDING_MODEL, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from server.utils import list_embed_models import os import time @@ -106,10 +106,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): key="vs_type", ) - if is_lite: - embed_models = list_online_embed_models() - else: - embed_models = list_embed_models() + list_online_embed_models() + embed_models = list_embed_models() embed_model = cols[1].selectbox( "Embedding 模型", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 7c3acf1d..09dd44d9 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -265,7 +265,7 @@ class ApiRequest: history_len: int = -1, history: List[Dict] = [], stream: bool = True, - model_config: Dict = None, + chat_model_config: Dict = None, openai_config: Dict = None, tool_config: Dict = None, **kwargs, @@ -280,7 +280,7 @@ class ApiRequest: "history_len": history_len, "history": history, "stream": stream, - "model_config": model_config, + "chat_model_config": chat_model_config, "openai_config": openai_config, "tool_config": tool_config, }