From 49bc5b54a45f02bec1434fce7ae464928d400768 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 6 Mar 2024 12:58:48 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20xinference=20=E6=9C=AC?= =?UTF-8?q?=E5=9C=B0=E6=A8=A1=E5=9E=8B=E5=92=8C=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E9=85=8D=E7=BD=AE=20UI:=20streamlit=20run=20?= =?UTF-8?q?model=5Floaders/xinference=5Fmanager.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_loaders/xinference_manager.py | 260 ++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 model_loaders/xinference_manager.py diff --git a/model_loaders/xinference_manager.py b/model_loaders/xinference_manager.py new file mode 100644 index 00000000..427bac52 --- /dev/null +++ b/model_loaders/xinference_manager.py @@ -0,0 +1,260 @@ +import json +import os + +import pandas as pd +import streamlit as st +from xinference.client import Client +import xinference.model.llm as xf_llm +from xinference.model import embedding as xf_embedding +from xinference.model import image as xf_image +from xinference.model import rerank as xf_rerank +from xinference.model import audio as xf_audio +from xinference.constants import XINFERENCE_CACHE_DIR + + +model_types = ["LLM", "embedding", "image", "rerank"] # "audio" +model_name_suffix = "-custom" +cache_methods = { + "LLM": xf_llm.llm_family.cache, + "embedding": xf_embedding.core.cache, + "image": xf_image.core.cache, + "rerank": xf_rerank.core.cache, +} + + +@st.cache_resource +def get_client(url: str): + return Client(url) + + +def get_cache_dir( + model_type: str, + model_name: str, + model_format: str = "", + model_size: str = "", +): + if model_type == "LLM": + dir_name = f"{model_name}-{model_format}-{model_size}b" + else: + dir_name = f"{model_name}" + return os.path.join(XINFERENCE_CACHE_DIR, dir_name) + + +def get_meta_path( + model_type: str, + cache_dir: str, + model_format: str, + model_hub: str = "huggingface", + model_quant: str = "none", +): + if model_type == "LLM": + return xf_llm.llm_family._get_meta_path( + cache_dir=cache_dir, + model_format=model_format, + model_hub=model_hub, + quantization=model_quant, + ) + else: + return os.path.join(cache_dir, "__valid_download") + + +def list_running_models(): + models = client.list_models() + columns = [ + "UID", + "type", + "name", + "ability", + "size", + "quant", + "max_tokens", + ] + data = [] + for k, v in models.items(): + item = dict( + UID=k, + type=v["model_type"], + name=v["model_name"], + ) + if v["model_type"] == "LLM": + item.update( + ability=v["model_ability"], + size=str(v["model_size_in_billions"]) + "B", + quant=v["quantization"], + max_tokens=v["context_length"], + ) + elif v["model_type"] == "embedding": + item.update( + max_tokens=v["max_tokens"], + ) + data.append(item) + df = pd.DataFrame(data, columns=columns) + df.index += 1 + return df + + +def get_model_registrations(): + data = {} + for model_type in model_types: + data[model_type] = {} + for m in client.list_model_registrations(model_type): + data[model_type][m["model_name"]] = {"is_builtin": m["is_builtin"]} + reg = client.get_model_registration(model_type, m["model_name"]) + data[model_type][m["model_name"]]["reg"] = reg + return data + + +with st.sidebar: + st.subheader("请先执行 xinference 或 xinference-local 命令启动 XF 服务。然后将 XF 服务地址配置在下方。") + xf_url = st.text_input("Xinference endpoint", "http://127.0.0.1:9997") + st.divider() + st.markdown( + "### 使用方法\n\n" + "- 场景1:我已经下载过模型,不想 XF 内置模型重复下载\n\n" + "- 操作:选择对应的模型后,填写好本地模型路径,点‘设置模型缓存’即可\n\n" + "- 场景2:我想对 XF 内置模型做一些修改,又不想从头写模型注册文件\n\n" + "- 操作:选择对应的模型后,填写好本地模型路径,点‘注册为自定义模型’即可\n\n" + "- 场景3:我不小心设置了错误的模型路径\n\n" + "- 操作:直接‘删除模型缓存’或更换正确的路径后‘设置模型缓存’即可\n\n" + ) +client = get_client(xf_url) + + +st.subheader("当前运行的模型:") +st.dataframe(list_running_models()) + +st.subheader("配置模型路径:") +regs = get_model_registrations() +cols = st.columns([3, 4, 3, 2, 2]) + +model_type = cols[0].selectbox("模型类别:", model_types) + +model_names = list(regs[model_type].keys()) +model_name = cols[1].selectbox("模型名称:", model_names) + +cur_reg = regs[model_type][model_name]["reg"] + +if model_type == "LLM": + cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) + cur_spec = None + model_formats = [] + for spec in cur_reg["model_specs"]: + if spec["model_format"] not in model_formats: + model_formats.append(spec["model_format"]) + index = 0 + if "pytorch" in model_formats: + index = model_formats.index("pytorch") + model_format = cols[2].selectbox("模型格式:", model_formats, index) + + model_sizes = [] + for spec in cur_reg["model_specs"]: + if (spec["model_format"] == model_format + and spec["model_size_in_billions"] not in model_sizes): + model_sizes.append(spec["model_size_in_billions"]) + model_size = cols[3].selectbox("模型大小:", model_sizes, format_func=lambda x: f"{x}B") + + model_quants = [] + for spec in cur_reg["model_specs"]: + if (spec["model_format"] == model_format + and spec["model_size_in_billions"] == model_size): + model_quants = spec["quantizations"] + model_quant = cols[4].selectbox("模型量化:", model_quants) + if model_quant == "none": + model_quant = None + + for i, spec in enumerate(cur_reg["model_specs"]): + if (spec["model_format"] == model_format + and spec["model_size_in_billions"] == model_size): + cur_spec = cur_family.model_specs[i] + break + cache_dir = get_cache_dir(model_type, model_name, model_format, model_size) +elif model_type == "embedding": + cur_spec = xf_embedding.core.EmbeddingModelSpec.parse_obj(cur_reg) + cache_dir = get_cache_dir(model_type, model_name) +elif model_type == "image": + cur_spec = xf_image.core.ImageModelFamilyV1.parse_obj(cur_reg) + cache_dir = get_cache_dir(model_type, model_name) +elif model_type == "rerank": + cur_spec = xf_rerank.core.RerankModelSpec.parse_obj(cur_reg) + cache_dir = get_cache_dir(model_type, model_name) + +meta_file = get_meta_path( + model_type=model_type, + model_format=model_format, + cache_dir=cache_dir, + model_quant=model_quant) + +if os.path.isdir(cache_dir): + try: + with open(meta_file, encoding="utf-8") as fp: + meta = json.load(fp) + revision = meta.get("revision", meta.get("model_revision")) + except: + revision = "" + if cur_spec.model_revision and cur_spec.model_revision != revision: + revision += " (与 XF 内置版本号不一致)" + else: + revision += " (符合要求)" + + text = (f"模型已缓存。\n\n" + f"缓存路径:{cache_dir}\n\n" + f"原始路径:{os.readlink(cache_dir)[4:]}\n\n" + f"版本号 :{revision}" + ) +else: + text = "模型尚未缓存" + +st.divider() +st.markdown(text) +st.divider() + +cols = st.columns([3, 1]) +local_path = cols[0].text_input("本地模型绝对路径:") + +if cols[1].button("设置模型缓存"): + if os.path.isabs(local_path) and os.path.isdir(local_path): + cur_spec.model_uri = local_path + if os.path.isdir(cache_dir): + os.rmdir(cache_dir) + if model_type == "LLM": + cache_methods[model_type](cur_family, cur_spec, model_quant) + else: + cache_methods[model_type](cur_spec) + if cur_spec.model_revision: + for hub in ["huggingface", "modelscope"]: + meta_file = get_meta_path( + model_type=model_type, + model_format=model_format, + model_hub=hub, + cache_dir=cache_dir, + model_quant=model_quant) + with open(meta_file, "w", encoding="utf-8") as fp: + json.dump({"revision": cur_spec.model_revision}, fp) + st.rerun() + else: + st.error("必须输入存在的绝对路径") + +if cols[1].button("删除模型缓存"): + if os.path.isdir(cache_dir): + os.rmdir(cache_dir) + +if cols[1].button("注册为自定义模型"): + if os.path.isabs(local_path) and os.path.isdir(local_path): + cur_spec.model_uri = local_path + cur_spec.model_revision = None + if model_type == "LLM": + cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}" + cur_family.model_family = "other" + model_definition = cur_family.json(indent=2, ensure_ascii=False) + else: + cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}" + model_definition = cur_spec.json(indent=2, ensure_ascii=False) + client.register_model( + model_type=model_type, + model=model_definition, + persist=True, + ) + st.rerun() + else: + st.error("必须输入存在的绝对路径") +