add: vllm

This commit is contained in:
huangyuxiang03 2024-02-01 10:48:37 +08:00
parent 213716ff0a
commit 79243048a6
3 changed files with 143 additions and 4 deletions

View File

@ -57,8 +57,15 @@ Experience models with larger scale at [Luca](https://luca.cn/).
<p id="3"></p>
# Benchmark
# Benchmark
| HuggingFace | ModelScope | WiseModel |
|-------------|------------|-----------|
|[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
|[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
|[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
|[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
## Multi-modal
@ -149,6 +156,24 @@ Launch gradio-based demo using the following command:
python demo/gradio_based_demo.py
```
#### Inference with vLLM (Recommended!)
* Install vLLM supporting MiniCPM
- vLLM 0.2.2 is adapted to MiniCPM in `inference/vllm`. More vLLM versions will be supported in the future
```shell
pip install inference/vllm
```
* Transfer Huggingface Transformers repo to vLLM-MiniCPM repo, where `<hf_repo_path>`, `<vllmcpm_repo_path>` are local paths.
```shell
python inference/convert_hf_to_vllmcpm.py --load <hf_repo_path> --save <vllmcpm_repo_path>
```
* Examples
```shell
cd inference/vllm/examples/infer_cpm
python inference.py --model_path <vllmcpm_repo_path> --prompt_path prompts/prompt_final.txt
##
## LICENSE

View File

@ -48,10 +48,16 @@ XXXXXX
<p id="2"></p>
# 模型下载
| HuggingFace | ModelScope | WiseModel |
|-------------|------------|-----------|
|[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
|[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
|[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
|[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
[HuggingFace仓库]()
[ModelScope仓库]()
[XX仓库]()
<p id="3"></p>
@ -184,7 +190,24 @@ todo
```shell
python demo/gradio_based_demo.py
```
#### vLLM推理部署
* 安装支持MiniCPM的vLLM
- 我们当前支持版本为0.2.2的vLLM未来将会支持更多版本
```shell
pip install inference/vllm
```
* 将Huggingface Transformers仓库转为vLLM-MiniCPM支持的格式其中`<hf_repo_path>`, `<vllmcpm_repo_path>`均为本地路径
```shell
python inference/convert_hf_to_vllmcpm.py --load <hf_repo_path> --save <vllmcpm_repo_path>
```
* 测试样例
```shell
cd inference/vllm/examples/infer_cpm
python inference.py --model_path <vllmcpm_repo_path> --prompt_path prompts/prompt_final.txt
```
<p id="6"></p>

View File

@ -0,0 +1,91 @@
import argparse
import json
import os
import shutil
from tqdm import tqdm
from collections import OrderedDict
import torch
def convert_model(config, ckpt):
# config
config_bmt = OrderedDict(
{
"_dtype": "bf16",
"activate_fn": "silu",
"architectures": [
"CPMDragonflyForCausalLM"
],
"model_type": "cpm_dragonfly",
"base": 10000,
"dim_ff": config['intermediate_size'],
"dim_head": config['hidden_size'] // config['num_attention_heads'],
"dim_model": config['hidden_size'],
"dim_model_base": 256,
"dropout_p": 0.0,
"eps": config['rms_norm_eps'],
"init_std": config['initializer_range'],
"num_heads": config['num_attention_heads'],
"num_kv_heads": config['num_key_value_heads'],
"num_layers": config['num_hidden_layers'],
"orig_max_length": 4096,
"pose_prob": 0.0,
"pose_scaling_factor": 1.0,
"qk_norm": False,
"rope_scaling_factor": 1,
"rope_scaling_type": "",
"scale": True,
"scale_depth": config['scale_depth'],
"scale_emb": config['scale_emb'],
"tie_lm_head": True,
"tp": 0,
"transformers_version": "4.35.0",
"vocab_size": config['vocab_size']
}
)
model_bmt = OrderedDict()
model_bmt["input_embedding.weight"] = ckpt['model.embed_tokens.weight'].contiguous()
model_bmt["encoder.output_layernorm.weight"] = ckpt['model.norm.weight'].contiguous()
for lnum in tqdm(range(config_bmt['num_layers'])):
hf_pfx = f"model.layers.{lnum}"
bmt_pfx = f"encoder.layers.{lnum}"
model_bmt[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = ckpt[f"{hf_pfx}.input_layernorm.weight"].contiguous()
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = ckpt[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = ckpt[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = ckpt[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = ckpt[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = ckpt[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = ckpt[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = ckpt[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
model_bmt[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = ckpt[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
return config_bmt, model_bmt
def load_model_ckpt(args):
with open(os.path.join(args.load, "config.json"), 'r') as fin:
config = json.load(fin)
ckpt = torch.load(os.path.join(args.load, "pytorch_model.bin"))
os.makedirs(f"{args.save}", exist_ok=True)
# model and config
hf_config, hf_ckpt = convert_model(config, ckpt)
with open(os.path.join(args.save, "config.json"), 'w') as fout:
json.dump(hf_config, fout, indent=4)
torch.save(hf_ckpt, f"{args.save}/pytorch_model.pt")
# tokenizer
shutil.copyfile(f"{args.load}/tokenizer.json", f"{args.save}/tokenizer.json")
shutil.copyfile(f"{args.load}/tokenizer.model", f"{args.save}/tokenizer.model")
shutil.copyfile(f"{args.load}/special_tokens_map.json", f"{args.save}/special_tokens_map.json")
shutil.copyfile(f"{args.load}/tokenizer_config.json", f"{args.save}/tokenizer_config.json")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load", type=str, default="")
parser.add_argument("--save", type=str, default="")
args = parser.parse_args()
load_model_ckpt(args)