fix model_config是系统关键词问题

This commit is contained in:
glide-the 2024-01-25 01:58:23 +08:00 committed by liunux4odoo
parent 4e358db525
commit 217bb61448
4 changed files with 17 additions and 20 deletions

View File

@ -1,6 +1,6 @@
import asyncio
import json
from typing import AsyncIterable, List, Union, Dict
from typing import AsyncIterable, List, Union, Dict, Annotated
from fastapi import Body
from fastapi.responses import StreamingResponse
@ -103,7 +103,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body(
history: List[History] = Body(
[],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[
@ -115,9 +115,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
]
),
stream: bool = Body(True, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置"),
openai_config: Dict = Body({}, description="openaiEndpoint配置"),
tool_config: Dict = Body({}, description="工具配置"),
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(
@ -128,7 +128,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config,
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
openai_config=openai_config, stream=stream)
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]

View File

@ -145,7 +145,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.info("没有可用的插件")
# 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
@ -158,7 +158,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具
selected_tool_configs = {}
@ -179,7 +179,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
llm_model = st.session_state["select_model_worker"]['label']
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
@ -223,7 +223,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun()
else:
history = get_messages_history(
model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1)
chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1)
)
chat_box.user_say(prompt)
if files_upload:
@ -253,7 +253,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
model_config=model_config,
chat_model_config=chat_model_config,
openai_config=openai_config,
conversation_id=conversation_id,
tool_config=selected_tool_configs,

View File

@ -7,9 +7,9 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models, list_online_embed_models
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models
import os
import time
@ -106,10 +106,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type",
)
if is_lite:
embed_models = list_online_embed_models()
else:
embed_models = list_embed_models() + list_online_embed_models()
embed_models = list_embed_models()
embed_model = cols[1].selectbox(
"Embedding 模型",

View File

@ -265,7 +265,7 @@ class ApiRequest:
history_len: int = -1,
history: List[Dict] = [],
stream: bool = True,
model_config: Dict = None,
chat_model_config: Dict = None,
openai_config: Dict = None,
tool_config: Dict = None,
**kwargs,
@ -280,7 +280,7 @@ class ApiRequest:
"history_len": history_len,
"history": history,
"stream": stream,
"model_config": model_config,
"chat_model_config": chat_model_config,
"openai_config": openai_config,
"tool_config": tool_config,
}