mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-31 19:33:26 +08:00
添加 xinference 本地模型和自定义模型配置 UI: streamlit run model_loaders/xinference_manager.py
This commit is contained in:
parent
76b796ea58
commit
49bc5b54a4
260
model_loaders/xinference_manager.py
Normal file
260
model_loaders/xinference_manager.py
Normal file
@ -0,0 +1,260 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from xinference.client import Client
|
||||
import xinference.model.llm as xf_llm
|
||||
from xinference.model import embedding as xf_embedding
|
||||
from xinference.model import image as xf_image
|
||||
from xinference.model import rerank as xf_rerank
|
||||
from xinference.model import audio as xf_audio
|
||||
from xinference.constants import XINFERENCE_CACHE_DIR
|
||||
|
||||
|
||||
model_types = ["LLM", "embedding", "image", "rerank"] # "audio"
|
||||
model_name_suffix = "-custom"
|
||||
cache_methods = {
|
||||
"LLM": xf_llm.llm_family.cache,
|
||||
"embedding": xf_embedding.core.cache,
|
||||
"image": xf_image.core.cache,
|
||||
"rerank": xf_rerank.core.cache,
|
||||
}
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_client(url: str):
|
||||
return Client(url)
|
||||
|
||||
|
||||
def get_cache_dir(
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_format: str = "",
|
||||
model_size: str = "",
|
||||
):
|
||||
if model_type == "LLM":
|
||||
dir_name = f"{model_name}-{model_format}-{model_size}b"
|
||||
else:
|
||||
dir_name = f"{model_name}"
|
||||
return os.path.join(XINFERENCE_CACHE_DIR, dir_name)
|
||||
|
||||
|
||||
def get_meta_path(
|
||||
model_type: str,
|
||||
cache_dir: str,
|
||||
model_format: str,
|
||||
model_hub: str = "huggingface",
|
||||
model_quant: str = "none",
|
||||
):
|
||||
if model_type == "LLM":
|
||||
return xf_llm.llm_family._get_meta_path(
|
||||
cache_dir=cache_dir,
|
||||
model_format=model_format,
|
||||
model_hub=model_hub,
|
||||
quantization=model_quant,
|
||||
)
|
||||
else:
|
||||
return os.path.join(cache_dir, "__valid_download")
|
||||
|
||||
|
||||
def list_running_models():
|
||||
models = client.list_models()
|
||||
columns = [
|
||||
"UID",
|
||||
"type",
|
||||
"name",
|
||||
"ability",
|
||||
"size",
|
||||
"quant",
|
||||
"max_tokens",
|
||||
]
|
||||
data = []
|
||||
for k, v in models.items():
|
||||
item = dict(
|
||||
UID=k,
|
||||
type=v["model_type"],
|
||||
name=v["model_name"],
|
||||
)
|
||||
if v["model_type"] == "LLM":
|
||||
item.update(
|
||||
ability=v["model_ability"],
|
||||
size=str(v["model_size_in_billions"]) + "B",
|
||||
quant=v["quantization"],
|
||||
max_tokens=v["context_length"],
|
||||
)
|
||||
elif v["model_type"] == "embedding":
|
||||
item.update(
|
||||
max_tokens=v["max_tokens"],
|
||||
)
|
||||
data.append(item)
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
df.index += 1
|
||||
return df
|
||||
|
||||
|
||||
def get_model_registrations():
|
||||
data = {}
|
||||
for model_type in model_types:
|
||||
data[model_type] = {}
|
||||
for m in client.list_model_registrations(model_type):
|
||||
data[model_type][m["model_name"]] = {"is_builtin": m["is_builtin"]}
|
||||
reg = client.get_model_registration(model_type, m["model_name"])
|
||||
data[model_type][m["model_name"]]["reg"] = reg
|
||||
return data
|
||||
|
||||
|
||||
with st.sidebar:
|
||||
st.subheader("请先执行 xinference 或 xinference-local 命令启动 XF 服务。然后将 XF 服务地址配置在下方。")
|
||||
xf_url = st.text_input("Xinference endpoint", "http://127.0.0.1:9997")
|
||||
st.divider()
|
||||
st.markdown(
|
||||
"### 使用方法\n\n"
|
||||
"- 场景1:我已经下载过模型,不想 XF 内置模型重复下载\n\n"
|
||||
"- 操作:选择对应的模型后,填写好本地模型路径,点‘设置模型缓存’即可\n\n"
|
||||
"- 场景2:我想对 XF 内置模型做一些修改,又不想从头写模型注册文件\n\n"
|
||||
"- 操作:选择对应的模型后,填写好本地模型路径,点‘注册为自定义模型’即可\n\n"
|
||||
"- 场景3:我不小心设置了错误的模型路径\n\n"
|
||||
"- 操作:直接‘删除模型缓存’或更换正确的路径后‘设置模型缓存’即可\n\n"
|
||||
)
|
||||
client = get_client(xf_url)
|
||||
|
||||
|
||||
st.subheader("当前运行的模型:")
|
||||
st.dataframe(list_running_models())
|
||||
|
||||
st.subheader("配置模型路径:")
|
||||
regs = get_model_registrations()
|
||||
cols = st.columns([3, 4, 3, 2, 2])
|
||||
|
||||
model_type = cols[0].selectbox("模型类别:", model_types)
|
||||
|
||||
model_names = list(regs[model_type].keys())
|
||||
model_name = cols[1].selectbox("模型名称:", model_names)
|
||||
|
||||
cur_reg = regs[model_type][model_name]["reg"]
|
||||
|
||||
if model_type == "LLM":
|
||||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
||||
cur_spec = None
|
||||
model_formats = []
|
||||
for spec in cur_reg["model_specs"]:
|
||||
if spec["model_format"] not in model_formats:
|
||||
model_formats.append(spec["model_format"])
|
||||
index = 0
|
||||
if "pytorch" in model_formats:
|
||||
index = model_formats.index("pytorch")
|
||||
model_format = cols[2].selectbox("模型格式:", model_formats, index)
|
||||
|
||||
model_sizes = []
|
||||
for spec in cur_reg["model_specs"]:
|
||||
if (spec["model_format"] == model_format
|
||||
and spec["model_size_in_billions"] not in model_sizes):
|
||||
model_sizes.append(spec["model_size_in_billions"])
|
||||
model_size = cols[3].selectbox("模型大小:", model_sizes, format_func=lambda x: f"{x}B")
|
||||
|
||||
model_quants = []
|
||||
for spec in cur_reg["model_specs"]:
|
||||
if (spec["model_format"] == model_format
|
||||
and spec["model_size_in_billions"] == model_size):
|
||||
model_quants = spec["quantizations"]
|
||||
model_quant = cols[4].selectbox("模型量化:", model_quants)
|
||||
if model_quant == "none":
|
||||
model_quant = None
|
||||
|
||||
for i, spec in enumerate(cur_reg["model_specs"]):
|
||||
if (spec["model_format"] == model_format
|
||||
and spec["model_size_in_billions"] == model_size):
|
||||
cur_spec = cur_family.model_specs[i]
|
||||
break
|
||||
cache_dir = get_cache_dir(model_type, model_name, model_format, model_size)
|
||||
elif model_type == "embedding":
|
||||
cur_spec = xf_embedding.core.EmbeddingModelSpec.parse_obj(cur_reg)
|
||||
cache_dir = get_cache_dir(model_type, model_name)
|
||||
elif model_type == "image":
|
||||
cur_spec = xf_image.core.ImageModelFamilyV1.parse_obj(cur_reg)
|
||||
cache_dir = get_cache_dir(model_type, model_name)
|
||||
elif model_type == "rerank":
|
||||
cur_spec = xf_rerank.core.RerankModelSpec.parse_obj(cur_reg)
|
||||
cache_dir = get_cache_dir(model_type, model_name)
|
||||
|
||||
meta_file = get_meta_path(
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
cache_dir=cache_dir,
|
||||
model_quant=model_quant)
|
||||
|
||||
if os.path.isdir(cache_dir):
|
||||
try:
|
||||
with open(meta_file, encoding="utf-8") as fp:
|
||||
meta = json.load(fp)
|
||||
revision = meta.get("revision", meta.get("model_revision"))
|
||||
except:
|
||||
revision = ""
|
||||
if cur_spec.model_revision and cur_spec.model_revision != revision:
|
||||
revision += " (与 XF 内置版本号不一致)"
|
||||
else:
|
||||
revision += " (符合要求)"
|
||||
|
||||
text = (f"模型已缓存。\n\n"
|
||||
f"缓存路径:{cache_dir}\n\n"
|
||||
f"原始路径:{os.readlink(cache_dir)[4:]}\n\n"
|
||||
f"版本号 :{revision}"
|
||||
)
|
||||
else:
|
||||
text = "模型尚未缓存"
|
||||
|
||||
st.divider()
|
||||
st.markdown(text)
|
||||
st.divider()
|
||||
|
||||
cols = st.columns([3, 1])
|
||||
local_path = cols[0].text_input("本地模型绝对路径:")
|
||||
|
||||
if cols[1].button("设置模型缓存"):
|
||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||
cur_spec.model_uri = local_path
|
||||
if os.path.isdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
if model_type == "LLM":
|
||||
cache_methods[model_type](cur_family, cur_spec, model_quant)
|
||||
else:
|
||||
cache_methods[model_type](cur_spec)
|
||||
if cur_spec.model_revision:
|
||||
for hub in ["huggingface", "modelscope"]:
|
||||
meta_file = get_meta_path(
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
model_hub=hub,
|
||||
cache_dir=cache_dir,
|
||||
model_quant=model_quant)
|
||||
with open(meta_file, "w", encoding="utf-8") as fp:
|
||||
json.dump({"revision": cur_spec.model_revision}, fp)
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("必须输入存在的绝对路径")
|
||||
|
||||
if cols[1].button("删除模型缓存"):
|
||||
if os.path.isdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
|
||||
if cols[1].button("注册为自定义模型"):
|
||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||
cur_spec.model_uri = local_path
|
||||
cur_spec.model_revision = None
|
||||
if model_type == "LLM":
|
||||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||||
cur_family.model_family = "other"
|
||||
model_definition = cur_family.json(indent=2, ensure_ascii=False)
|
||||
else:
|
||||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||||
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
|
||||
client.register_model(
|
||||
model_type=model_type,
|
||||
model=model_definition,
|
||||
persist=True,
|
||||
)
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("必须输入存在的绝对路径")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user