mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 21:03:39 +08:00
add: vllm
This commit is contained in:
parent
213716ff0a
commit
79243048a6
27
README-en.md
27
README-en.md
@ -57,8 +57,15 @@ Experience models with larger scale at [Luca](https://luca.cn/).
|
||||
|
||||
<p id="3"></p>
|
||||
|
||||
# Benchmark
|
||||
# Benchmark
|
||||
|
||||
| HuggingFace | ModelScope | WiseModel |
|
||||
|-------------|------------|-----------|
|
||||
|[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
|
||||
|[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|
||||
|[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
|
||||
|[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
|
||||
|[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|
||||
|
||||
## Multi-modal
|
||||
|
||||
@ -149,6 +156,24 @@ Launch gradio-based demo using the following command:
|
||||
python demo/gradio_based_demo.py
|
||||
```
|
||||
|
||||
#### Inference with vLLM (Recommended!)
|
||||
|
||||
* Install vLLM supporting MiniCPM
|
||||
- vLLM 0.2.2 is adapted to MiniCPM in `inference/vllm`. More vLLM versions will be supported in the future
|
||||
```shell
|
||||
pip install inference/vllm
|
||||
```
|
||||
|
||||
* Transfer Huggingface Transformers repo to vLLM-MiniCPM repo, where `<hf_repo_path>`, `<vllmcpm_repo_path>` are local paths.
|
||||
```shell
|
||||
python inference/convert_hf_to_vllmcpm.py --load <hf_repo_path> --save <vllmcpm_repo_path>
|
||||
```
|
||||
|
||||
* Examples
|
||||
```shell
|
||||
cd inference/vllm/examples/infer_cpm
|
||||
python inference.py --model_path <vllmcpm_repo_path> --prompt_path prompts/prompt_final.txt
|
||||
|
||||
##
|
||||
|
||||
## LICENSE
|
||||
|
||||
29
README.md
29
README.md
@ -48,10 +48,16 @@ XXXXXX
|
||||
<p id="2"></p>
|
||||
|
||||
# 模型下载
|
||||
|
||||
| HuggingFace | ModelScope | WiseModel |
|
||||
|-------------|------------|-----------|
|
||||
|[sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)|[sft-bf16](https://modelscope.cn/models/OpenBMB/miniCPM-bf16)|[sft-bf16](https://wisemodel.cn/models/OpenBMB/miniCPM-bf16)
|
||||
|[sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|[sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32)|[sft-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|
||||
|[dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16)|[dpo-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16/summary)|[dpo-bf16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-bf16)
|
||||
|[dpo-fp16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16)|[dpo-fp16](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16/)|[dpo-fp16](https://wisemodel.cn/models/OpenBMB/MiniCPM-2B-dpo-fp16)
|
||||
|[dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32)|[dpo-fp32](https://wisemodel.cn/models/OpenBMB/miniCPM-dpo-fp32)
|
||||
|
||||
|
||||
[HuggingFace仓库]()
|
||||
[ModelScope仓库]()
|
||||
[XX仓库]()
|
||||
|
||||
|
||||
<p id="3"></p>
|
||||
@ -184,7 +190,24 @@ todo
|
||||
```shell
|
||||
python demo/gradio_based_demo.py
|
||||
```
|
||||
#### vLLM推理部署
|
||||
|
||||
* 安装支持MiniCPM的vLLM
|
||||
- 我们当前支持版本为0.2.2的vLLM,未来将会支持更多版本
|
||||
```shell
|
||||
pip install inference/vllm
|
||||
```
|
||||
|
||||
* 将Huggingface Transformers仓库转为vLLM-MiniCPM支持的格式,其中`<hf_repo_path>`, `<vllmcpm_repo_path>`均为本地路径
|
||||
```shell
|
||||
python inference/convert_hf_to_vllmcpm.py --load <hf_repo_path> --save <vllmcpm_repo_path>
|
||||
```
|
||||
|
||||
* 测试样例
|
||||
```shell
|
||||
cd inference/vllm/examples/infer_cpm
|
||||
python inference.py --model_path <vllmcpm_repo_path> --prompt_path prompts/prompt_final.txt
|
||||
```
|
||||
|
||||
<p id="6"></p>
|
||||
|
||||
|
||||
91
inference/convert_hf_to_vllmcpm.py
Normal file
91
inference/convert_hf_to_vllmcpm.py
Normal file
@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
|
||||
def convert_model(config, ckpt):
|
||||
# config
|
||||
config_bmt = OrderedDict(
|
||||
{
|
||||
"_dtype": "bf16",
|
||||
"activate_fn": "silu",
|
||||
"architectures": [
|
||||
"CPMDragonflyForCausalLM"
|
||||
],
|
||||
"model_type": "cpm_dragonfly",
|
||||
"base": 10000,
|
||||
"dim_ff": config['intermediate_size'],
|
||||
"dim_head": config['hidden_size'] // config['num_attention_heads'],
|
||||
"dim_model": config['hidden_size'],
|
||||
"dim_model_base": 256,
|
||||
"dropout_p": 0.0,
|
||||
"eps": config['rms_norm_eps'],
|
||||
"init_std": config['initializer_range'],
|
||||
"num_heads": config['num_attention_heads'],
|
||||
"num_kv_heads": config['num_key_value_heads'],
|
||||
"num_layers": config['num_hidden_layers'],
|
||||
"orig_max_length": 4096,
|
||||
"pose_prob": 0.0,
|
||||
"pose_scaling_factor": 1.0,
|
||||
"qk_norm": False,
|
||||
"rope_scaling_factor": 1,
|
||||
"rope_scaling_type": "",
|
||||
"scale": True,
|
||||
"scale_depth": config['scale_depth'],
|
||||
"scale_emb": config['scale_emb'],
|
||||
"tie_lm_head": True,
|
||||
"tp": 0,
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": config['vocab_size']
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
model_bmt = OrderedDict()
|
||||
model_bmt["input_embedding.weight"] = ckpt['model.embed_tokens.weight'].contiguous()
|
||||
model_bmt["encoder.output_layernorm.weight"] = ckpt['model.norm.weight'].contiguous()
|
||||
for lnum in tqdm(range(config_bmt['num_layers'])):
|
||||
hf_pfx = f"model.layers.{lnum}"
|
||||
bmt_pfx = f"encoder.layers.{lnum}"
|
||||
model_bmt[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = ckpt[f"{hf_pfx}.input_layernorm.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = ckpt[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = ckpt[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = ckpt[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = ckpt[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = ckpt[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = ckpt[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = ckpt[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
|
||||
model_bmt[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = ckpt[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
|
||||
|
||||
|
||||
return config_bmt, model_bmt
|
||||
|
||||
def load_model_ckpt(args):
|
||||
with open(os.path.join(args.load, "config.json"), 'r') as fin:
|
||||
config = json.load(fin)
|
||||
ckpt = torch.load(os.path.join(args.load, "pytorch_model.bin"))
|
||||
|
||||
os.makedirs(f"{args.save}", exist_ok=True)
|
||||
|
||||
# model and config
|
||||
hf_config, hf_ckpt = convert_model(config, ckpt)
|
||||
with open(os.path.join(args.save, "config.json"), 'w') as fout:
|
||||
json.dump(hf_config, fout, indent=4)
|
||||
torch.save(hf_ckpt, f"{args.save}/pytorch_model.pt")
|
||||
|
||||
# tokenizer
|
||||
shutil.copyfile(f"{args.load}/tokenizer.json", f"{args.save}/tokenizer.json")
|
||||
shutil.copyfile(f"{args.load}/tokenizer.model", f"{args.save}/tokenizer.model")
|
||||
shutil.copyfile(f"{args.load}/special_tokens_map.json", f"{args.save}/special_tokens_map.json")
|
||||
shutil.copyfile(f"{args.load}/tokenizer_config.json", f"{args.save}/tokenizer_config.json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--load", type=str, default="")
|
||||
parser.add_argument("--save", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model_ckpt(args)
|
||||
Loading…
x
Reference in New Issue
Block a user