JoeJoe1313
commited on
Commit
·
0361bfb
1
Parent(s):
5e0609b
posts
Browse files- src/posts/2025-02-12-fine-tuning-lora-mlx/images/lora.jpg +3 -0
- src/posts/2025-02-12-fine-tuning-lora-mlx/index.qmd +234 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/input.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_1.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_2.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_3.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_4.png +3 -0
- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/index.qmd +341 -0
src/posts/2025-02-12-fine-tuning-lora-mlx/images/lora.jpg
ADDED
|
Git LFS Details
|
src/posts/2025-02-12-fine-tuning-lora-mlx/index.qmd
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: "Fine-Tuning LLMs with LoRA and MLX-LM"
|
| 3 |
+
author: "Joana Levtcheva"
|
| 4 |
+
date: "2025-02-12"
|
| 5 |
+
categories: [Machine Learning, mlx, llm]
|
| 6 |
+
draft: false
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
This blog post is going to be a tutorial on how to fine-tune a LLM with LoRA and the `mlx-lm` package. Medium post can be found [here](https://medium.com/@levchevajoana/fine-tuning-llms-with-lora-and-mlx-lm-c0b143642deb) and Substack [here](https://substack.com/home/post/p-157008884).
|
| 10 |
+
|
| 11 |
+
## Introduction
|
| 12 |
+
|
| 13 |
+
[MLX](https://opensource.apple.com/projects/mlx/) is an array framework tailored for efficient machine learning research on Apple silicon. Its biggest strength is that it leverages the unified memory architecture of Apple devices and offers a familiar, NumPy-like API. Apple has also developed a package for LLM text generation, fine-tuning, etc. called [MLX LM](https://github.com/ml-explore/mlx-examples/blob/main/llms/README.md).
|
| 14 |
+
|
| 15 |
+
Overall, `mlx-lm` supports many of Hugging Face format LLMs. With `mlx-lm` it is also very easy to directly load models from the Hugging Face [MLX Community](https://huggingface.co/mlx-community). This is a place for mlx model pre-converted weights that run on Apple Silicon, hosting many ready-to-use models with the framework. The framework also supports parameter-efficient fine-tuning ([PEFT](https://huggingface.co/blog/peft)) with [LoRA and QLoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora). You can find more information about LoRA in the following [paper](https://arxiv.org/abs/2106.09685).
|
| 16 |
+
|
| 17 |
+
In this tutorial, with the help of the `mlx-lm` package, we are going to load the [Mistral-7B-Instruct-v0.3–4bit](https://medium.com/r/?url=https%3A%2F%2Fhuggingface.co%2Fmlx-community%2FMistral-7B-Instruct-v0.3-4bit) model from the MLX Community space, and attempt to fine-tune it with LoRA and the dataset [win-wang/Machine_Learning_QA_Collection](https://medium.com/r/?url=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fwin-wang%2FMachine_Learning_QA_Collection). Let's begin.
|
| 18 |
+
|
| 19 |
+
## Packages and Model Loading
|
| 20 |
+
First, we have to load the needed packages.
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
import mlx.optimizers as optim
|
| 29 |
+
from mlx.utils import tree_flatten
|
| 30 |
+
from mlx_lm import generate, load
|
| 31 |
+
from mlx_lm.tuner import TrainingArgs, datasets, linear_to_lora_layers, train
|
| 32 |
+
from transformers import PreTrainedTokenizer
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Then, we should load the model and tokenizer.
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
model_path = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
| 39 |
+
model, tokenizer = load(model_path)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Let's see what would our model output when given a simple pormpt such as *"What is fine-tuning in machine learning?"*.
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
prompt = "What is fine-tuning in machine learning?"
|
| 46 |
+
messages = [{"role": "user", "content": prompt}]
|
| 47 |
+
prompt = tokenizer.apply_chat_template(
|
| 48 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 49 |
+
)
|
| 50 |
+
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
The generated output of the model is:
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
Fine-tuning in machine learning refers to the process of taking a pre-trained model, which has already been trained on a large dataset for a specific task, and adapting it to a new, related task or a different aspect of the same task.
|
| 57 |
+
|
| 58 |
+
For example, imagine you have a pre-trained model that can recognize different types of animals. You can fine-tune this model to recognize specific breeds of dogs, or even to recognize different types of flowers. The idea is that the pre-trained model has already learned some general features that are useful for the new task, and fine-tuning helps the model to learn the specific details that are important for the new task.
|
| 59 |
+
|
| 60 |
+
Fine-tuning is often used when you have a small dataset for the new task, as it allows you to leverage the knowledge the model has already gained from the large pre-training dataset. It's a common technique in deep learning, particularly for tasks like image classification, natural language processing, and speech recognition.
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Preparation for Fine-Tuning
|
| 64 |
+
|
| 65 |
+
Let's create an `adapters` directory, and the paths to the adapter configuration (in our case the LoRA configuration) and adapter files.
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
adapter_path = "adapters"
|
| 69 |
+
os.makedirs(adapter_path, exist_ok=True)
|
| 70 |
+
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
|
| 71 |
+
adapter_file_path = os.path.join(adapter_path, "adapters.safetensors")
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
We have to set our LoRA parameter configurations. This can be done in a separate `.yml` file, as shown [here](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/lora_config.yaml), but for code simplicity and the sake of just showing the process of fine-tuning with LoRA and mlx-lm, we are going to stick to this simple in-code configuration
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
lora_config = {
|
| 78 |
+
"num_layers": 8,
|
| 79 |
+
"lora_parameters": {
|
| 80 |
+
"rank": 8,
|
| 81 |
+
"scale": 20.0,
|
| 82 |
+
"dropout": 0.0,
|
| 83 |
+
},
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
which we save into the adapters directory we already created.
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
with open(adapter_config_path, "w") as f:
|
| 91 |
+
json.dump(lora_config, f, indent=4)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
We can also set our training arguments, pointing to our adapter file, how many iterations we want to perform, and how many steps per evaluation should be done.
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
training_args = TrainingArgs(
|
| 98 |
+
adapter_file=adapter_file_path,
|
| 99 |
+
iters=200,
|
| 100 |
+
steps_per_eval=50,
|
| 101 |
+
)
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
In the LoRA framework, most of the model's original parameters remain unchanged during fine-tuning. The `model.freeze()` command is used to set these parameters to a non-trainable state so that their weights aren't updated during backpropagation. This way, only the newly introduced low-rank adaptation matrices (LoRA parameters) are optimized, reducing computational overhead and memory usage while preserving the original model's knowledge.
|
| 105 |
+
|
| 106 |
+
The `linear_to_lora_layers` function converts or wraps some of the model's linear layers into LoRA layers. Essentially, it replaces (or augments) selected linear layers with their LoRA counterparts, which include the additional low-rank matrices that will be trained. The configuration parameters (like the number of layers and specific LoRA parameters) determine which layers are modified and how the LoRA adapters are set up.
|
| 107 |
+
|
| 108 |
+
We should also verify that only a small subset of parameters are set for training, and activate training mode while preserving the frozen state of the main model parameters.
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
model.freeze()
|
| 112 |
+
linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])
|
| 113 |
+
num_train_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
|
| 114 |
+
print(f"Number of trainable parameters: {num_train_params}")
|
| 115 |
+
model.train()
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
We can also create a class to follow the train and validation loss metrics during the training process
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
class Metrics:
|
| 122 |
+
def __init__(self) -> None:
|
| 123 |
+
self.train_losses: List[Tuple[int, float]] = []
|
| 124 |
+
self.val_losses: List[Tuple[int, float]] = []
|
| 125 |
+
|
| 126 |
+
def on_train_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
|
| 127 |
+
self.train_losses.append((info["iteration"], info["train_loss"]))
|
| 128 |
+
|
| 129 |
+
def on_val_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
|
| 130 |
+
self.val_losses.append((info["iteration"], info["val_loss"]))
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
and create an instance of this class.
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
metrics = Metrics()
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## Data Loading
|
| 140 |
+
|
| 141 |
+
Here, we are creating a simplified variant of the following [function](https://github.com/ml-explore/mlx-examples/blob/ec30dc35382d87614f51fe7590f015f93a491bfd/llms/mlx_lm/tuner/datasets.py#L163-L187) for loading a Hugging Face dataset.
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
def custom_load_hf_dataset(
|
| 145 |
+
data_id: str,
|
| 146 |
+
tokenizer: PreTrainedTokenizer,
|
| 147 |
+
names: Tuple[str, str, str] = ("train", "valid", "test"),
|
| 148 |
+
):
|
| 149 |
+
from datasets import exceptions, load_dataset
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
dataset = load_dataset(data_id)
|
| 153 |
+
|
| 154 |
+
train, valid, test = [
|
| 155 |
+
(
|
| 156 |
+
datasets.create_dataset(dataset[n], tokenizer)
|
| 157 |
+
if n in dataset.keys()
|
| 158 |
+
else []
|
| 159 |
+
)
|
| 160 |
+
for n in names
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
except exceptions.DatasetNotFoundError:
|
| 164 |
+
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
|
| 165 |
+
|
| 166 |
+
return train, valid, test
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Then, let's load the `win-wang/Machine_Learning_QA_Collection` dataset from Hugging Face.
|
| 170 |
+
|
| 171 |
+
```python
|
| 172 |
+
train_set, val_set, test_set = custom_load_hf_dataset(
|
| 173 |
+
data_id="win-wang/Machine_Learning_QA_Collection",
|
| 174 |
+
tokenizer=tokenizer,
|
| 175 |
+
names=("train", "validation", "test"),
|
| 176 |
+
)
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
## Fine-Tuning
|
| 180 |
+
|
| 181 |
+
Finally, we can begin the LoRA fine-tuning process by calling the `train()` function.
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
train(
|
| 185 |
+
model=model,
|
| 186 |
+
tokenizer=tokenizer,
|
| 187 |
+
args=training_args,
|
| 188 |
+
optimizer=optim.Adam(learning_rate=1e-5),
|
| 189 |
+
train_dataset=train_set,
|
| 190 |
+
val_dataset=val_set,
|
| 191 |
+
training_callback=metrics,
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
After the training is completed, we can also plot the train and validation loss.
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
train_its, train_losses = zip(*metrics.train_losses)
|
| 199 |
+
validation_its, validation_losses = zip(*metrics.val_losses)
|
| 200 |
+
plt.plot(train_its, train_losses, "-o", label="Train")
|
| 201 |
+
plt.plot(validation_its, validation_losses, "-o", label="Validation")
|
| 202 |
+
plt.xlabel("Iteration")
|
| 203 |
+
plt.ylabel("Loss")
|
| 204 |
+
plt.legend()
|
| 205 |
+
plt.show()
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
For example, one of the trainings performed resulted in the following losses.
|
| 209 |
+
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
## Test the model_lora
|
| 213 |
+
|
| 214 |
+
Now, we can load the fine-tuned model, specifying the `adapter_path`,
|
| 215 |
+
|
| 216 |
+
```python
|
| 217 |
+
model_lora, _ = load(model_path, adapter_path=adapter_path)
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
and we can generate an output for the same prompt as earlier.
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
response = generate(model_lora, tokenizer, prompt=prompt, verbose=True)
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
The generated response is:
|
| 227 |
+
|
| 228 |
+
```
|
| 229 |
+
Fine-tuning in machine learning refers to the process of adjusting the parameters of a pre-trained model to adapt it to a specific task or dataset. This approach is often used when the available data is limited, as it allows the model to leverage the knowledge it has already gained from previous training. Fine-tuning can improve the performance of a model on a new task, making it a valuable technique in many machine learning applications.
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
## Conclusion
|
| 233 |
+
|
| 234 |
+
In this tutorial, we explored how to leverage MLX LM and LoRA for fine-tuning large language models on Apple silicon. We started by setting up the necessary environment, loading a pre-trained model from the MLX Community, and preparing our dataset from Hugging Face. By converting selected linear layers into LoRA adapters and freezing the majority of the model's weights, we efficiently fine-tuned the model using a modest computational footprint. This approach not only optimizes resource usage but also opens the door to experimenting with different fine-tuning strategies and datasets. Further modifications can be explored, such as experimenting with other adapter configurations like QLoRA (extends the LoRA approach by integrating quantization techniques), fusing adapters, integrating additional evaluation metrics to better understand a model's performance, etc. Happy fine-tuning!
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/input.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_1.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_2.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_3.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/images/output_4.png
ADDED
|
Git LFS Details
|
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/index.qmd
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: "Qwen2.5-vl with MLX-VLM"
|
| 3 |
+
date: "2025-02-13"
|
| 4 |
+
categories: [Machine Learning, mlx, vlm]
|
| 5 |
+
draft: false
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
In this post, we are going to show a tutorial on using the [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) model with [MLX-VLM](https://github.com/Blaizzy/mlx-vlm) for visual understanding tasks. We are going to cover:
|
| 9 |
+
|
| 10 |
+
- Loading the model and image
|
| 11 |
+
- Generating a natural language description of an image
|
| 12 |
+
- Perform object detection in different scenarios with outputing their bounding boxes in JSON format
|
| 13 |
+
- Visualizing the results
|
| 14 |
+
|
| 15 |
+
Medium post can be found [here](https://medium.com/@levchevajoana/qwen2-5-vl-with-mlx-vlm-c4329b40ab87) and Substack [here](https://substack.com/home/post/p-157062287).
|
| 16 |
+
|
| 17 |
+
# Introduction
|
| 18 |
+
|
| 19 |
+
[Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) is the latest flagship vision-language model from the Qwen series, representing a significant advancement over its predecessor, [Qwen2-VL](https://arxiv.org/abs/2409.12191). This model is designed to enhance visual understanding and interaction capabilities across various domains. Key features of Qwen2.5-VL include:
|
| 20 |
+
|
| 21 |
+
- **Enhanced Visual Recognition:** The model excels at identifying a wide range of objects, including plants, animals, landmarks, and products. It also proficiently analyzes texts, charts, icons, graphics, and layouts within images.
|
| 22 |
+
- **Agentic Abilities:** Qwen2.5-VL functions as a visual agent capable of reasoning and dynamically directing tools, enabling operations on devices like computers and mobile phones.
|
| 23 |
+
- **Advanced Video Comprehension:** The model can understand lengthy videos exceeding one hour and can pinpoint specific events by identifying relevant video segments.
|
| 24 |
+
- **Accurate Visual Localization:** It can precisely locate objects within images by generating bounding boxes or points and provides structured JSON outputs detailing absolute coordinates and attributes.
|
| 25 |
+
- **Structured Data Output:** Qwen2.5-VL supports the generation of structured outputs from data such as scanned invoices, forms, and tables, benefiting applications in finance and commerce.
|
| 26 |
+
|
| 27 |
+
Performance evaluations indicate that the flagship model, Qwen2.5-VL-72B-Instruct, delivers competitive results across various benchmarks, including college-level problem-solving, mathematics, document comprehension, general question answering, and video understanding. Notably, it demonstrates significant strengths in interpreting documents and diagrams and operates effectively as a visual agent without the need for task-specific fine-tuning.
|
| 28 |
+
|
| 29 |
+
For developers and users interested in exploring Qwen2.5-VL, both base and instruct models are available in 3B, 7B, and 72B parameter sizes on platforms like Hugging Face. Additionally, the model can be used through [Qwen Chat](https://chat.qwenlm.ai).
|
| 30 |
+
|
| 31 |
+
# Tutorial
|
| 32 |
+
|
| 33 |
+
## Loading Packages
|
| 34 |
+
|
| 35 |
+
We begin by importing the necessary libraries. We are going to use the `mlx_vlm` package to load and operate with our Qwen2.5-VL model. We also use libraries such as matplotlib for plotting and PIL for image processing.
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
import json
|
| 39 |
+
|
| 40 |
+
import matplotlib.patches as patches
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import numpy as np
|
| 43 |
+
from mlx_vlm import apply_chat_template, generate, load
|
| 44 |
+
from mlx_vlm.utils import load_image
|
| 45 |
+
from PIL import Image
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Loading the Qwen2.5-VL Model and Processor
|
| 49 |
+
|
| 50 |
+
Next, we load the pre-trained [Qwen2.5-VL-3B-Instruct-bf16](https://huggingface.co/mlx-community/Qwen2.5-VL-3B-Instruct-bf16) model from the Hugging Face [MLX Community](https://huggingface.co/mlx-community) along with its processor using the provided model path. The processor formats and preprocesses both text and image inputs to ensure they are compatible with the model’s architecture.
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
model_path = "mlx-community/Qwen2.5-VL-3B-Instruct-bf16"
|
| 54 |
+
model, processor = load(model_path)
|
| 55 |
+
config = model.config
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
You’ll notice the loading process involves fetching several files if the model hasn’t been downloaded previously. Once completed, the model is ready to process our inputs.
|
| 59 |
+
|
| 60 |
+
## Loading and Displaying the Image
|
| 61 |
+
|
| 62 |
+
For this tutorial, we use an image file (`person_dog.jpg`) which contains a person with a dog. We load the image using a helper function and then display its size.
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
image_path = "person_dog.jpg"
|
| 66 |
+
image = load_image(image_path)
|
| 67 |
+
print(image)
|
| 68 |
+
print(image.size) # Example output: (467, 700)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
The input image is shown below.
|
| 72 |
+
|
| 73 |
+
{ style="display: block; margin: 0 auto"}
|
| 74 |
+
|
| 75 |
+
## Generating an Image Description
|
| 76 |
+
|
| 77 |
+
We now prepare a prompt to describe the image. The prompt is wrapped using the `apply_chat_template` function, which converts our query into the chat-based format expected by the model.
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
prompt = "Describe the image."
|
| 81 |
+
formatted_prompt = apply_chat_template(
|
| 82 |
+
processor, config, prompt, num_images=1
|
| 83 |
+
)
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Next, we generate the output by feeding both the formatted prompt and image into the model:
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
output = generate(model, processor, formatted_prompt, image, verbose=True)
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
**Sample Output:**
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
The image shows a person standing outdoors, holding a small, fluffy, light-colored dog. The person is wearing a dark gray hoodie with the word "ROX" on it and blue jeans. The background features a garden with various plants and a fence, and there are some fallen leaves on the ground. The setting appears to be a residential area with a garden.
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
This demonstrates how the model can effectively generate descriptive captions for images.
|
| 99 |
+
|
| 100 |
+
## Object Detection with Bounding Boxes
|
| 101 |
+
|
| 102 |
+
In addition to descriptions, the Qwen2.5-VL model can help us obtain spatial details such as bounding box coordinates for detected objects. We prepare a prompt asking the model to outline each object’s position in JSON format. We include the system prompt *“You are a helpful assistant"*, the user prompt describing the task *“Outline the position of each object and output all the bbox coordinates in JSON format.”*, and the path to the input image.
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
system_prompt="You are a helpful assistant"
|
| 106 |
+
prompt="Outline the position of ecah object and output all the bbox coordinates in JSON format."
|
| 107 |
+
messages = [
|
| 108 |
+
{
|
| 109 |
+
"role": "system",
|
| 110 |
+
"content": system_prompt
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"role": "user",
|
| 114 |
+
"content": [
|
| 115 |
+
{
|
| 116 |
+
"type": "text",
|
| 117 |
+
"text": prompt
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"type": "image",
|
| 121 |
+
"image": image_path,
|
| 122 |
+
}
|
| 123 |
+
]
|
| 124 |
+
}
|
| 125 |
+
]
|
| 126 |
+
prompt = apply_chat_template(processor, config, messages, tokenize=False)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
We then generate the spatial output:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
output = generate(
|
| 133 |
+
model,
|
| 134 |
+
processor,
|
| 135 |
+
prompt,
|
| 136 |
+
image,
|
| 137 |
+
verbose=True
|
| 138 |
+
)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**Sample JSON Output:**
|
| 142 |
+
|
| 143 |
+
`````markdown
|
| 144 |
+
```json
|
| 145 |
+
[
|
| 146 |
+
{
|
| 147 |
+
"bbox_2d": [170, 105, 429, 699],
|
| 148 |
+
"label": "person holding dog"
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"bbox_2d": [180, 158, 318, 504],
|
| 152 |
+
"label": "dog"
|
| 153 |
+
}
|
| 154 |
+
]
|
| 155 |
+
```
|
| 156 |
+
`````
|
| 157 |
+
|
| 158 |
+
This output provides the absolute coordinates of the bounding boxes around the detected objects along with the corresponding label. We should note that the absolute coordinates are with respect to:
|
| 159 |
+
|
| 160 |
+
- The beginning of the coordinate system which is top left.
|
| 161 |
+
- The image size corresponding to the possibly resized image after it’s processed via the `processor`. We can determine the new size by checking the `image_grid_thw` value in
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
processor.image_processor(image)
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
and the `patch_size` value from `processor`. Then we simply multiply the height and width values of `image_grid_thw` with the `ptach_size`, which by default is $14$. Thus, the adjusted bounding box coordinates can be determined by scaling with the original image size divided by the image size after it’s processed by the `processor`. The code can be seen in the next section in the function `normalize_bbox(processor, image, x_min, y_min, x_max, y_max)`.
|
| 168 |
+
|
| 169 |
+
**Observations:**
|
| 170 |
+
|
| 171 |
+
1. Most of the time the model produces an identical format of the JSON, with the same key-value pairs. If the user prompt didn’t include the word bbox before the word coordinates, the model sometimes produced slightly different key names and/or structure.
|
| 172 |
+
|
| 173 |
+
2. I achieved accurate and identical JSON outputs when using [mlx-community/Qwen2.5-VL-3B-Instruct-8bit](https://huggingface.co/mlx-community/Qwen2.5-VL-3B-Instruct-8bit) and [mlx-community/Qwen2.5-VL-3B-Instruct-bf16](https://huggingface.co/mlx-community/Qwen2.5-VL-3B-Instruct-bf16). In contrast, when I experimented with [mlx-community/Qwen2.5-VL-7B-Instruct-6bit](https://huggingface.co/mlx-community/Qwen2.5-VL-7B-Instruct-6bit) and [mlx-community/Qwen2.5-VL-7B-Instruct-8bit](https://huggingface.co/mlx-community/Qwen2.5-VL-7B-Instruct-8bit) the generated bounding box coordinates seemed to be shifted along the $y$-axis, but otherwise matched the dimensions of the bounding boxes generated with 3B models.
|
| 174 |
+
|
| 175 |
+
## Visualizing the Bounding Boxes
|
| 176 |
+
|
| 177 |
+
To better understand the spatial outputs, we can visualize these bounding boxes on the image. Below are helper functions that:
|
| 178 |
+
|
| 179 |
+
- Parse the JSON output
|
| 180 |
+
|
| 181 |
+
```python
|
| 182 |
+
def parse_bbox(bbox_str):
|
| 183 |
+
return json.loads(bbox_str.replace("```json", "").replace("```", ""))
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
- Normalize bounding box coordinates to match the image dimensions
|
| 187 |
+
|
| 188 |
+
```python
|
| 189 |
+
def normalize_bbox(processor, image, x_min, y_min, x_max, y_max):
|
| 190 |
+
width, height = image.size
|
| 191 |
+
_, input_height, input_width = (
|
| 192 |
+
processor.image_processor(image)["image_grid_thw"][0] * 14
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
x_min_norm = int(x_min / input_width * width)
|
| 196 |
+
y_min_norm = int(y_min / input_height * height)
|
| 197 |
+
x_max_norm = int(x_max / input_width * width)
|
| 198 |
+
y_max_norm = int(y_max / input_height * height)
|
| 199 |
+
|
| 200 |
+
return x_min_norm, y_min_norm, x_max_norm, y_max_norm
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
- Plot the image with rectangles and labels
|
| 204 |
+
|
| 205 |
+
```python
|
| 206 |
+
def plot_image_with_bboxes(processor, image, bboxes):
|
| 207 |
+
image = Image.open(image) if isinstance(image, str) else image
|
| 208 |
+
_, ax = plt.subplots(1)
|
| 209 |
+
ax.imshow(image)
|
| 210 |
+
|
| 211 |
+
if isinstance(bboxes, list) and all(isinstance(bbox, dict) for bbox in bboxes):
|
| 212 |
+
colors = plt.cm.rainbow(np.linspace(0, 1, len(bboxes)))
|
| 213 |
+
|
| 214 |
+
for i, (bbox, color) in enumerate(zip(bboxes, colors)):
|
| 215 |
+
label = bbox.get("label", None)
|
| 216 |
+
x_min, y_min, x_max, y_max = bbox.get("bbox_2d", None)
|
| 217 |
+
|
| 218 |
+
x_min_norm, y_min_norm, x_max_norm, y_max_norm = normalize_bbox(
|
| 219 |
+
processor, image, x_min, y_min, x_max, y_max
|
| 220 |
+
)
|
| 221 |
+
width = x_max_norm - x_min_norm
|
| 222 |
+
height = y_max_norm - y_min_norm
|
| 223 |
+
|
| 224 |
+
rect = patches.Rectangle(
|
| 225 |
+
(x_min_norm, y_min_norm),
|
| 226 |
+
width,
|
| 227 |
+
height,
|
| 228 |
+
linewidth=2,
|
| 229 |
+
edgecolor=color,
|
| 230 |
+
facecolor="none",
|
| 231 |
+
)
|
| 232 |
+
ax.add_patch(rect)
|
| 233 |
+
ax.text(
|
| 234 |
+
x_min_norm,
|
| 235 |
+
y_min_norm,
|
| 236 |
+
label,
|
| 237 |
+
color=color,
|
| 238 |
+
fontweight="bold",
|
| 239 |
+
bbox=dict(facecolor="white", edgecolor=color, alpha=0.8),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
plt.axis("off")
|
| 243 |
+
plt.tight_layout()
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
Running the functions below
|
| 247 |
+
|
| 248 |
+
```python
|
| 249 |
+
objects_data = parse_bbox(output)
|
| 250 |
+
plot_image_with_bboxes(processor, image, bboxes=objects_data)
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
display the original image with bounding boxes drawn around the person and the dog, along with their respective labels.
|
| 254 |
+
|
| 255 |
+
{ style="display: block; margin: 0 auto"}
|
| 256 |
+
|
| 257 |
+
This example shows that even the 3B model can accurately detect objects based on a general prompt to detect all objects in the image.
|
| 258 |
+
|
| 259 |
+
## More Spatial Understanding Examples
|
| 260 |
+
|
| 261 |
+
We can demonstrate a few other model outputs, corresponding to different spatial understanding tasks.
|
| 262 |
+
|
| 263 |
+
### Detect a specific object using descriptions
|
| 264 |
+
|
| 265 |
+
**Prompt:** *“Outline the position of the dog and output all the bbox coordinates in JSON format.”*
|
| 266 |
+
|
| 267 |
+
**Output:**
|
| 268 |
+
|
| 269 |
+
{ style="display: block; margin: 0 auto"}
|
| 270 |
+
|
| 271 |
+
**Observation:** The dog was accurately detected.
|
| 272 |
+
|
| 273 |
+
The next examples are taken from the original Qwen2.5-VL [cookbook](https://github.com/QwenLM/Qwen2.5-VL/blob/main/cookbooks/spatial_understanding.ipynb) in which they use the model `Qwen2.5-VL-7B-Instruct`.
|
| 274 |
+
|
| 275 |
+
### Reasoning capability
|
| 276 |
+
|
| 277 |
+
**Prompt:** *“Locate the shadow of the paper fox, report the bbox coordinates in JSON format.”*
|
| 278 |
+
|
| 279 |
+
**Note:** The original image size as in the cookbook example was reduced so it can be better processed by the 3B model.
|
| 280 |
+
|
| 281 |
+
**Output:**
|
| 282 |
+
|
| 283 |
+
{ style="display: block; margin: 0 auto"}
|
| 284 |
+
|
| 285 |
+
**Observation:** The shadow of the paper fox was accurately detected.
|
| 286 |
+
|
| 287 |
+
### Understand relationships across different instances
|
| 288 |
+
|
| 289 |
+
**Prompt:** *“Locate the person who acts bravely, report the bbox coordinates in JSON format.”*
|
| 290 |
+
|
| 291 |
+
**Output:**
|
| 292 |
+
|
| 293 |
+
{ style="display: block; margin: 0 auto"}
|
| 294 |
+
|
| 295 |
+
**Observation:** The person who acts bravely was accurately detected.
|
| 296 |
+
|
| 297 |
+
### Find a special instance with unique characteristic
|
| 298 |
+
|
| 299 |
+
**Prompt:** *“If the sun is very glaring, which item in this image should I use? Please locate it in the image with its bbox coordinates and its name and output in JSON format.”*
|
| 300 |
+
|
| 301 |
+
**Output:**
|
| 302 |
+
|
| 303 |
+
{ style="display: block; margin: 0 auto"}
|
| 304 |
+
|
| 305 |
+
**Observation:** The image input in the cookbook has a transparent background. I tested the model with a present background and the produced results were not very logical. The above result is of the original image without background. Moreover, their output is `glasses`, in contrast to our 3B output `umbrella`, but our output is still logical.
|
| 306 |
+
|
| 307 |
+
---
|
| 308 |
+
|
| 309 |
+
In the end of their cookbook they mention that the above examples were based on the default system prompt. The system prompt can be changed so that we can obtain other output format like plain text. The supported Qwen2.5-VL formats are:
|
| 310 |
+
|
| 311 |
+
- bbox-format: JSON
|
| 312 |
+
|
| 313 |
+
```python
|
| 314 |
+
{"bbox_2d": [x1, y1, x2, y2], "label": "object name/description"}
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
- bbox-format: plain text
|
| 318 |
+
|
| 319 |
+
```
|
| 320 |
+
x1,y1,x2,y2 object_name/description
|
| 321 |
+
```
|
| 322 |
+
|
| 323 |
+
- point-format: XML
|
| 324 |
+
|
| 325 |
+
```xml
|
| 326 |
+
<points x y>object_name/description</points>
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
- point-format: JSON
|
| 330 |
+
|
| 331 |
+
```python
|
| 332 |
+
{"point_2d": [x, y], "label": "object name/description"}
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
They also give an example of how to change the system prompt so it ouputs plain text:
|
| 336 |
+
|
| 337 |
+
*“As an AI assistant, you specialize in accurate image object detection, delivering coordinates in plain text format ‘x1,y1,x2,y2 object’.”*
|
| 338 |
+
|
| 339 |
+
## Conclusion
|
| 340 |
+
|
| 341 |
+
In this tutorial, we explored the capabilities of Qwen2.5-VL by using MLX-VLM for various visual understanding tasks. We demonstrated how to load the model and images, generate natural language descriptions, and perform object detection with bounding boxes in different spatial understanding scenarios. Our experiments show that even the 3B model provides accurate object localization and structured JSON outputs, and suggests to be indeed a very powerful vision-language model.
|