diff --git a/1.png b/1.png deleted file mode 100644 index 3979555c..00000000 Binary files a/1.png and /dev/null differ diff --git a/bus.png b/bus.png deleted file mode 100644 index 8994abc4..00000000 Binary files a/bus.png and /dev/null differ diff --git a/requirements.txt b/requirements.txt index bb341dd6..7a99abc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # API requirements -langchain>=0.0.346 +langchain>=0.0.348 langchain-experimental>=0.0.42 pydantic==1.10.13 fschat==0.2.35 diff --git a/server/api.py b/server/api.py index 3c4d04df..33954e85 100644 --- a/server/api.py +++ b/server/api.py @@ -127,7 +127,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): def mount_knowledge_routes(app: FastAPI): - from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.file_chat import upload_temp_docs, file_chat from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, @@ -135,10 +134,6 @@ def mount_knowledge_routes(app: FastAPI): search_docs, DocumentWithVSId, update_info, update_docs_by_id,) - app.post("/chat/knowledge_base_chat", - tags=["Chat"], - summary="与知识库对话")(knowledge_base_chat) - app.post("/chat/file_chat", tags=["Knowledge Base Management"], summary="文件对话" diff --git a/server/chat/chat.py b/server/chat/chat.py index 1dd92fab..bd92210d 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -83,11 +83,15 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks memory=memory, verbose=True, ) - branch = RunnableBranch( - (lambda x: "1" in x["topic"].lower(), agent_executor), - chain - ) - full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) + agent_use = False + if agent_use: + branch = RunnableBranch( + (lambda x: "1" in x["topic"].lower(), agent_executor), + chain + ) + full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) + else: + full_chain = ({"input": lambda x: x["input"]} | chain) return full_chain @@ -107,15 +111,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: - message_id = add_message_to_db(chat_type="llm_chat", query=query, - conversation_id=conversation_id) if conversation_id else None - + message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None callback = CustomAsyncIteratorCallbackHandler() callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) - tools = [tool for tool in all_tools if tool.name in tool_config] - # 构建完整的Chain full_chain = create_models_chains(prompts=prompts, models=models, diff --git a/server/utils.py b/server/utils.py index dac08a04..5cf62417 100644 --- a/server/utils.py +++ b/server/utils.py @@ -363,7 +363,8 @@ def get_model_worker_config(model_name: str = None) -> dict: config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy()) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy()) - + print(config, "*******") + breakpoint() if model_name in ONLINE_LLM_MODEL: config["online_api"] = True if provider := config.get("provider"): @@ -380,6 +381,7 @@ def get_model_worker_config(model_name: str = None) -> dict: if path and os.path.isdir(path): config["model_path_exists"] = True config["device"] = llm_device(config.get("device")) + breakpoint() return config @@ -523,6 +525,8 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]: def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE + # if device.isdigit(): + # return "cuda:" + device if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device diff --git a/teaser.png b/teaser.png deleted file mode 100644 index c01532b6..00000000 Binary files a/teaser.png and /dev/null differ diff --git a/test.py b/test.py deleted file mode 100644 index b4b73e23..00000000 --- a/test.py +++ /dev/null @@ -1,32 +0,0 @@ -from transformers import FuyuProcessor, FuyuForCausalLM -from PIL import Image -import requests -import torch - -# 加载模型和处理器 -model_id = "/data/models/fuyu-8b" -processor = FuyuProcessor.from_pretrained(model_id) -model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16) - -# 将模型转换为 bf16 -model = model.to(dtype=torch.bfloat16) - -# 准备模型的输入 -# text_prompt = "According to this chart, which model performs best?\n" - -text_prompt = "Generate a coco-style caption.\n" -image = Image.open("1.png").convert("RGB") - -while True: - # 获取用户输入的文本提示 - text_prompt = input("请输入文本提示: ") - if text_prompt.lower() == 'exit': - break - inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0") - - # 生成输出 - generation_output = model.generate(**inputs, max_new_tokens=7) - generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True) - - # 打印生成的文本 - print(generation_text) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 2e82ec23..77e90068 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -175,7 +175,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): for k, v in config_models.get("online", {}).items(): if not v.get("provider") and k not in running_models and k in LLM_MODELS: available_models.append(k) - llm_models = running_models + available_models + llm_models = running_models + available_models + ["openai-api"] cur_llm_model = st.session_state.get("cur_llm_model", default_model) if cur_llm_model in llm_models: index = llm_models.index(cur_llm_model)