diff --git a/README-en.md b/README-en.md
index 9b44d7d..d3477e6 100644
--- a/README-en.md
+++ b/README-en.md
@@ -57,8 +57,15 @@ Experience models with larger scale at [Luca](https://luca.cn/).
-# Benchmark
+# Benchmark
+ | HuggingFace | ModelScope | WiseModel |
+ |-------------|------------|-----------|
+ |[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
+ |[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
+ |[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
+ |[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
+ |[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
## Multi-modal
@@ -149,6 +156,24 @@ Launch gradio-based demo using the following command:
python demo/gradio_based_demo.py
```
+#### Inference with vLLM (Recommended!)
+
+* Install vLLM supporting MiniCPM
+ - vLLM 0.2.2 is adapted to MiniCPM in `inference/vllm`. More vLLM versions will be supported in the future
+```shell
+pip install inference/vllm
+```
+
+* Transfer Huggingface Transformers repo to vLLM-MiniCPM repo, where ``, `` are local paths.
+```shell
+python inference/convert_hf_to_vllmcpm.py --load --save
+```
+
+* Examples
+```shell
+cd inference/vllm/examples/infer_cpm
+python inference.py --model_path --prompt_path prompts/prompt_final.txt
+
##
## LICENSE
diff --git a/README.md b/README.md
index 992395a..a43893d 100644
--- a/README.md
+++ b/README.md
@@ -48,10 +48,16 @@ XXXXXX
# 模型下载
+
+ | HuggingFace | ModelScope | WiseModel |
+ |-------------|------------|-----------|
+ |[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
+ |[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
+ |[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
+ |[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
+ |[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
+
- [HuggingFace仓库]()
- [ModelScope仓库]()
- [XX仓库]()
@@ -184,7 +190,24 @@ todo
```shell
python demo/gradio_based_demo.py
```
+#### vLLM推理部署
+* 安装支持MiniCPM的vLLM
+ - 我们当前支持版本为0.2.2的vLLM,未来将会支持更多版本
+```shell
+pip install inference/vllm
+```
+
+* 将Huggingface Transformers仓库转为vLLM-MiniCPM支持的格式,其中``, ``均为本地路径
+```shell
+python inference/convert_hf_to_vllmcpm.py --load --save
+```
+
+* 测试样例
+```shell
+cd inference/vllm/examples/infer_cpm
+python inference.py --model_path --prompt_path prompts/prompt_final.txt
+```
diff --git a/inference/convert_hf_to_vllmcpm.py b/inference/convert_hf_to_vllmcpm.py
new file mode 100644
index 0000000..76f541a
--- /dev/null
+++ b/inference/convert_hf_to_vllmcpm.py
@@ -0,0 +1,91 @@
+import argparse
+import json
+import os
+import shutil
+from tqdm import tqdm
+from collections import OrderedDict
+import torch
+
+def convert_model(config, ckpt):
+ # config
+ config_bmt = OrderedDict(
+ {
+ "_dtype": "bf16",
+ "activate_fn": "silu",
+ "architectures": [
+ "CPMDragonflyForCausalLM"
+ ],
+ "model_type": "cpm_dragonfly",
+ "base": 10000,
+ "dim_ff": config['intermediate_size'],
+ "dim_head": config['hidden_size'] // config['num_attention_heads'],
+ "dim_model": config['hidden_size'],
+ "dim_model_base": 256,
+ "dropout_p": 0.0,
+ "eps": config['rms_norm_eps'],
+ "init_std": config['initializer_range'],
+ "num_heads": config['num_attention_heads'],
+ "num_kv_heads": config['num_key_value_heads'],
+ "num_layers": config['num_hidden_layers'],
+ "orig_max_length": 4096,
+ "pose_prob": 0.0,
+ "pose_scaling_factor": 1.0,
+ "qk_norm": False,
+ "rope_scaling_factor": 1,
+ "rope_scaling_type": "",
+ "scale": True,
+ "scale_depth": config['scale_depth'],
+ "scale_emb": config['scale_emb'],
+ "tie_lm_head": True,
+ "tp": 0,
+ "transformers_version": "4.35.0",
+ "vocab_size": config['vocab_size']
+ }
+ )
+
+
+ model_bmt = OrderedDict()
+ model_bmt["input_embedding.weight"] = ckpt['model.embed_tokens.weight'].contiguous()
+ model_bmt["encoder.output_layernorm.weight"] = ckpt['model.norm.weight'].contiguous()
+ for lnum in tqdm(range(config_bmt['num_layers'])):
+ hf_pfx = f"model.layers.{lnum}"
+ bmt_pfx = f"encoder.layers.{lnum}"
+ model_bmt[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = ckpt[f"{hf_pfx}.input_layernorm.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = ckpt[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = ckpt[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = ckpt[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = ckpt[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = ckpt[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = ckpt[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = ckpt[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
+ model_bmt[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = ckpt[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
+
+
+ return config_bmt, model_bmt
+
+def load_model_ckpt(args):
+ with open(os.path.join(args.load, "config.json"), 'r') as fin:
+ config = json.load(fin)
+ ckpt = torch.load(os.path.join(args.load, "pytorch_model.bin"))
+
+ os.makedirs(f"{args.save}", exist_ok=True)
+
+ # model and config
+ hf_config, hf_ckpt = convert_model(config, ckpt)
+ with open(os.path.join(args.save, "config.json"), 'w') as fout:
+ json.dump(hf_config, fout, indent=4)
+ torch.save(hf_ckpt, f"{args.save}/pytorch_model.pt")
+
+ # tokenizer
+ shutil.copyfile(f"{args.load}/tokenizer.json", f"{args.save}/tokenizer.json")
+ shutil.copyfile(f"{args.load}/tokenizer.model", f"{args.save}/tokenizer.model")
+ shutil.copyfile(f"{args.load}/special_tokens_map.json", f"{args.save}/special_tokens_map.json")
+ shutil.copyfile(f"{args.load}/tokenizer_config.json", f"{args.save}/tokenizer_config.json")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--load", type=str, default="")
+ parser.add_argument("--save", type=str, default="")
+ args = parser.parse_args()
+
+ load_model_ckpt(args)
\ No newline at end of file