mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-03 13:13:13 +08:00
修复模型选择的bug
This commit is contained in:
parent
253168a187
commit
5714358403
@ -12,7 +12,10 @@ RERANKER_MAX_LENGTH = 1024
|
||||
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
|
||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
|
||||
SUPPORT_AGENT_MODELS = [
|
||||
"chatglm3-6b",
|
||||
"openai-api"
|
||||
]
|
||||
LLM_MODEL_CONFIG = {
|
||||
# 意图识别不需要输出,模型后台知道就行
|
||||
"preprocess_model": {
|
||||
@ -44,7 +47,7 @@ LLM_MODEL_CONFIG = {
|
||||
"max_tokens": 2048,
|
||||
"history_len": 100,
|
||||
"prompt_name": "default",
|
||||
"callbacks": False
|
||||
"callbacks": True
|
||||
},
|
||||
},
|
||||
"action_model": {
|
||||
@ -54,11 +57,10 @@ LLM_MODEL_CONFIG = {
|
||||
"prompt_name": "ChatGLM3",
|
||||
"callbacks": True
|
||||
},
|
||||
"zhipu-api": {
|
||||
"openai-api": {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 2096,
|
||||
"history_len": 5,
|
||||
"prompt_name": "ChatGLM3",
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "GPT-4",
|
||||
"callbacks": True
|
||||
},
|
||||
},
|
||||
|
||||
@ -1,19 +1,19 @@
|
||||
PROMPT_TEMPLATES = {
|
||||
"preprocess_model": {
|
||||
"default":
|
||||
'请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n'
|
||||
'以下几种情况要使用工具,请返回1\n'
|
||||
'1. 实时性的问题,例如天气,日期,地点等信息\n'
|
||||
'根据我们对话的历史,判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。'
|
||||
'以下几种情况要使用工具:\n'
|
||||
'1. 实时性的问题,例如查询天气,日期,地点等信息\n'
|
||||
'2. 需要数学计算的问题\n'
|
||||
'3. 需要查询数据,地点等精确数据\n'
|
||||
'4. 需要行业知识的问题\n'
|
||||
'<question>'
|
||||
'{input}'
|
||||
'</question>'
|
||||
'5. 需要联网的内容\n'
|
||||
'你只要回答一个数字:1代表需要使用工具,你无法为我直接提供服务。0代表不需要使用工具。你应该尽量使用工具\n'
|
||||
'你只能回答0或者1'
|
||||
},
|
||||
"llm_model": {
|
||||
"default":
|
||||
'{{ input }}',
|
||||
'{{input}}',
|
||||
"with_history":
|
||||
'The following is a friendly conversation between a human and an AI. '
|
||||
'The AI is talkative and provides lots of specific details from its context. '
|
||||
|
||||
@ -21,7 +21,9 @@ from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||
|
||||
|
||||
def create_models_from_config(configs: dict = {}, callbacks: list = []):
|
||||
def create_models_from_config(configs, callbacks):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
models = {}
|
||||
prompts = {}
|
||||
for model_type, model_configs in configs.items():
|
||||
@ -40,7 +42,57 @@ def create_models_from_config(configs: dict = {}, callbacks: list = []):
|
||||
return models, prompts
|
||||
|
||||
|
||||
# 在这里写构建逻辑
|
||||
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
|
||||
memory = None
|
||||
chat_prompt = None
|
||||
if history:
|
||||
history = [History.from_data(h) for h in history]
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
elif conversation_id and history_len > 0:
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
||||
message_limit=history_len)
|
||||
else:
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
||||
classifier_chain = (
|
||||
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
return full_chain
|
||||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
@ -55,61 +107,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal history
|
||||
memory = None
|
||||
message_id = None
|
||||
chat_prompt = None
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
||||
conversation_id=conversation_id) if conversation_id else None
|
||||
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
|
||||
if conversation_id:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
|
||||
if history:
|
||||
history = [History.from_data(h) for h in history]
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
elif conversation_id and history_len > 0:
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
||||
message_limit=history_len)
|
||||
else:
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
# 构建完整的Chain
|
||||
full_chain = create_models_chains(prompts=prompts,
|
||||
models=models,
|
||||
conversation_id=conversation_id,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
history=history,
|
||||
history_len=history_len,
|
||||
metadata=metadata)
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
||||
classifier_chain = (
|
||||
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
# 执行完整的Chain
|
||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
@ -132,7 +149,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
agent_finish = ""
|
||||
tool_info = None
|
||||
async for chunk in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.agent_action:
|
||||
tool_info = {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
from typing import List
|
||||
@ -18,7 +18,14 @@ def list_running_models(
|
||||
r = client.post(controller_address + "/list_models")
|
||||
models = r.json()["models"]
|
||||
data = {m: get_model_config(m).data for m in models}
|
||||
return BaseResponse(data=data)
|
||||
|
||||
## 只有LLM模型才返回
|
||||
result = {}
|
||||
for model, config in data.items():
|
||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
||||
result[model] = config
|
||||
|
||||
return BaseResponse(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
@ -36,10 +43,17 @@ def list_config_models(
|
||||
从本地获取configs中配置的模型列表
|
||||
'''
|
||||
data = {}
|
||||
result = {}
|
||||
|
||||
for type, models in list_config_llm_models().items():
|
||||
if type in types:
|
||||
data[type] = {m: get_model_config(m).data for m in models}
|
||||
return BaseResponse(data=data)
|
||||
|
||||
for model, config in data.items():
|
||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
||||
result[type][model] = config
|
||||
|
||||
return BaseResponse(data=result)
|
||||
|
||||
|
||||
def get_model_config(
|
||||
|
||||
BIN
teaser.png
Normal file
BIN
teaser.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.0 MiB |
32
test.py
Normal file
32
test.py
Normal file
@ -0,0 +1,32 @@
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# 加载模型和处理器
|
||||
model_id = "/data/models/fuyu-8b"
|
||||
processor = FuyuProcessor.from_pretrained(model_id)
|
||||
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16)
|
||||
|
||||
# 将模型转换为 bf16
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
|
||||
# 准备模型的输入
|
||||
# text_prompt = "According to this chart, which model performs best?\n"
|
||||
|
||||
text_prompt = "Generate a coco-style caption.\n"
|
||||
image = Image.open("1.png").convert("RGB")
|
||||
|
||||
while True:
|
||||
# 获取用户输入的文本提示
|
||||
text_prompt = input("请输入文本提示: ")
|
||||
if text_prompt.lower() == 'exit':
|
||||
break
|
||||
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
|
||||
|
||||
# 生成输出
|
||||
generation_output = model.generate(**inputs, max_new_tokens=7)
|
||||
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
|
||||
|
||||
# 打印生成的文本
|
||||
print(generation_text)
|
||||
@ -1,4 +1,6 @@
|
||||
import streamlit as st
|
||||
|
||||
from webui_pages.dialogue.utils import process_files
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
@ -187,7 +189,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
key="llm_model",
|
||||
)
|
||||
|
||||
|
||||
# 传入后端的内容
|
||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
|
||||
@ -201,10 +202,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if llm_model is not None:
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
|
||||
print(model_config)
|
||||
files = st.file_uploader("上传附件",
|
||||
type=[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True)
|
||||
files_upload = process_files(files=files) if files else None
|
||||
print(len(files_upload)) if files_upload else None
|
||||
|
||||
# if dialogue_mode == "文件对话":
|
||||
# with st.expander("文件对话配置", True):
|
||||
@ -218,7 +220,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
|
||||
chat_box.output_messages()
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
@ -227,6 +228,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
message_id: str = "",
|
||||
history_index: int = -1,
|
||||
):
|
||||
|
||||
reason = feedback["text"]
|
||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
api.chat_feedback(message_id=message_id,
|
||||
@ -252,6 +254,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
message_id = ""
|
||||
element_index = 0
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
model_config=model_config,
|
||||
conversation_id=conversation_id,
|
||||
|
||||
29
webui_pages/dialogue/utils.py
Normal file
29
webui_pages/dialogue/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def encode_file_to_base64(file):
|
||||
# 将文件内容转换为 Base64 编码
|
||||
buffer = BytesIO()
|
||||
buffer.write(file.read())
|
||||
return base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
|
||||
def process_files(files):
|
||||
result = {"videos": [], "images": []}
|
||||
|
||||
for file in files:
|
||||
file_extension = os.path.splitext(file.name)[1].lower()
|
||||
|
||||
# 检测文件类型并进行相应的处理
|
||||
if file_extension in ['.mp4', '.avi']:
|
||||
# 视频文件处理
|
||||
video_base64 = encode_file_to_base64(file)
|
||||
result["videos"].append(video_base64)
|
||||
elif file_extension in ['.jpg', '.png', '.jpeg']:
|
||||
# 图像文件处理
|
||||
image_base64 = encode_file_to_base64(file)
|
||||
result["images"].append(image_base64)
|
||||
|
||||
return result
|
||||
@ -260,6 +260,7 @@ class ApiRequest:
|
||||
def chat_chat(
|
||||
self,
|
||||
query: str,
|
||||
metadata: dict,
|
||||
conversation_id: str = None,
|
||||
history_len: int = -1,
|
||||
history: List[Dict] = [],
|
||||
@ -273,6 +274,7 @@ class ApiRequest:
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"metadata": metadata,
|
||||
"conversation_id": conversation_id,
|
||||
"history_len": history_len,
|
||||
"history": history,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user