fix model_config是系统关键词问题

This commit is contained in:
glide-the 2024-01-25 01:58:23 +08:00 committed by liunux4odoo
parent 4e358db525
commit 217bb61448
4 changed files with 17 additions and 20 deletions

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import json import json
from typing import AsyncIterable, List, Union, Dict from typing import AsyncIterable, List, Union, Dict, Annotated
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -103,7 +103,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"), conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body( history: List[History] = Body(
[], [],
description="历史对话,设为一个整数可以从数据库中读取历史消息", description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[ examples=[
@ -115,9 +115,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
] ]
), ),
stream: bool = Body(True, description="流式输出"), stream: bool = Body(True, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置"), chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
openai_config: Dict = Body({}, description="openaiEndpoint配置"), openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]),
tool_config: Dict = Body({}, description="工具配置"), tool_config: dict = Body({}, description="工具配置", examples=[]),
): ):
async def chat_iterator() -> AsyncIterable[str]: async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db( message_id = add_message_to_db(
@ -128,7 +128,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback = AgentExecutorAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback] callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
openai_config=openai_config, stream=stream) openai_config=openai_config, stream=stream)
tools = [tool for tool in all_tools if tool.name in tool_config] tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools] tools = [t.copy(update={"callbacks": callbacks}) for t in tools]

View File

@ -145,7 +145,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.info("没有可用的插件") st.info("没有可用的插件")
# 传入后端的内容 # 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True tool_use = True
for key in LLM_MODEL_CONFIG: for key in LLM_MODEL_CONFIG:
if key == 'llm_model': if key == 'llm_model':
@ -158,7 +158,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
continue continue
if LLM_MODEL_CONFIG[key]: if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key])) first_key = next(iter(LLM_MODEL_CONFIG[key]))
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具 # 选择工具
selected_tool_configs = {} selected_tool_configs = {}
@ -179,7 +179,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
llm_model = st.session_state["select_model_worker"]['label'] llm_model = st.session_state["select_model_worker"]['label']
if llm_model is not None: if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
@ -223,7 +223,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun() st.rerun()
else: else:
history = get_messages_history( history = get_messages_history(
model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1) chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1)
) )
chat_box.user_say(prompt) chat_box.user_say(prompt)
if files_upload: if files_upload:
@ -253,7 +253,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for d in api.chat_chat(query=prompt, for d in api.chat_chat(query=prompt,
metadata=files_upload, metadata=files_upload,
history=history, history=history,
model_config=model_config, chat_model_config=chat_model_config,
openai_config=openai_config, openai_config=openai_config,
conversation_id=conversation_id, conversation_id=conversation_id,
tool_config=selected_tool_configs, tool_config=selected_tool_configs,

View File

@ -7,9 +7,9 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
from configs import (kbs_config, from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models, list_online_embed_models from server.utils import list_embed_models
import os import os
import time import time
@ -106,10 +106,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type", key="vs_type",
) )
if is_lite: embed_models = list_embed_models()
embed_models = list_online_embed_models()
else:
embed_models = list_embed_models() + list_online_embed_models()
embed_model = cols[1].selectbox( embed_model = cols[1].selectbox(
"Embedding 模型", "Embedding 模型",

View File

@ -265,7 +265,7 @@ class ApiRequest:
history_len: int = -1, history_len: int = -1,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model_config: Dict = None, chat_model_config: Dict = None,
openai_config: Dict = None, openai_config: Dict = None,
tool_config: Dict = None, tool_config: Dict = None,
**kwargs, **kwargs,
@ -280,7 +280,7 @@ class ApiRequest:
"history_len": history_len, "history_len": history_len,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_config": model_config, "chat_model_config": chat_model_config,
"openai_config": openai_config, "openai_config": openai_config,
"tool_config": tool_config, "tool_config": tool_config,
} }