mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
mlx inference
This commit is contained in:
parent
a1013b1ad2
commit
9e1438682e
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,3 +2,7 @@
|
|||||||
*.pyc
|
*.pyc
|
||||||
finetune/output/*
|
finetune/output/*
|
||||||
wip.*
|
wip.*
|
||||||
|
.idea
|
||||||
|
venv
|
||||||
|
.venv
|
||||||
|
.env
|
||||||
@ -488,6 +488,12 @@ python demo/vllm_based_demo.py --model_path <vllmcpm_repo_path>
|
|||||||
python demo/hf_based_demo.py --model_path <hf_repo_path>
|
python demo/hf_based_demo.py --model_path <hf_repo_path>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 使用如下命令启动基于 Mac mlx 加速框架推理
|
||||||
|
|
||||||
|
你需要安装 `mlx_lm` 库,并且,你需要下载对应的转换后的专用模型权重[MiniCPM-2B-sft-bf16-llama-format-mlx](https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx),然后运行以下命令:
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.generate --model mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx --prompt "hello, tell me a joke." --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
<p id="6"></p>
|
<p id="6"></p>
|
||||||
|
|
||||||
|
|||||||
42
demo/mlx_based_demo.py
Normal file
42
demo/mlx_based_demo.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
使用 MLX 快速推理 MiniCPM
|
||||||
|
|
||||||
|
如果你使用 Mac 设备进行推理,可以直接使用MLX进行推理。
|
||||||
|
由于 MiniCPM 暂时不支持 mlx 格式转换。您可以下载由 MLX 社群转换好的模型 [MiniCPM-2B-sft-bf16-llama-format-mlx](https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx)。
|
||||||
|
|
||||||
|
并安装对应的依赖包
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
这是一个简单的推理代码,使用 Mac 设备推理 MiniCPM-2
|
||||||
|
```python
|
||||||
|
python -m mlx_lm.generate --model mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx --prompt "hello, tell me a joke." --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mlx_lm import load, generate
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
def chat_with_model():
|
||||||
|
model, tokenizer = load("mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx")
|
||||||
|
print("Model loaded. Start chatting! (Type 'quit' to stop)")
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
chat_template = Template(
|
||||||
|
"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("You: ")
|
||||||
|
if user_input.lower() == 'quit':
|
||||||
|
break
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
response = generate(model, tokenizer, prompt=chat_template.render(messages=messages), verbose=True)
|
||||||
|
print("Model:", response)
|
||||||
|
messages.append({"role": "ai", "content": response})
|
||||||
|
|
||||||
|
|
||||||
|
chat_with_model()
|
||||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
transformers>=4.38.2
|
||||||
|
torch>=2.0.0
|
||||||
|
triton>=2.2.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
gradio>=4.21.0
|
||||||
|
flash_attn>=2.4.1
|
||||||
|
accelerate>=0.28.0
|
||||||
|
sentence_transformers>=2.6.0
|
||||||
|
sse_starlette>=2.0.0
|
||||||
|
tiktoken>=0.6.0
|
||||||
|
mlx_lm>=0.5.0
|
||||||
Loading…
x
Reference in New Issue
Block a user