mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
支持多模态Grounding
1. 美化了chat的代码 2. 支持视觉工具输出Grounding任务 3. 完善工具调用的流程
This commit is contained in:
parent
03891cc27a
commit
d44ce6ce21
@ -5,4 +5,4 @@ from .server_config import *
|
|||||||
from .prompt_config import *
|
from .prompt_config import *
|
||||||
|
|
||||||
|
|
||||||
VERSION = "v0.2.10"
|
VERSION = "v0.3.0-preview"
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
# 用于批量将configs下的.example文件复制并命名为.py文件
|
"""
|
||||||
|
用于批量将configs下的.example文件复制并命名为.py文件
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# API requirements
|
# API requirements
|
||||||
|
|
||||||
langchain>=0.0.348
|
langchain>=0.0.350
|
||||||
langchain-experimental>=0.0.42
|
langchain-experimental>=0.0.42
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# API requirements
|
# API requirements
|
||||||
|
|
||||||
langchain>=0.0.346
|
langchain>=0.0.350
|
||||||
langchain-experimental>=0.0.42
|
langchain-experimental>=0.0.42
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
|
|||||||
@ -8,8 +8,6 @@ class ModelContainer:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.metadata = None
|
self.metadata = None
|
||||||
|
|
||||||
self.metadata_response = None
|
|
||||||
|
|
||||||
self.vision_model = None
|
self.vision_model = None
|
||||||
self.vision_tokenizer = None
|
self.vision_tokenizer = None
|
||||||
self.audio_tokenizer = None
|
self.audio_tokenizer = None
|
||||||
@ -29,12 +27,15 @@ class ModelContainer:
|
|||||||
if TOOL_CONFIG["aqa_processor"]["use"]:
|
if TOOL_CONFIG["aqa_processor"]["use"]:
|
||||||
self.audio_tokenizer = AutoTokenizer.from_pretrained(
|
self.audio_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
|
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
|
||||||
trust_remote_code=True)
|
trust_remote_code=True
|
||||||
|
)
|
||||||
self.audio_model = AutoModelForCausalLM.from_pretrained(
|
self.audio_model = AutoModelForCausalLM.from_pretrained(
|
||||||
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
|
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
|
trust_remote_code=True).to(
|
||||||
|
TOOL_CONFIG["aqa_processor"]["device"]
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
|
||||||
container = ModelContainer()
|
container = ModelContainer()
|
||||||
|
|||||||
@ -4,9 +4,59 @@ Method Use cogagent to generate response for a given image and query.
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image, ImageDraw
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from configs import TOOL_CONFIG
|
from configs import TOOL_CONFIG
|
||||||
|
import re
|
||||||
|
from server.agent.container import container
|
||||||
|
|
||||||
|
|
||||||
|
def extract_between_markers(text, start_marker, end_marker):
|
||||||
|
"""
|
||||||
|
Extracts and returns the portion of the text that is between 'start_marker' and 'end_marker'.
|
||||||
|
"""
|
||||||
|
start = text.find(start_marker)
|
||||||
|
end = text.find(end_marker, start)
|
||||||
|
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
# Extract and return the text between the markers, without including the markers themselves
|
||||||
|
return text[start + len(start_marker):end].strip()
|
||||||
|
else:
|
||||||
|
return "Text not found between the specified markers"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def draw_box_on_existing_image(base64_image, text):
|
||||||
|
"""
|
||||||
|
在已有的Base64编码的图片上根据“Grounded Operation”中的坐标信息绘制矩形框。
|
||||||
|
假设坐标是经过缩放的比例坐标。
|
||||||
|
"""
|
||||||
|
# 解码并打开Base64编码的图片
|
||||||
|
img = Image.open(BytesIO(base64.b64decode(base64_image)))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 提取“Grounded Operation”后的坐标
|
||||||
|
pattern = r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]"
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
coords = tuple(map(int, match.groups()))
|
||||||
|
scaled_coords = (
|
||||||
|
int(coords[0] * 0.001 * img.width),
|
||||||
|
int(coords[1] * 0.001 * img.height),
|
||||||
|
int(coords[2] * 0.001 * img.width),
|
||||||
|
int(coords[3] * 0.001 * img.height)
|
||||||
|
)
|
||||||
|
draw.rectangle(scaled_coords, outline="red", width=3)
|
||||||
|
|
||||||
|
buffered = BytesIO()
|
||||||
|
img.save(buffered, format="JPEG")
|
||||||
|
img.save("tmp/image.jpg")
|
||||||
|
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
|
||||||
|
return img_base64
|
||||||
|
|
||||||
|
|
||||||
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
|
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
|
||||||
temperature=1.0):
|
temperature=1.0):
|
||||||
@ -33,8 +83,9 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
|||||||
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
|
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
|
||||||
'cross_images'] else None,
|
'cross_images'] else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
gen_kwargs = {"max_length": max_length,
|
gen_kwargs = {"max_length": max_length,
|
||||||
"temperature": temperature,
|
# "temperature": temperature,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"do_sample": False}
|
"do_sample": False}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -47,15 +98,23 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
|||||||
|
|
||||||
|
|
||||||
def vqa_processor(query: str):
|
def vqa_processor(query: str):
|
||||||
from server.agent.container import container
|
|
||||||
tool_config = TOOL_CONFIG["vqa_processor"]
|
tool_config = TOOL_CONFIG["vqa_processor"]
|
||||||
# model, tokenizer = load_model(model_path=tool_config["model_path"],
|
|
||||||
# tokenizer_path=tool_config["tokenizer_path"],
|
|
||||||
# device=tool_config["device"])
|
|
||||||
if container.metadata["images"]:
|
if container.metadata["images"]:
|
||||||
image_base64 = container.metadata["images"][0]
|
image_base64 = container.metadata["images"][0]
|
||||||
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
|
ans = vqa_run(model=container.vision_model,
|
||||||
device=tool_config["device"])
|
tokenizer=container.vision_tokenizer,
|
||||||
|
query=query + "(with grounding)",
|
||||||
|
image_base_64=image_base64,
|
||||||
|
device=tool_config["device"])
|
||||||
|
print(ans)
|
||||||
|
image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans)
|
||||||
|
|
||||||
|
# Markers
|
||||||
|
# start_marker = "Next Action:draw_box_on_existing_image
|
||||||
|
# end_marker = "Grounded Operation:"
|
||||||
|
# ans = extract_between_markers(ans, start_marker, end_marker)
|
||||||
|
|
||||||
|
return ans
|
||||||
else:
|
else:
|
||||||
return "No Image, Please Try Again"
|
return "No Image, Please Try Again"
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class ConversationCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
# 如果想存更多信息,则prompts 也需要持久化
|
# TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
|||||||
@ -1,25 +1,28 @@
|
|||||||
from fastapi import Body
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from langchain.agents import initialize_agent, AgentType
|
|
||||||
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
|
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
|
||||||
from langchain_core.runnables import RunnableBranch
|
|
||||||
from server.agent.agent_factory import initialize_glm3_agent
|
|
||||||
from server.agent.tools_factory.tools_registry import all_tools
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from typing import AsyncIterable, Dict
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import List, Union, AsyncIterable, Dict
|
||||||
|
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from langchain.agents import initialize_agent, AgentType
|
||||||
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import RunnableBranch
|
||||||
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List, Union
|
|
||||||
from server.chat.utils import History
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from server.utils import get_prompt_template
|
|
||||||
|
from server.agent.agent_factory import initialize_glm3_agent
|
||||||
|
from server.agent.tools_factory.tools_registry import all_tools
|
||||||
|
from server.agent.container import container
|
||||||
|
|
||||||
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
|
from server.chat.utils import History
|
||||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||||
from server.db.repository import add_message_to_db
|
from server.db.repository import add_message_to_db
|
||||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||||
from server.agent.container import container
|
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks):
|
def create_models_from_config(configs, callbacks):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
@ -47,6 +50,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
memory = None
|
memory = None
|
||||||
chat_prompt = None
|
chat_prompt = None
|
||||||
container.metadata = metadata
|
container.metadata = metadata
|
||||||
|
|
||||||
if history:
|
if history:
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||||
@ -102,23 +106,35 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
||||||
conversation_id: str = Body("", description="对话框ID"),
|
conversation_id: str = Body("", description="对话框ID"),
|
||||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||||
history: Union[int, List[History]] = Body([],
|
history: Union[int, List[History]] = Body(
|
||||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
[],
|
||||||
examples=[[
|
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||||
{"role": "user",
|
examples=[
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
[
|
||||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
{"role": "user",
|
||||||
),
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
|
{"role": "assistant", "content": "虎头虎脑"}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_config: Dict = Body({}, description="LLM 模型配置。"),
|
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||||
tool_config: Dict = Body({}, description="工具配置"),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
message_id = add_message_to_db(
|
||||||
conversation_id=conversation_id) if conversation_id else None
|
chat_type="llm_chat",
|
||||||
|
query=query,
|
||||||
|
conversation_id=conversation_id
|
||||||
|
) if conversation_id else None
|
||||||
|
|
||||||
callback = CustomAsyncIteratorCallbackHandler()
|
callback = CustomAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
|
|
||||||
|
# 从配置中选择模型
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||||
|
|
||||||
|
# 从配置中选择工具
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
|
|
||||||
# 构建完整的Chain
|
# 构建完整的Chain
|
||||||
@ -131,7 +147,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
history_len=history_len,
|
history_len=history_len,
|
||||||
metadata=metadata)
|
metadata=metadata)
|
||||||
|
|
||||||
# 执行完整的Chain
|
# Execute Chain
|
||||||
|
|
||||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||||
if stream:
|
if stream:
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
@ -166,12 +183,22 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
text += data["llm_token"]
|
text += data["llm_token"]
|
||||||
if tool_info:
|
if tool_info:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
|
{
|
||||||
ensure_ascii=False)
|
"text": text,
|
||||||
|
"agent_action": tool_info,
|
||||||
|
"agent_finish": agent_finish,
|
||||||
|
"message_id": message_id
|
||||||
|
},
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{"text": text, "message_id": message_id},
|
{
|
||||||
ensure_ascii=False)
|
"text": text,
|
||||||
|
"message_id": message_id
|
||||||
|
},
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(chat_iterator())
|
return EventSourceResponse(chat_iterator())
|
||||||
|
|||||||
@ -22,7 +22,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
|
|
||||||
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||||
async def completion_iterator(query: str,
|
async def completion_iterator(query: str,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
prompt_name: str = prompt_name,
|
prompt_name: str = prompt_name,
|
||||||
|
|||||||
@ -49,10 +49,6 @@ for model_category in LLM_MODEL_CONFIG.values():
|
|||||||
|
|
||||||
all_model_names_list = list(all_model_names)
|
all_model_names_list = list(all_model_names)
|
||||||
|
|
||||||
@deprecated(
|
|
||||||
since="0.3.0",
|
|
||||||
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
|
|
||||||
removal="0.3.0")
|
|
||||||
def create_controller_app(
|
def create_controller_app(
|
||||||
dispatch_method: str,
|
dispatch_method: str,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
@ -111,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
worker_addr=args.worker_address)
|
worker_addr=args.worker_address)
|
||||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||||
|
|
||||||
# 本地模型
|
# 本地模型
|
||||||
else:
|
else:
|
||||||
from configs.model_config import VLLM_MODEL_DICT
|
from configs.model_config import VLLM_MODEL_DICT
|
||||||
@ -876,7 +873,9 @@ async def start_main_server():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
create_tables()
|
create_tables()
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
else:
|
else:
|
||||||
@ -887,4 +886,3 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_until_complete(start_main_server())
|
loop.run_until_complete(start_main_server())
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
|||||||
from server.knowledge_base.utils import LOADER_DICT
|
from server.knowledge_base.utils import LOADER_DICT
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
@ -206,7 +205,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||||
|
|
||||||
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
||||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||||
|
|
||||||
@ -252,20 +251,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
history = get_messages_history(
|
history = get_messages_history(
|
||||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||||
chat_box.user_say(prompt)
|
chat_box.user_say(prompt)
|
||||||
if files_upload["images"]:
|
if files_upload:
|
||||||
# 显示第一个上传的图像
|
if files_upload["images"]:
|
||||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||||
unsafe_allow_html=True)
|
unsafe_allow_html=True)
|
||||||
elif files_upload["videos"]:
|
elif files_upload["videos"]:
|
||||||
# 显示第一个上传的视频
|
st.markdown(
|
||||||
st.markdown(
|
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||||
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
unsafe_allow_html=True)
|
||||||
unsafe_allow_html=True)
|
elif files_upload["audios"]:
|
||||||
elif files_upload["audios"]:
|
st.markdown(
|
||||||
# 播放第一个上传的音频
|
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||||
st.markdown(
|
unsafe_allow_html=True)
|
||||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
|
||||||
unsafe_allow_html=True)
|
|
||||||
|
|
||||||
chat_box.ai_say("正在思考...")
|
chat_box.ai_say("正在思考...")
|
||||||
text = ""
|
text = ""
|
||||||
@ -307,10 +304,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
text = chunk
|
text = chunk
|
||||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||||
|
|
||||||
chat_box.show_feedback(**feedback_kwargs,
|
if os.path.exists("tmp/image.jpg"):
|
||||||
key=message_id,
|
with open("tmp/image.jpg", "rb") as image_file:
|
||||||
on_submit=on_feedback,
|
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
img_tag = f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
|
||||||
|
st.markdown(img_tag, unsafe_allow_html=True)
|
||||||
|
os.remove("tmp/image.jpg")
|
||||||
|
# chat_box.show_feedback(**feedback_kwargs,
|
||||||
|
# key=message_id,
|
||||||
|
# on_submit=on_feedback,
|
||||||
|
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||||
|
|
||||||
# elif dialogue_mode == "文件对话":
|
# elif dialogue_mode == "文件对话":
|
||||||
# if st.session_state["file_chat_id"] is None:
|
# if st.session_state["file_chat_id"] is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user