mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 15:38:27 +08:00
支持XPU,修改了glm3部分agent
This commit is contained in:
parent
d44ce6ce21
commit
6d3d99639e
@ -3,14 +3,6 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import yaml
|
|
||||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
|
||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
|
||||||
import os
|
|
||||||
from langchain.agents.agent import Agent
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
@ -30,7 +22,8 @@ from langchain.agents.agent import AgentExecutor
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from pydantic.schema import model_schema
|
from langchain_core.callbacks import Callbacks
|
||||||
|
|
||||||
|
|
||||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -195,6 +188,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
|
verbose=True
|
||||||
)
|
)
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
@ -9,7 +8,7 @@ from configs import TOOL_CONFIG
|
|||||||
def search_knowledgebase(query: str, database: str, config: dict):
|
def search_knowledgebase(query: str, database: str, config: dict):
|
||||||
docs = search_docs(
|
docs = search_docs(
|
||||||
query=query,
|
query=query,
|
||||||
knowledge_base_name=database,
|
knowledge_base_name="samples",
|
||||||
top_k=config["top_k"],
|
top_k=config["top_k"],
|
||||||
score_threshold=config["score_threshold"])
|
score_threshold=config["score_threshold"])
|
||||||
context = ""
|
context = ""
|
||||||
@ -22,7 +21,7 @@ def search_knowledgebase(query: str, database: str, config: dict):
|
|||||||
source_documents.append(text)
|
source_documents.append(text)
|
||||||
|
|
||||||
if len(source_documents) == 0:
|
if len(source_documents) == 0:
|
||||||
context= "没有找到相关文档,请更换关键词重试"
|
context = "没有找到相关文档,请更换关键词重试"
|
||||||
else:
|
else:
|
||||||
for doc in source_documents:
|
for doc in source_documents:
|
||||||
context += doc + "\n"
|
context += doc + "\n"
|
||||||
@ -37,4 +36,4 @@ class SearchKnowledgeInput(BaseModel):
|
|||||||
|
|
||||||
def search_local_knowledgebase(database: str, query: str):
|
def search_local_knowledgebase(database: str, query: str):
|
||||||
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
|
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
|
||||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
@ -2,9 +2,9 @@ from langchain_core.tools import StructuredTool
|
|||||||
from server.agent.tools_factory import *
|
from server.agent.tools_factory import *
|
||||||
from configs import KB_INFO
|
from configs import KB_INFO
|
||||||
|
|
||||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool."
|
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||||
template_knowledge = template.format(KB_info=KB_info_str)
|
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||||
|
|
||||||
all_tools = [
|
all_tools = [
|
||||||
StructuredTool.from_function(
|
StructuredTool.from_function(
|
||||||
@ -24,6 +24,7 @@ all_tools = [
|
|||||||
name="shell",
|
name="shell",
|
||||||
description="Use Shell to execute Linux commands",
|
description="Use Shell to execute Linux commands",
|
||||||
args_schema=ShellInput,
|
args_schema=ShellInput,
|
||||||
|
# return_direct=True, #是否直接返回,不做大模型处理
|
||||||
),
|
),
|
||||||
StructuredTool.from_function(
|
StructuredTool.from_function(
|
||||||
func=wolfram,
|
func=wolfram,
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import json
|
import json
|
||||||
from langchain.schema import AgentFinish, AgentAction
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
from langchain.schema import AgentFinish, AgentAction
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain.callbacks.base import AsyncCallbackHandler
|
|
||||||
|
|
||||||
def dumps(obj: Dict) -> str:
|
def dumps(obj: Dict) -> str:
|
||||||
return json.dumps(obj, ensure_ascii=False)
|
return json.dumps(obj, ensure_ascii=False)
|
||||||
@ -21,7 +23,7 @@ class Status:
|
|||||||
tool_finish: int = 7
|
tool_finish: int = 7
|
||||||
|
|
||||||
|
|
||||||
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
@ -29,29 +31,31 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|||||||
self.cur_tool = {}
|
self.cur_tool = {}
|
||||||
self.out = True
|
self.out = True
|
||||||
|
|
||||||
async def on_tool_start(
|
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||||
self,
|
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||||
serialized: Dict[str, Any],
|
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||||
input_str: str,
|
self.cur_tool = {
|
||||||
*,
|
"tool_name": serialized["name"],
|
||||||
run_id: UUID,
|
"input_str": input_str,
|
||||||
parent_run_id: Optional[UUID] = None,
|
"output_str": "",
|
||||||
tags: Optional[List[str]] = None,
|
"status": Status.agent_action,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
"run_id": run_id.hex,
|
||||||
**kwargs: Any,
|
"llm_token": "",
|
||||||
) -> None:
|
"final_answer": "",
|
||||||
print("on_tool_start")
|
"error": "",
|
||||||
|
}
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||||
|
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
|
||||||
|
self.out = True
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.tool_finish,
|
||||||
|
output_str=output.replace("Answer:", ""),
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_tool_end(
|
|
||||||
self,
|
|
||||||
output: str,
|
|
||||||
*,
|
|
||||||
run_id: UUID,
|
|
||||||
parent_run_id: Optional[UUID] = None,
|
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
print("on_tool_end")
|
|
||||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
@ -134,29 +138,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|||||||
tool_input=action.tool_input,
|
tool_input=action.tool_input,
|
||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if "Thought:" in finish.return_values["output"]:
|
||||||
|
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
||||||
|
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.agent_finish,
|
status=Status.agent_finish,
|
||||||
agent_finish=finish.return_values["output"],
|
agent_finish=finish.return_values["output"],
|
||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def aiter(self) -> AsyncIterator[str]:
|
|
||||||
while not self.queue.empty() or not self.done.is_set():
|
|
||||||
done, other = await asyncio.wait(
|
|
||||||
[
|
|
||||||
asyncio.ensure_future(self.queue.get()),
|
|
||||||
asyncio.ensure_future(self.done.wait()),
|
|
||||||
],
|
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
|
||||||
)
|
|
||||||
if other:
|
|
||||||
other.pop().cancel()
|
|
||||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
|
||||||
if token_or_done is True:
|
|
||||||
break
|
|
||||||
yield token_or_done
|
|
||||||
|
|||||||
@ -6,12 +6,11 @@ from fastapi import Body
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from langchain.agents import initialize_agent, AgentType
|
from langchain.agents import initialize_agent, AgentType
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.runnables import RunnableBranch
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain_core.runnables import RunnableBranch
|
||||||
|
|
||||||
from server.agent.agent_factory import initialize_glm3_agent
|
from server.agent.agent_factory import initialize_glm3_agent
|
||||||
from server.agent.tools_factory.tools_registry import all_tools
|
from server.agent.tools_factory.tools_registry import all_tools
|
||||||
@ -57,48 +56,55 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
elif conversation_id and history_len > 0:
|
elif conversation_id and history_len > 0:
|
||||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
memory = ConversationBufferDBMemory(
|
||||||
message_limit=history_len)
|
conversation_id=conversation_id,
|
||||||
|
llm=models["llm_model"],
|
||||||
|
message_limit=history_len
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
chain = LLMChain(
|
||||||
|
prompt=chat_prompt,
|
||||||
|
llm=models["llm_model"],
|
||||||
|
memory=memory
|
||||||
|
)
|
||||||
classifier_chain = (
|
classifier_chain = (
|
||||||
PromptTemplate.from_template(prompts["preprocess_model"])
|
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||||
| models["preprocess_model"]
|
| models["preprocess_model"]
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chatglm3" in models["action_model"].model_name.lower():
|
if "action_model" in models and tools:
|
||||||
agent_executor = initialize_glm3_agent(
|
if "chatglm3" in models["action_model"].model_name.lower():
|
||||||
llm=models["action_model"],
|
agent_executor = initialize_glm3_agent(
|
||||||
tools=tools,
|
llm=models["action_model"],
|
||||||
prompt=prompts["action_model"],
|
tools=tools,
|
||||||
input_variables=["input", "intermediate_steps", "history"],
|
prompt=prompts["action_model"],
|
||||||
memory=memory,
|
input_variables=["input", "intermediate_steps", "history"],
|
||||||
callback_manager=BaseCallbackManager(handlers=callbacks),
|
memory=memory,
|
||||||
verbose=True,
|
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||||
)
|
verbose=True,
|
||||||
else:
|
)
|
||||||
agent_executor = initialize_agent(
|
else:
|
||||||
llm=models["action_model"],
|
agent_executor = initialize_agent(
|
||||||
tools=tools,
|
llm=models["action_model"],
|
||||||
callbacks=callbacks,
|
tools=tools,
|
||||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
callbacks=callbacks,
|
||||||
memory=memory,
|
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
memory=memory,
|
||||||
)
|
verbose=True,
|
||||||
agent_use = False
|
)
|
||||||
if agent_use:
|
|
||||||
branch = RunnableBranch(
|
# branch = RunnableBranch(
|
||||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
# (lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||||
chain
|
# chain
|
||||||
)
|
# )
|
||||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||||
else:
|
|
||||||
# full_chain = ({"input": lambda x: x["input"]} | chain)
|
|
||||||
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
||||||
|
else:
|
||||||
|
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||||
return full_chain
|
return full_chain
|
||||||
|
|
||||||
|
|
||||||
@ -149,7 +155,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
|
|
||||||
# Execute Chain
|
# Execute Chain
|
||||||
|
|
||||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
task = asyncio.create_task(
|
||||||
|
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||||
if stream:
|
if stream:
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|||||||
@ -508,30 +508,33 @@ def set_httpx_config(
|
|||||||
urllib.request.getproxies = _get_proxies
|
urllib.request.getproxies = _get_proxies
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
return "mps"
|
return "mps"
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
if torch.xpu.get_device_properties(0):
|
||||||
|
return "xpu"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or LLM_DEVICE
|
||||||
# if device.isdigit():
|
# if device.isdigit():
|
||||||
# return "cuda:" + device
|
# return "cuda:" + device
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||||
device = device or EMBEDDING_DEVICE
|
device = device or EMBEDDING_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|||||||
@ -219,7 +219,6 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
wbits=args.awq_wbits,
|
wbits=args.awq_wbits,
|
||||||
groupsize=args.awq_groupsize,
|
groupsize=args.awq_groupsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker = ModelWorker(
|
worker = ModelWorker(
|
||||||
controller_addr=args.controller_address,
|
controller_addr=args.controller_address,
|
||||||
worker_addr=args.worker_address,
|
worker_addr=args.worker_address,
|
||||||
@ -391,7 +390,6 @@ def run_model_worker(
|
|||||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||||
model_path = kwargs.get("model_path", "")
|
model_path = kwargs.get("model_path", "")
|
||||||
kwargs["model_path"] = model_path
|
kwargs["model_path"] = model_path
|
||||||
|
|
||||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||||
_set_app_event(app, started_event)
|
_set_app_event(app, started_event)
|
||||||
if log_level == "ERROR":
|
if log_level == "ERROR":
|
||||||
|
|||||||
@ -10,8 +10,7 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS)
|
||||||
from server.knowledge_base.utils import LOADER_DICT
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
@ -124,15 +123,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
tools = list(TOOL_CONFIG.keys())
|
|
||||||
selected_tool_configs = {}
|
|
||||||
|
|
||||||
with st.expander("工具栏"):
|
|
||||||
for tool in tools:
|
|
||||||
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
|
|
||||||
if is_selected:
|
|
||||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
|
||||||
|
|
||||||
if st.session_state.get("cur_conv_name") in conv_names:
|
if st.session_state.get("cur_conv_name") in conv_names:
|
||||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||||
@ -177,7 +167,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
for k, v in config_models.get("online", {}).items():
|
for k, v in config_models.get("online", {}).items():
|
||||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
llm_models = running_models + available_models + ["openai-api"]
|
llm_models = running_models + available_models # + ["openai-api"]
|
||||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||||
if cur_llm_model in llm_models:
|
if cur_llm_model in llm_models:
|
||||||
index = llm_models.index(cur_llm_model)
|
index = llm_models.index(cur_llm_model)
|
||||||
@ -193,22 +183,39 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
|
|
||||||
# 传入后端的内容
|
# 传入后端的内容
|
||||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||||
|
tool_use = True
|
||||||
for key in LLM_MODEL_CONFIG:
|
for key in LLM_MODEL_CONFIG:
|
||||||
if key == 'llm_model':
|
if key == 'llm_model':
|
||||||
continue
|
continue
|
||||||
|
if key == 'action_model':
|
||||||
|
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||||
|
if first_key not in SUPPORT_AGENT_MODELS:
|
||||||
|
st.warning("不支持Agent的模型,无法执行任何工具调用")
|
||||||
|
tool_use = False
|
||||||
|
continue
|
||||||
if LLM_MODEL_CONFIG[key]:
|
if LLM_MODEL_CONFIG[key]:
|
||||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||||
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||||
|
|
||||||
|
# 选择工具
|
||||||
|
selected_tool_configs = {}
|
||||||
|
if tool_use:
|
||||||
|
from configs import prompt_config
|
||||||
|
import importlib
|
||||||
|
importlib.reload(prompt_config)
|
||||||
|
|
||||||
|
tools = list(prompt_config.TOOL_CONFIG.keys())
|
||||||
|
with st.expander("工具栏"):
|
||||||
|
for tool in tools:
|
||||||
|
is_selected = st.checkbox(tool, value=prompt_config.TOOL_CONFIG[tool]["use"], key=tool)
|
||||||
|
if is_selected:
|
||||||
|
selected_tool_configs[tool] = prompt_config.TOOL_CONFIG[tool]
|
||||||
|
|
||||||
if llm_model is not None:
|
if llm_model is not None:
|
||||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||||
|
|
||||||
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
|
||||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
|
||||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||||
|
|
||||||
# print(len(files_upload["audios"])) if files_upload else None
|
# print(len(files_upload["audios"])) if files_upload else None
|
||||||
|
|
||||||
# if dialogue_mode == "文件对话":
|
# if dialogue_mode == "文件对话":
|
||||||
@ -355,6 +362,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_box.reset_history()
|
chat_box.reset_history()
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
|
warning_placeholder = st.empty()
|
||||||
|
with warning_placeholder.container():
|
||||||
|
st.warning('Running in 8 x A100')
|
||||||
|
|
||||||
export_btn.download_button(
|
export_btn.download_button(
|
||||||
"导出记录",
|
"导出记录",
|
||||||
"".join(chat_box.export2md()),
|
"".join(chat_box.export2md()),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user