mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-03 21:23:13 +08:00
更新一些内容
This commit is contained in:
parent
5714358403
commit
bc225bf9f5
@ -1,6 +1,6 @@
|
|||||||
# API requirements
|
# API requirements
|
||||||
|
|
||||||
langchain>=0.0.346
|
langchain>=0.0.348
|
||||||
langchain-experimental>=0.0.42
|
langchain-experimental>=0.0.42
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
|
|||||||
@ -127,7 +127,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
|
|
||||||
|
|
||||||
def mount_knowledge_routes(app: FastAPI):
|
def mount_knowledge_routes(app: FastAPI):
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
@ -135,10 +134,6 @@ def mount_knowledge_routes(app: FastAPI):
|
|||||||
search_docs, DocumentWithVSId, update_info,
|
search_docs, DocumentWithVSId, update_info,
|
||||||
update_docs_by_id,)
|
update_docs_by_id,)
|
||||||
|
|
||||||
app.post("/chat/knowledge_base_chat",
|
|
||||||
tags=["Chat"],
|
|
||||||
summary="与知识库对话")(knowledge_base_chat)
|
|
||||||
|
|
||||||
app.post("/chat/file_chat",
|
app.post("/chat/file_chat",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
summary="文件对话"
|
summary="文件对话"
|
||||||
|
|||||||
@ -83,11 +83,15 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
branch = RunnableBranch(
|
agent_use = False
|
||||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
if agent_use:
|
||||||
chain
|
branch = RunnableBranch(
|
||||||
)
|
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
chain
|
||||||
|
)
|
||||||
|
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||||
|
else:
|
||||||
|
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||||
return full_chain
|
return full_chain
|
||||||
|
|
||||||
|
|
||||||
@ -107,15 +111,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
tool_config: Dict = Body({}, description="工具配置"),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
|
||||||
conversation_id=conversation_id) if conversation_id else None
|
|
||||||
|
|
||||||
callback = CustomAsyncIteratorCallbackHandler()
|
callback = CustomAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||||
|
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
|
|
||||||
# 构建完整的Chain
|
# 构建完整的Chain
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
models=models,
|
models=models,
|
||||||
|
|||||||
@ -363,7 +363,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||||
|
print(config, "*******")
|
||||||
|
breakpoint()
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
config["online_api"] = True
|
config["online_api"] = True
|
||||||
if provider := config.get("provider"):
|
if provider := config.get("provider"):
|
||||||
@ -380,6 +381,7 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
if path and os.path.isdir(path):
|
if path and os.path.isdir(path):
|
||||||
config["model_path_exists"] = True
|
config["model_path_exists"] = True
|
||||||
config["device"] = llm_device(config.get("device"))
|
config["device"] = llm_device(config.get("device"))
|
||||||
|
breakpoint()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@ -523,6 +525,8 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
|||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or LLM_DEVICE
|
||||||
|
# if device.isdigit():
|
||||||
|
# return "cuda:" + device
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|||||||
BIN
teaser.png
BIN
teaser.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.0 MiB |
32
test.py
32
test.py
@ -1,32 +0,0 @@
|
|||||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 加载模型和处理器
|
|
||||||
model_id = "/data/models/fuyu-8b"
|
|
||||||
processor = FuyuProcessor.from_pretrained(model_id)
|
|
||||||
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16)
|
|
||||||
|
|
||||||
# 将模型转换为 bf16
|
|
||||||
model = model.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# 准备模型的输入
|
|
||||||
# text_prompt = "According to this chart, which model performs best?\n"
|
|
||||||
|
|
||||||
text_prompt = "Generate a coco-style caption.\n"
|
|
||||||
image = Image.open("1.png").convert("RGB")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# 获取用户输入的文本提示
|
|
||||||
text_prompt = input("请输入文本提示: ")
|
|
||||||
if text_prompt.lower() == 'exit':
|
|
||||||
break
|
|
||||||
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
|
|
||||||
|
|
||||||
# 生成输出
|
|
||||||
generation_output = model.generate(**inputs, max_new_tokens=7)
|
|
||||||
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
|
|
||||||
|
|
||||||
# 打印生成的文本
|
|
||||||
print(generation_text)
|
|
||||||
@ -175,7 +175,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
for k, v in config_models.get("online", {}).items():
|
for k, v in config_models.get("online", {}).items():
|
||||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models + ["openai-api"]
|
||||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||||
if cur_llm_model in llm_models:
|
if cur_llm_model in llm_models:
|
||||||
index = llm_models.index(cur_llm_model)
|
index = llm_models.index(cur_llm_model)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user