diff --git a/1.png b/1.png new file mode 100644 index 00000000..3979555c Binary files /dev/null and b/1.png differ diff --git a/bus.png b/bus.png new file mode 100644 index 00000000..8994abc4 Binary files /dev/null and b/bus.png differ diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 50314e56..57d4fa09 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -12,7 +12,10 @@ RERANKER_MAX_LENGTH = 1024 # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" - +SUPPORT_AGENT_MODELS = [ + "chatglm3-6b", + "openai-api" +] LLM_MODEL_CONFIG = { # 意图识别不需要输出,模型后台知道就行 "preprocess_model": { @@ -44,7 +47,7 @@ LLM_MODEL_CONFIG = { "max_tokens": 2048, "history_len": 100, "prompt_name": "default", - "callbacks": False + "callbacks": True }, }, "action_model": { @@ -54,11 +57,10 @@ LLM_MODEL_CONFIG = { "prompt_name": "ChatGLM3", "callbacks": True }, - "zhipu-api": { + "openai-api": { "temperature": 0.01, - "max_tokens": 2096, - "history_len": 5, - "prompt_name": "ChatGLM3", + "max_tokens": 4096, + "prompt_name": "GPT-4", "callbacks": True }, }, diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index e7def5a6..de83d06e 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -1,19 +1,19 @@ PROMPT_TEMPLATES = { "preprocess_model": { "default": - '请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n' - '以下几种情况要使用工具,请返回1\n' - '1. 实时性的问题,例如天气,日期,地点等信息\n' + '根据我们对话的历史,判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。' + '以下几种情况要使用工具:\n' + '1. 实时性的问题,例如查询天气,日期,地点等信息\n' '2. 需要数学计算的问题\n' '3. 需要查询数据,地点等精确数据\n' '4. 需要行业知识的问题\n' - '' - '{input}' - '' + '5. 需要联网的内容\n' + '你只要回答一个数字:1代表需要使用工具,你无法为我直接提供服务。0代表不需要使用工具。你应该尽量使用工具\n' + '你只能回答0或者1' }, "llm_model": { "default": - '{{ input }}', + '{{input}}', "with_history": 'The following is a friendly conversation between a human and an AI. ' 'The AI is talkative and provides lots of specific details from its context. ' diff --git a/server/chat/chat.py b/server/chat/chat.py index ccf0982b..1dd92fab 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -21,7 +21,9 @@ from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler -def create_models_from_config(configs: dict = {}, callbacks: list = []): +def create_models_from_config(configs, callbacks): + if configs is None: + configs = {} models = {} prompts = {} for model_type, model_configs in configs.items(): @@ -40,7 +42,57 @@ def create_models_from_config(configs: dict = {}, callbacks: list = []): return models, prompts +# 在这里写构建逻辑 +def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata): + memory = None + chat_prompt = None + if history: + history = [History.from_data(h) for h in history] + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_template() for i in history] + [input_msg]) + elif conversation_id and history_len > 0: + memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"], + message_limit=history_len) + else: + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages([input_msg]) + + chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory) + classifier_chain = ( + PromptTemplate.from_template(prompts["preprocess_model"]) + | models["preprocess_model"] + | StrOutputParser() + ) + if "chatglm3" in models["action_model"].model_name.lower(): + agent_executor = initialize_glm3_agent( + llm=models["action_model"], + tools=tools, + prompt=prompts["action_model"], + input_variables=["input", "intermediate_steps", "history"], + memory=memory, + callback_manager=BaseCallbackManager(handlers=callbacks), + verbose=True, + ) + else: + agent_executor = initialize_agent( + llm=models["action_model"], + tools=tools, + callbacks=callbacks, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + memory=memory, + verbose=True, + ) + branch = RunnableBranch( + (lambda x: "1" in x["topic"].lower(), agent_executor), + chain + ) + full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) + return full_chain + + async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), conversation_id: str = Body("", description="对话框ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: Union[int, List[History]] = Body([], @@ -55,61 +107,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: - nonlocal history - memory = None - message_id = None - chat_prompt = None + message_id = add_message_to_db(chat_type="llm_chat", query=query, + conversation_id=conversation_id) if conversation_id else None callback = CustomAsyncIteratorCallbackHandler() callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) - if conversation_id: - message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) tools = [tool for tool in all_tools if tool.name in tool_config] - if history: - history = [History.from_data(h) for h in history] - input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) - chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) - elif conversation_id and history_len > 0: - memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"], - message_limit=history_len) - else: - input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) - chat_prompt = ChatPromptTemplate.from_messages([input_msg]) + # 构建完整的Chain + full_chain = create_models_chains(prompts=prompts, + models=models, + conversation_id=conversation_id, + tools=tools, + callbacks=callbacks, + history=history, + history_len=history_len, + metadata=metadata) - chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory) - classifier_chain = ( - PromptTemplate.from_template(prompts["preprocess_model"]) - | models["preprocess_model"] - | StrOutputParser() - ) - if "chatglm3" in models["action_model"].model_name.lower(): - agent_executor = initialize_glm3_agent( - llm=models["action_model"], - tools=tools, - prompt=prompts["action_model"], - input_variables=["input", "intermediate_steps", "history"], - memory=memory, - callback_manager=BaseCallbackManager(handlers=callbacks), - verbose=True, - ) - else: - agent_executor = initialize_agent( - llm=models["action_model"], - tools=tools, - callbacks=callbacks, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - memory=memory, - verbose=True, - ) - branch = RunnableBranch( - (lambda x: "1" in x["topic"].lower(), agent_executor), - chain - ) - full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) + # 执行完整的Chain task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) if stream: async for chunk in callback.aiter(): @@ -132,7 +149,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 agent_finish = "" tool_info = None async for chunk in callback.aiter(): - # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.agent_action: tool_info = { diff --git a/server/llm_api.py b/server/llm_api.py index d642eeac..21410fc7 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,5 +1,5 @@ from fastapi import Body -from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT +from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) from typing import List @@ -18,7 +18,14 @@ def list_running_models( r = client.post(controller_address + "/list_models") models = r.json()["models"] data = {m: get_model_config(m).data for m in models} - return BaseResponse(data=data) + + ## 只有LLM模型才返回 + result = {} + for model, config in data.items(): + if model in LLM_MODEL_CONFIG['llm_model']: + result[model] = config + + return BaseResponse(data=result) except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if log_verbose else None) @@ -36,10 +43,17 @@ def list_config_models( 从本地获取configs中配置的模型列表 ''' data = {} + result = {} + for type, models in list_config_llm_models().items(): if type in types: data[type] = {m: get_model_config(m).data for m in models} - return BaseResponse(data=data) + + for model, config in data.items(): + if model in LLM_MODEL_CONFIG['llm_model']: + result[type][model] = config + + return BaseResponse(data=result) def get_model_config( diff --git a/teaser.png b/teaser.png new file mode 100644 index 00000000..c01532b6 Binary files /dev/null and b/teaser.png differ diff --git a/test.py b/test.py new file mode 100644 index 00000000..b4b73e23 --- /dev/null +++ b/test.py @@ -0,0 +1,32 @@ +from transformers import FuyuProcessor, FuyuForCausalLM +from PIL import Image +import requests +import torch + +# 加载模型和处理器 +model_id = "/data/models/fuyu-8b" +processor = FuyuProcessor.from_pretrained(model_id) +model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16) + +# 将模型转换为 bf16 +model = model.to(dtype=torch.bfloat16) + +# 准备模型的输入 +# text_prompt = "According to this chart, which model performs best?\n" + +text_prompt = "Generate a coco-style caption.\n" +image = Image.open("1.png").convert("RGB") + +while True: + # 获取用户输入的文本提示 + text_prompt = input("请输入文本提示: ") + if text_prompt.lower() == 'exit': + break + inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0") + + # 生成输出 + generation_output = model.generate(**inputs, max_new_tokens=7) + generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True) + + # 打印生成的文本 + print(generation_text) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index be4934c7..2e82ec23 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -1,4 +1,6 @@ import streamlit as st + +from webui_pages.dialogue.utils import process_files from webui_pages.utils import * from streamlit_chatbox import * from streamlit_modal import Modal @@ -187,7 +189,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): key="llm_model", ) - # 传入后端的内容 model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} @@ -201,10 +202,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): if llm_model is not None: model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model] - print(model_config) files = st.file_uploader("上传附件", type=[i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True) + files_upload = process_files(files=files) if files else None + print(len(files_upload)) if files_upload else None # if dialogue_mode == "文件对话": # with st.expander("文件对话配置", True): @@ -218,7 +220,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): # st.session_state["file_chat_id"] = upload_temp_docs(files, api) # Display chat messages from history on app rerun - chat_box.output_messages() chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " @@ -227,6 +228,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): message_id: str = "", history_index: int = -1, ): + reason = feedback["text"] score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) api.chat_feedback(message_id=message_id, @@ -252,6 +254,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): message_id = "" element_index = 0 for d in api.chat_chat(query=prompt, + metadata=files_upload, history=history, model_config=model_config, conversation_id=conversation_id, diff --git a/webui_pages/dialogue/utils.py b/webui_pages/dialogue/utils.py new file mode 100644 index 00000000..07384325 --- /dev/null +++ b/webui_pages/dialogue/utils.py @@ -0,0 +1,29 @@ +import base64 +import os +from io import BytesIO + + +def encode_file_to_base64(file): + # 将文件内容转换为 Base64 编码 + buffer = BytesIO() + buffer.write(file.read()) + return base64.b64encode(buffer.getvalue()).decode() + + +def process_files(files): + result = {"videos": [], "images": []} + + for file in files: + file_extension = os.path.splitext(file.name)[1].lower() + + # 检测文件类型并进行相应的处理 + if file_extension in ['.mp4', '.avi']: + # 视频文件处理 + video_base64 = encode_file_to_base64(file) + result["videos"].append(video_base64) + elif file_extension in ['.jpg', '.png', '.jpeg']: + # 图像文件处理 + image_base64 = encode_file_to_base64(file) + result["images"].append(image_base64) + + return result \ No newline at end of file diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 834873d4..0c0e4dd8 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -260,6 +260,7 @@ class ApiRequest: def chat_chat( self, query: str, + metadata: dict, conversation_id: str = None, history_len: int = -1, history: List[Dict] = [], @@ -273,6 +274,7 @@ class ApiRequest: ''' data = { "query": query, + "metadata": metadata, "conversation_id": conversation_id, "history_len": history_len, "history": history,