修复模型选择的bug

This commit is contained in:
zR 2023-12-06 19:53:29 +08:00 committed by liunux4odoo
parent 253168a187
commit 5714358403
11 changed files with 166 additions and 68 deletions

BIN
1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

BIN
bus.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 551 KiB

View File

@ -12,7 +12,10 @@ RERANKER_MAX_LENGTH = 1024
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
SUPPORT_AGENT_MODELS = [
"chatglm3-6b",
"openai-api"
]
LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
@ -44,7 +47,7 @@ LLM_MODEL_CONFIG = {
"max_tokens": 2048,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
"callbacks": True
},
},
"action_model": {
@ -54,11 +57,10 @@ LLM_MODEL_CONFIG = {
"prompt_name": "ChatGLM3",
"callbacks": True
},
"zhipu-api": {
"openai-api": {
"temperature": 0.01,
"max_tokens": 2096,
"history_len": 5,
"prompt_name": "ChatGLM3",
"max_tokens": 4096,
"prompt_name": "GPT-4",
"callbacks": True
},
},

View File

@ -1,19 +1,19 @@
PROMPT_TEMPLATES = {
"preprocess_model": {
"default":
'请你根据我的描述和我们对话的历史,判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 01代表需要使用工具0代表不需要使用工具。\n'
'以下几种情况要使用工具,请返回1\n'
'1. 实时性的问题,例如天气,日期,地点等信息\n'
'根据我们对话的历史,判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。'
'以下几种情况要使用工具:\n'
'1. 实时性的问题,例如查询天气,日期,地点等信息\n'
'2. 需要数学计算的问题\n'
'3. 需要查询数据,地点等精确数据\n'
'4. 需要行业知识的问题\n'
'<question>'
'{input}'
'</question>'
'5. 需要联网的内容\n'
'你只要回答一个数字:1代表需要使用工具你无法为我直接提供服务。0代表不需要使用工具。你应该尽量使用工具\n'
'你只能回答0或者1'
},
"llm_model": {
"default":
'{{ input }}',
'{{input}}',
"with_history":
'The following is a friendly conversation between a human and an AI. '
'The AI is talkative and provides lots of specific details from its context. '

View File

@ -21,7 +21,9 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
def create_models_from_config(configs: dict = {}, callbacks: list = []):
def create_models_from_config(configs, callbacks):
if configs is None:
configs = {}
models = {}
prompts = {}
for model_type, model_configs in configs.items():
@ -40,7 +42,57 @@ def create_models_from_config(configs: dict = {}, callbacks: list = []):
return models, prompts
# 在这里写构建逻辑
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
memory = None
chat_prompt = None
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
else:
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
classifier_chain = (
PromptTemplate.from_template(prompts["preprocess_model"])
| models["preprocess_model"]
| StrOutputParser()
)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
branch = RunnableBranch(
(lambda x: "1" in x["topic"].lower(), agent_executor),
chain
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
return full_chain
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
@ -55,61 +107,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal history
memory = None
message_id = None
chat_prompt = None
message_id = add_message_to_db(chat_type="llm_chat", query=query,
conversation_id=conversation_id) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
if conversation_id:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
tools = [tool for tool in all_tools if tool.name in tool_config]
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
else:
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
# 构建完整的Chain
full_chain = create_models_chains(prompts=prompts,
models=models,
conversation_id=conversation_id,
tools=tools,
callbacks=callbacks,
history=history,
history_len=history_len,
metadata=metadata)
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
classifier_chain = (
PromptTemplate.from_template(prompts["preprocess_model"])
| models["preprocess_model"]
| StrOutputParser()
)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
branch = RunnableBranch(
(lambda x: "1" in x["topic"].lower(), agent_executor),
chain
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
# 执行完整的Chain
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
@ -132,7 +149,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
agent_finish = ""
tool_info = None
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.agent_action:
tool_info = {

View File

@ -1,5 +1,5 @@
from fastapi import Body
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from typing import List
@ -18,7 +18,14 @@ def list_running_models(
r = client.post(controller_address + "/list_models")
models = r.json()["models"]
data = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data)
## 只有LLM模型才返回
result = {}
for model, config in data.items():
if model in LLM_MODEL_CONFIG['llm_model']:
result[model] = config
return BaseResponse(data=result)
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
@ -36,10 +43,17 @@ def list_config_models(
从本地获取configs中配置的模型列表
'''
data = {}
result = {}
for type, models in list_config_llm_models().items():
if type in types:
data[type] = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data)
for model, config in data.items():
if model in LLM_MODEL_CONFIG['llm_model']:
result[type][model] = config
return BaseResponse(data=result)
def get_model_config(

BIN
teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

32
test.py Normal file
View File

@ -0,0 +1,32 @@
from transformers import FuyuProcessor, FuyuForCausalLM
from PIL import Image
import requests
import torch
# 加载模型和处理器
model_id = "/data/models/fuyu-8b"
processor = FuyuProcessor.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16)
# 将模型转换为 bf16
model = model.to(dtype=torch.bfloat16)
# 准备模型的输入
# text_prompt = "According to this chart, which model performs best?\n"
text_prompt = "Generate a coco-style caption.\n"
image = Image.open("1.png").convert("RGB")
while True:
# 获取用户输入的文本提示
text_prompt = input("请输入文本提示: ")
if text_prompt.lower() == 'exit':
break
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
# 生成输出
generation_output = model.generate(**inputs, max_new_tokens=7)
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
# 打印生成的文本
print(generation_text)

View File

@ -1,4 +1,6 @@
import streamlit as st
from webui_pages.dialogue.utils import process_files
from webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
@ -187,7 +189,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
key="llm_model",
)
# 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
@ -201,10 +202,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
print(model_config)
files = st.file_uploader("上传附件",
type=[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True)
files_upload = process_files(files=files) if files else None
print(len(files_upload)) if files_upload else None
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
@ -218,7 +220,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
@ -227,6 +228,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
message_id: str = "",
history_index: int = -1,
):
reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
api.chat_feedback(message_id=message_id,
@ -252,6 +254,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
message_id = ""
element_index = 0
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
model_config=model_config,
conversation_id=conversation_id,

View File

@ -0,0 +1,29 @@
import base64
import os
from io import BytesIO
def encode_file_to_base64(file):
# 将文件内容转换为 Base64 编码
buffer = BytesIO()
buffer.write(file.read())
return base64.b64encode(buffer.getvalue()).decode()
def process_files(files):
result = {"videos": [], "images": []}
for file in files:
file_extension = os.path.splitext(file.name)[1].lower()
# 检测文件类型并进行相应的处理
if file_extension in ['.mp4', '.avi']:
# 视频文件处理
video_base64 = encode_file_to_base64(file)
result["videos"].append(video_base64)
elif file_extension in ['.jpg', '.png', '.jpeg']:
# 图像文件处理
image_base64 = encode_file_to_base64(file)
result["images"].append(image_base64)
return result

View File

@ -260,6 +260,7 @@ class ApiRequest:
def chat_chat(
self,
query: str,
metadata: dict,
conversation_id: str = None,
history_len: int = -1,
history: List[Dict] = [],
@ -273,6 +274,7 @@ class ApiRequest:
'''
data = {
"query": query,
"metadata": metadata,
"conversation_id": conversation_id,
"history_len": history_len,
"history": history,