mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
更新一些内容
This commit is contained in:
parent
5714358403
commit
bc225bf9f5
@ -1,6 +1,6 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.346
|
||||
langchain>=0.0.348
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
|
||||
@ -127,7 +127,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
|
||||
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
@ -135,10 +134,6 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
search_docs, DocumentWithVSId, update_info,
|
||||
update_docs_by_id,)
|
||||
|
||||
app.post("/chat/knowledge_base_chat",
|
||||
tags=["Chat"],
|
||||
summary="与知识库对话")(knowledge_base_chat)
|
||||
|
||||
app.post("/chat/file_chat",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="文件对话"
|
||||
|
||||
@ -83,11 +83,15 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
agent_use = False
|
||||
if agent_use:
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
else:
|
||||
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
return full_chain
|
||||
|
||||
|
||||
@ -107,15 +111,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
||||
conversation_id=conversation_id) if conversation_id else None
|
||||
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
|
||||
# 构建完整的Chain
|
||||
full_chain = create_models_chains(prompts=prompts,
|
||||
models=models,
|
||||
|
||||
@ -363,7 +363,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
|
||||
print(config, "*******")
|
||||
breakpoint()
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
@ -380,6 +381,7 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
breakpoint()
|
||||
return config
|
||||
|
||||
|
||||
@ -523,6 +525,8 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
# if device.isdigit():
|
||||
# return "cuda:" + device
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
BIN
teaser.png
BIN
teaser.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.0 MiB |
32
test.py
32
test.py
@ -1,32 +0,0 @@
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# 加载模型和处理器
|
||||
model_id = "/data/models/fuyu-8b"
|
||||
processor = FuyuProcessor.from_pretrained(model_id)
|
||||
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.float16)
|
||||
|
||||
# 将模型转换为 bf16
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
|
||||
# 准备模型的输入
|
||||
# text_prompt = "According to this chart, which model performs best?\n"
|
||||
|
||||
text_prompt = "Generate a coco-style caption.\n"
|
||||
image = Image.open("1.png").convert("RGB")
|
||||
|
||||
while True:
|
||||
# 获取用户输入的文本提示
|
||||
text_prompt = input("请输入文本提示: ")
|
||||
if text_prompt.lower() == 'exit':
|
||||
break
|
||||
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
|
||||
|
||||
# 生成输出
|
||||
generation_output = model.generate(**inputs, max_new_tokens=7)
|
||||
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
|
||||
|
||||
# 打印生成的文本
|
||||
print(generation_text)
|
||||
@ -175,7 +175,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
for k, v in config_models.get("online", {}).items():
|
||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
llm_models = running_models + available_models + ["openai-api"]
|
||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||
if cur_llm_model in llm_models:
|
||||
index = llm_models.index(cur_llm_model)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user