diff --git a/README-en.md b/README-en.md index 9b44d7d..d3477e6 100644 --- a/README-en.md +++ b/README-en.md @@ -57,8 +57,15 @@ Experience models with larger scale at [Luca](https://luca.cn/).

-# Benchmark +# Benchmark + | HuggingFace | ModelScope | WiseModel | + |-------------|------------|-----------| + |[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16) + |[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32) + |[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16) + |[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16) + |[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32) ## Multi-modal @@ -149,6 +156,24 @@ Launch gradio-based demo using the following command: python demo/gradio_based_demo.py ``` +#### Inference with vLLM (Recommended!) + +* Install vLLM supporting MiniCPM + - vLLM 0.2.2 is adapted to MiniCPM in `inference/vllm`. More vLLM versions will be supported in the future +```shell +pip install inference/vllm +``` + +* Transfer Huggingface Transformers repo to vLLM-MiniCPM repo, where ``, `` are local paths. +```shell +python inference/convert_hf_to_vllmcpm.py --load --save +``` + +* Examples +```shell +cd inference/vllm/examples/infer_cpm +python inference.py --model_path --prompt_path prompts/prompt_final.txt + ## ## LICENSE diff --git a/README.md b/README.md index 992395a..a43893d 100644 --- a/README.md +++ b/README.md @@ -48,10 +48,16 @@ XXXXXX

# 模型下载 + + | HuggingFace | ModelScope | WiseModel | + |-------------|------------|-----------| + |[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16) + |[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32) + |[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16) + |[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16) + |[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32) + - [HuggingFace仓库]() - [ModelScope仓库]() - [XX仓库]()

@@ -184,7 +190,24 @@ todo ```shell python demo/gradio_based_demo.py ``` +#### vLLM推理部署 +* 安装支持MiniCPM的vLLM + - 我们当前支持版本为0.2.2的vLLM,未来将会支持更多版本 +```shell +pip install inference/vllm +``` + +* 将Huggingface Transformers仓库转为vLLM-MiniCPM支持的格式,其中``, ``均为本地路径 +```shell +python inference/convert_hf_to_vllmcpm.py --load --save +``` + +* 测试样例 +```shell +cd inference/vllm/examples/infer_cpm +python inference.py --model_path --prompt_path prompts/prompt_final.txt +```

diff --git a/inference/convert_hf_to_vllmcpm.py b/inference/convert_hf_to_vllmcpm.py new file mode 100644 index 0000000..76f541a --- /dev/null +++ b/inference/convert_hf_to_vllmcpm.py @@ -0,0 +1,91 @@ +import argparse +import json +import os +import shutil +from tqdm import tqdm +from collections import OrderedDict +import torch + +def convert_model(config, ckpt): + # config + config_bmt = OrderedDict( + { + "_dtype": "bf16", + "activate_fn": "silu", + "architectures": [ + "CPMDragonflyForCausalLM" + ], + "model_type": "cpm_dragonfly", + "base": 10000, + "dim_ff": config['intermediate_size'], + "dim_head": config['hidden_size'] // config['num_attention_heads'], + "dim_model": config['hidden_size'], + "dim_model_base": 256, + "dropout_p": 0.0, + "eps": config['rms_norm_eps'], + "init_std": config['initializer_range'], + "num_heads": config['num_attention_heads'], + "num_kv_heads": config['num_key_value_heads'], + "num_layers": config['num_hidden_layers'], + "orig_max_length": 4096, + "pose_prob": 0.0, + "pose_scaling_factor": 1.0, + "qk_norm": False, + "rope_scaling_factor": 1, + "rope_scaling_type": "", + "scale": True, + "scale_depth": config['scale_depth'], + "scale_emb": config['scale_emb'], + "tie_lm_head": True, + "tp": 0, + "transformers_version": "4.35.0", + "vocab_size": config['vocab_size'] + } + ) + + + model_bmt = OrderedDict() + model_bmt["input_embedding.weight"] = ckpt['model.embed_tokens.weight'].contiguous() + model_bmt["encoder.output_layernorm.weight"] = ckpt['model.norm.weight'].contiguous() + for lnum in tqdm(range(config_bmt['num_layers'])): + hf_pfx = f"model.layers.{lnum}" + bmt_pfx = f"encoder.layers.{lnum}" + model_bmt[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = ckpt[f"{hf_pfx}.input_layernorm.weight"].contiguous() + model_bmt[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = ckpt[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = ckpt[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = ckpt[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = ckpt[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = ckpt[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous() + model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = ckpt[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = ckpt[f"{hf_pfx}.mlp.up_proj.weight"].contiguous() + model_bmt[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = ckpt[f"{hf_pfx}.mlp.down_proj.weight"].contiguous() + + + return config_bmt, model_bmt + +def load_model_ckpt(args): + with open(os.path.join(args.load, "config.json"), 'r') as fin: + config = json.load(fin) + ckpt = torch.load(os.path.join(args.load, "pytorch_model.bin")) + + os.makedirs(f"{args.save}", exist_ok=True) + + # model and config + hf_config, hf_ckpt = convert_model(config, ckpt) + with open(os.path.join(args.save, "config.json"), 'w') as fout: + json.dump(hf_config, fout, indent=4) + torch.save(hf_ckpt, f"{args.save}/pytorch_model.pt") + + # tokenizer + shutil.copyfile(f"{args.load}/tokenizer.json", f"{args.save}/tokenizer.json") + shutil.copyfile(f"{args.load}/tokenizer.model", f"{args.save}/tokenizer.model") + shutil.copyfile(f"{args.load}/special_tokens_map.json", f"{args.save}/special_tokens_map.json") + shutil.copyfile(f"{args.load}/tokenizer_config.json", f"{args.save}/tokenizer_config.json") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--load", type=str, default="") + parser.add_argument("--save", type=str, default="") + args = parser.parse_args() + + load_model_ckpt(args) \ No newline at end of file