JoeJoe1313
commited on
Commit
·
e9dcf74
1
Parent(s):
0361bfb
add posts
Browse files- src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/index.qmd +1 -0
- src/posts/2025-04-06-fine-tuning-function-calling/index.qmd +368 -0
- src/posts/2025-04-15-paligemma-2-mix/images/car_in.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/car_out.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/cat_in.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/cat_out.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/cow_in.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/cow_out.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/input_bb.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/map_mask.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/images/paligemma2-architecture.png +3 -0
- src/posts/2025-04-15-paligemma-2-mix/index.qmd +643 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/3_1.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/3_2.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/4_1.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/4_2.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/5_1.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/5_2.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/config_scheme_dest.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/images/developer_team.png +3 -0
- src/posts/2025-05-06-chat-qwen3-ios/index.qmd +162 -0
- src/posts/2025-05-23-app-docker-fastapi/index.qmd +546 -0
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/index.qmd
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
---
|
2 |
title: "Qwen2.5-vl with MLX-VLM"
|
|
|
3 |
date: "2025-02-13"
|
4 |
categories: [Machine Learning, mlx, vlm]
|
5 |
draft: false
|
|
|
1 |
---
|
2 |
title: "Qwen2.5-vl with MLX-VLM"
|
3 |
+
author: "Joana Levtcheva"
|
4 |
date: "2025-02-13"
|
5 |
categories: [Machine Learning, mlx, vlm]
|
6 |
draft: false
|
src/posts/2025-04-06-fine-tuning-function-calling/index.qmd
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Fine-Tuning a Model for Function-Calling with MLX-LM"
|
3 |
+
author: "Joana Levtcheva"
|
4 |
+
date: "2025-04-06"
|
5 |
+
categories: [Machine Learning, mlx, llm]
|
6 |
+
draft: false
|
7 |
+
---
|
8 |
+
|
9 |
+
In this post, we explore the process of fine-tuning a language model for function-calling using [MLX-LM](https://github.com/ml-explore/mlx-lm). Following the Hugging Face Agents course [notebook](https://huggingface.co/agents-course/notebooks/blob/main/bonus-unit1/bonus-unit1.ipynb), we’ll walk through the steps from setting up the environment to training the model with LoRA adapters. The goal is to empower the model with the ability to intelligently plan and generate function calls, making it a versatile tool for interactive applications. Medium post can be found [here](https://medium.com/@levchevajoana/fine-tuning-a-model-for-function-calling-with-mlx-lm-d00d587e2559)
|
10 |
+
|
11 |
+
## Introduction
|
12 |
+
|
13 |
+
Modern AI models can do much more than generate plain text — they can now integrate with external tools by “calling” functions based on user intent. In this tutorial, we demonstrate how to adapt a pre-trained model (in our case, the [gemma-2-2b-it-4bit](https://huggingface.co/mlx-community/gemma-2-2b-it-4bit) model from the [MLX Community](https://huggingface.co/mlx-community)) to handle function-calling by using the `mlx-lm` package. This involves creating a specialized chat template, preprocessing a dataset of function call interactions, and applying LoRA for efficient fine-tuning.
|
14 |
+
|
15 |
+
## Setting Up the Model and Tokenizer
|
16 |
+
|
17 |
+
We start by importing the necessary libraries and modules, including the MLX-LM package, dataset utilities, and LoRA functions.
|
18 |
+
|
19 |
+
```python
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
from enum import Enum
|
23 |
+
from typing import Dict, List, Tuple, Union
|
24 |
+
|
25 |
+
import mlx.optimizers as optim
|
26 |
+
from datasets import load_dataset
|
27 |
+
from mlx.utils import tree_flatten
|
28 |
+
from mlx_lm import generate, load
|
29 |
+
from mlx_lm.tuner import TrainingArgs, datasets, linear_to_lora_layers, train
|
30 |
+
```
|
31 |
+
|
32 |
+
After loading our model and tokenizer,
|
33 |
+
|
34 |
+
```python
|
35 |
+
model_path = "mlx-community/gemma-2-2b-it-4bit"
|
36 |
+
model, tokenizer = load(model_path)
|
37 |
+
```
|
38 |
+
|
39 |
+
we customize the tokenizer’s chat template to define the structure of our conversational interactions.
|
40 |
+
|
41 |
+
```python
|
42 |
+
tokenizer.chat_template = (
|
43 |
+
"{{ bos_token }}"
|
44 |
+
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}"
|
45 |
+
"{% for message in messages %}"
|
46 |
+
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] | trim + '<end_of_turn><eos>\n' }}"
|
47 |
+
"{% endfor %}"
|
48 |
+
"{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
|
49 |
+
)
|
50 |
+
```
|
51 |
+
|
52 |
+
This template embeds special tokens (like `<bos>`, `<start_of_turn>`,` <think>`, and `<tool_call>`) that mark the different stages of the conversation - from the user’s prompt to the model’s internal reasoning and eventual function call.
|
53 |
+
|
54 |
+
## Dataset Preparation and Preprocessing
|
55 |
+
|
56 |
+
We use the dataset [Jofthomas/hermes-function-calling-thinking-V1](https://huggingface.co/datasets/Jofthomas/hermes-function-calling-thinking-V1) which contains conversations involving function calls.
|
57 |
+
|
58 |
+
```python
|
59 |
+
dataset_path = "Jofthomas/hermes-function-calling-thinking-V1"
|
60 |
+
```
|
61 |
+
|
62 |
+
Let’s load the dataset.
|
63 |
+
|
64 |
+
```python
|
65 |
+
dataset = load_dataset(dataset_path)
|
66 |
+
dataset
|
67 |
+
```
|
68 |
+
|
69 |
+
This outputs
|
70 |
+
|
71 |
+
```
|
72 |
+
DatasetDict({
|
73 |
+
train: Dataset({
|
74 |
+
features: ['conversations'],
|
75 |
+
num_rows: 3570
|
76 |
+
})
|
77 |
+
})
|
78 |
+
```
|
79 |
+
|
80 |
+
showing that the dataset originally includes a “conversations” column, and has 3570 rows. We rename this column to “messages” for consistency
|
81 |
+
|
82 |
+
```python
|
83 |
+
dataset = dataset.rename_column("conversations", "messages")
|
84 |
+
dataset
|
85 |
+
```
|
86 |
+
|
87 |
+
and then apply the following preprocessing function
|
88 |
+
|
89 |
+
```python
|
90 |
+
def preprocess(sample):
|
91 |
+
messages = sample["messages"]
|
92 |
+
first_message = messages[0]
|
93 |
+
|
94 |
+
# Instead of adding a system message, we merge the content into the first user message
|
95 |
+
if first_message["role"] == "system":
|
96 |
+
system_message_content = first_message["content"]
|
97 |
+
# Merge system content with the first user message
|
98 |
+
messages[1]["content"] = (
|
99 |
+
system_message_content
|
100 |
+
+ "Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>\n\n"
|
101 |
+
+ messages[1]["content"]
|
102 |
+
)
|
103 |
+
# Remove the system message from the conversation
|
104 |
+
messages.pop(0)
|
105 |
+
|
106 |
+
return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
|
107 |
+
```
|
108 |
+
|
109 |
+
to the dataset
|
110 |
+
|
111 |
+
```python
|
112 |
+
dataset = dataset.map(preprocess, remove_columns="messages")
|
113 |
+
dataset = dataset["train"].train_test_split(0.1)
|
114 |
+
dataset
|
115 |
+
```
|
116 |
+
|
117 |
+
This function merges any system messages into the first user message, ensuring the context is maintained without extra role annotations. This outputs
|
118 |
+
|
119 |
+
```text
|
120 |
+
DatasetDict({
|
121 |
+
train: Dataset({
|
122 |
+
features: ['text'],
|
123 |
+
num_rows: 3213
|
124 |
+
})
|
125 |
+
test: Dataset({
|
126 |
+
features: ['text'],
|
127 |
+
num_rows: 357
|
128 |
+
})
|
129 |
+
})
|
130 |
+
```
|
131 |
+
|
132 |
+
showing that we have successfully separated our original dataset into a train set with 3213 records, and a test set with 357 records. Each sample is now a formatted text string ready for fine-tuning. Let’s see one train example
|
133 |
+
|
134 |
+
```text
|
135 |
+
<bos><start_of_turn>human
|
136 |
+
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'create_todo', 'description': 'Create a new todo item', 'parameters': {'type': 'object', 'properties': {'task': {'type': 'string', 'description': 'The task description'}, 'due_date': {'type': 'string', 'format': 'date', 'description': 'The due date of the task'}, 'priority': {'type': 'integer', 'description': 'The priority of the task (1-5)'}}, 'required': ['task', 'due_date']}}}, {'type': 'function', 'function': {'name': 'calculate_bmi', 'description': 'Calculate the Body Mass Index (BMI)', 'parameters': {'type': 'object', 'properties': {'weight': {'type': 'number', 'description': 'The weight in kilograms'}, 'height': {'type': 'number', 'description': 'The height in meters'}}, 'required': ['weight', 'height']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
137 |
+
<tool_call>
|
138 |
+
{tool_call}
|
139 |
+
</tool_call>Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>
|
140 |
+
|
141 |
+
I need to create a new task.<end_of_turn><eos>
|
142 |
+
<start_of_turn>model
|
143 |
+
Sure, I can help with that. Could you please provide me with the task description, the due date, and the priority level?<end_of_turn><eos>
|
144 |
+
<start_of_turn>human
|
145 |
+
The task is to prepare a presentation for the annual meeting. The due date is 2022-09-15 and the priority level is 3.<end_of_turn><eos>
|
146 |
+
<start_of_turn>model
|
147 |
+
<think>Okay, so I need to figure out the reasoning that goes between the conversation and the next function call. Let me break this down step by step.
|
148 |
+
|
149 |
+
First, looking at the conversation, the user starts by saying, "I need to create a new task." The model responds by asking for the task description, due date, and priority level. The user then provides all the necessary details: the task is to prepare a presentation for the annual meeting, the due date is 2022-09-15, and the priority is 3.
|
150 |
+
|
151 |
+
Now, the model's next move should be to call the appropriate function. The available functions are 'create_todo' and 'calculate_bmi'. Since the user is talking about creating a new task, 'create_todo' is the relevant function here.
|
152 |
+
|
153 |
+
Examining the function's parameters, it requires 'task', 'due_date', and takes 'priority' as optional. The user provided all three, so we can include them in the arguments.
|
154 |
+
|
155 |
+
Therefore, the model will execute the 'create_todo' function with the provided task details. This makes sense because the conversation is about setting up a new task, and the function is designed for that exact purpose.
|
156 |
+
</think><tool_call>
|
157 |
+
{'name': 'create_todo', 'arguments': {'task': 'Prepare a presentation for the annual meeting', 'due_date': '2022-09-15', 'priority': 3}}
|
158 |
+
</tool_call><end_of_turn><eos>
|
159 |
+
<start_of_turn>tool
|
160 |
+
<tool_response>
|
161 |
+
{'status': 'success', 'message': 'Todo item successfully created', 'todo_id': '12345'}
|
162 |
+
</tool_response><end_of_turn><eos>
|
163 |
+
<start_of_turn>model
|
164 |
+
Your task has been successfully created. The ID for your new task is 12345.<end_of_turn><eos>
|
165 |
+
```
|
166 |
+
|
167 |
+
## Training Setup with LoRA Adapters
|
168 |
+
|
169 |
+
To efficiently fine-tune the model without retraining all of its parameters, we leverage LoRA. First, we create a directory to store adapter configurations and weights.
|
170 |
+
|
171 |
+
```python
|
172 |
+
adapter_path = "adapters_fc"
|
173 |
+
os.makedirs(adapter_path, exist_ok=True)
|
174 |
+
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
|
175 |
+
adapter_file_path = os.path.join(adapter_path, "adapters.safetensors")
|
176 |
+
```
|
177 |
+
|
178 |
+
Then we define our LoRA configuration, with parameters like number of layers 8, a rank of 16, scale of 64, a dropout of 0.05,
|
179 |
+
|
180 |
+
```python
|
181 |
+
lora_config = {
|
182 |
+
"num_layers": 8,
|
183 |
+
"lora_parameters": {
|
184 |
+
"rank": 16,
|
185 |
+
"scale": 64,
|
186 |
+
"dropout": 0.05,
|
187 |
+
},
|
188 |
+
}
|
189 |
+
```
|
190 |
+
|
191 |
+
and save it as a JSON file.
|
192 |
+
|
193 |
+
```python
|
194 |
+
with open(adapter_config_path, "w") as f:
|
195 |
+
json.dump(lora_config, f, indent=4)
|
196 |
+
```
|
197 |
+
|
198 |
+
Next, we define the training arguments, specifically setting a single iteration,
|
199 |
+
|
200 |
+
```python
|
201 |
+
training_args = TrainingArgs(
|
202 |
+
adapter_file=adapter_file_path,
|
203 |
+
iters=1,
|
204 |
+
steps_per_eval=50,
|
205 |
+
)
|
206 |
+
```
|
207 |
+
|
208 |
+
and freeze the original model parameters.
|
209 |
+
|
210 |
+
```python
|
211 |
+
_ = model.freeze()
|
212 |
+
```
|
213 |
+
|
214 |
+
Then, we convert selected linear layers to LoRA layers to make only a small subset of parameters trainable.
|
215 |
+
|
216 |
+
```python
|
217 |
+
linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])
|
218 |
+
```
|
219 |
+
|
220 |
+
In our example, this results in
|
221 |
+
|
222 |
+
```python
|
223 |
+
num_train_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
|
224 |
+
print(f"Number of trainable parameters: {num_train_params}")
|
225 |
+
```
|
226 |
+
|
227 |
+
983,040 trainable parameters. Finally, we should not forget to activate training mode while still preserving the frozen state of the main model parameters.
|
228 |
+
|
229 |
+
```python
|
230 |
+
_ = model.train()
|
231 |
+
```
|
232 |
+
|
233 |
+
## Fine-Tuning Process and Metrics
|
234 |
+
|
235 |
+
With our model and dataset ready, we configure a metrics tracker to log both training and validation losses,
|
236 |
+
|
237 |
+
```python
|
238 |
+
class Metrics:
|
239 |
+
def __init__(self) -> None:
|
240 |
+
self.train_losses: List[Tuple[int, float]] = []
|
241 |
+
self.val_losses: List[Tuple[int, float]] = []
|
242 |
+
|
243 |
+
def on_train_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
|
244 |
+
self.train_losses.append((info["iteration"], info["train_loss"]))
|
245 |
+
|
246 |
+
def on_val_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
|
247 |
+
self.val_losses.append((info["iteration"], info["val_loss"]))
|
248 |
+
```
|
249 |
+
|
250 |
+
and create an instance of this class.
|
251 |
+
|
252 |
+
```python
|
253 |
+
metrics = Metrics()
|
254 |
+
```
|
255 |
+
|
256 |
+
We also create mlx-lm–suitable datasets by first defining the following configuration about our datasets,
|
257 |
+
|
258 |
+
```python
|
259 |
+
configs = {
|
260 |
+
"mask_prompt": False,
|
261 |
+
"prompt_feature": "prompt",
|
262 |
+
"text_feature": "text",
|
263 |
+
"completion_feature": "completion",
|
264 |
+
"chat_feature": "messages",
|
265 |
+
}
|
266 |
+
```
|
267 |
+
|
268 |
+
and then create a train set with the help of the mlx-lm function `datasets.create_dataset` and passing the configuration from above.
|
269 |
+
|
270 |
+
```python
|
271 |
+
train_set = datasets.create_dataset(
|
272 |
+
dataset["train"],
|
273 |
+
tokenizer,
|
274 |
+
configs
|
275 |
+
)
|
276 |
+
```
|
277 |
+
|
278 |
+
Similarly, we create our validation set.
|
279 |
+
|
280 |
+
```python
|
281 |
+
val_set = datasets.create_dataset(
|
282 |
+
dataset["test"],
|
283 |
+
tokenizer,
|
284 |
+
configs
|
285 |
+
)
|
286 |
+
```
|
287 |
+
|
288 |
+
Finally, we start the fine-tuning process by calling the `train()` function.
|
289 |
+
|
290 |
+
```python
|
291 |
+
train(
|
292 |
+
model=model,
|
293 |
+
tokenizer=tokenizer,
|
294 |
+
args=training_args,
|
295 |
+
optimizer=optim.Adam(learning_rate=1e-5),
|
296 |
+
train_dataset=train_set,
|
297 |
+
val_dataset=val_set,
|
298 |
+
training_callback=metrics,
|
299 |
+
)
|
300 |
+
```
|
301 |
+
|
302 |
+
The training logs report both training and validation losses, along with performance metrics like tokens processed per second and memory usage. After training, the adapter weights are saved and can later be reloaded to quickly deploy the fine-tuned model.
|
303 |
+
|
304 |
+
```
|
305 |
+
Starting training..., iters: 1
|
306 |
+
Iter 1: Val loss 1.821, Val took 128.584s
|
307 |
+
Iter 1: Train loss 1.861, Learning Rate 1.000e-05, It/sec 0.430, Tokens/sec 160.427, Trained Tokens 3735, Peak mem 20.665 GB
|
308 |
+
Saved final weights to adapters_fc/adapters.safetensors.
|
309 |
+
```
|
310 |
+
|
311 |
+
## Evaluating the Fine-Tuned Model
|
312 |
+
|
313 |
+
After training, we reload the model with the newly learned LoRA weights,
|
314 |
+
|
315 |
+
```python
|
316 |
+
model_lora, _ = load(model_path, adapter_path=adapter_path)
|
317 |
+
```
|
318 |
+
|
319 |
+
set our prompt to
|
320 |
+
|
321 |
+
```python
|
322 |
+
prompt="""<bos><start_of_turn>human
|
323 |
+
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'convert_currency', 'description': 'Convert from one currency to another', 'parameters': {'type': 'object', 'properties': {'amount': {'type': 'number', 'description': 'The amount to convert'}, 'from_currency': {'type': 'string', 'description': 'The currency to convert from'}, 'to_currency': {'type': 'string', 'description': 'The currency to convert to'}}, 'required': ['amount', 'from_currency', 'to_currency']}}}, {'type': 'function', 'function': {'name': 'calculate_distance', 'description': 'Calculate the distance between two locations', 'parameters': {'type': 'object', 'properties': {'start_location': {'type': 'string', 'description': 'The starting location'}, 'end_location': {'type': 'string', 'description': 'The ending location'}}, 'required': ['start_location', 'end_location']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
324 |
+
<tool_call>
|
325 |
+
{tool_call}
|
326 |
+
</tool_call>Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>
|
327 |
+
|
328 |
+
Hi, I need to convert 500 USD to Euros. Can you help me with that?<end_of_turn><eos>
|
329 |
+
<start_of_turn>model
|
330 |
+
<think>"""
|
331 |
+
```
|
332 |
+
|
333 |
+
and generate a response
|
334 |
+
|
335 |
+
```python
|
336 |
+
generate(model_lora, tokenizer, prompt=prompt, verbose=True, max_tokens=1000)
|
337 |
+
```
|
338 |
+
|
339 |
+
which returns
|
340 |
+
|
341 |
+
```text
|
342 |
+
==========
|
343 |
+
|
344 |
+
To convert USD to Euros, I need to use the 'convert_currency' function from the provided tools. I need to provide the amount to convert, the currency to convert from (USD), and the currency to convert to (Euros). I should also make sure the amount is a number.
|
345 |
+
</think>
|
346 |
+
|
347 |
+
<tool_call>
|
348 |
+
{
|
349 |
+
'name': 'convert_currency',
|
350 |
+
'arguments': {
|
351 |
+
'amount': 500,
|
352 |
+
'from_currency': 'USD',
|
353 |
+
'to_currency': 'EUR'
|
354 |
+
}
|
355 |
+
}
|
356 |
+
</tool_call>
|
357 |
+
|
358 |
+
==========
|
359 |
+
Prompt: 460 tokens, 862.170 tokens-per-sec
|
360 |
+
Generation: 135 tokens, 68.472 tokens-per-sec
|
361 |
+
Peak memory: 20.665 GB
|
362 |
+
```
|
363 |
+
|
364 |
+
The model first walks through a thought process before generating a function call to convert USD to Euros. This demonstrates the model’s improved ability to generate precise JSON function calls within `<tool_call>` XML tags.
|
365 |
+
|
366 |
+
## Conclusion
|
367 |
+
|
368 |
+
Fine-tuning a model for function-calling can significantly enhance its interactivity and real-world utility. By adapting the chat template, preprocessing the dataset, and applying LoRA adapters, we’ve demonstrated a streamlined approach to training a model that can generate executable function calls with clear reasoning. It is impressive that we achieved this by using only the `mlx-lm` package. Happy fine-tuning!
|
src/posts/2025-04-15-paligemma-2-mix/images/car_in.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/car_out.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/cat_in.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/cat_out.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/cow_in.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/cow_out.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/input_bb.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/map_mask.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/images/paligemma2-architecture.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-04-15-paligemma-2-mix/index.qmd
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Image Segmentation with PaliGemma 2 Mix and MLX"
|
3 |
+
author: "Joana Levtcheva"
|
4 |
+
date: "2025-04-15"
|
5 |
+
categories: [Machine Learning, mlx, vlm]
|
6 |
+
draft: false
|
7 |
+
---
|
8 |
+
|
9 |
+
In this post, we are going to explore Google’s [**PaliGemma 2 mix**](https://developers.googleblog.com/en/introducing-paligemma-2-mix/) vision-language model (VLM), and its capabilities to perform image segmentation. What’s interesting is that we are going to perform this task by only using Apple’s MLX framework, and MLX-VLM. This would eliminate the dependency of using JAX/Flax as in the original Google’s segmentation [script](https://github.com/google-research/big_vision/blob/main/big_vision/evaluators/proj/paligemma/transfers/segmentation.py), and would allow us to fully and seamlessly utilise Apple’s unified memory. Medium post can be found [here](https://medium.com/@levchevajoana/image-segmentation-with-paligemma-2-mix-and-mlx-7e69e077968b).
|
10 |
+
|
11 |
+
# Introduction
|
12 |
+
|
13 |
+
## PaliGemma 2
|
14 |
+
|
15 |
+
In December 2024 Google introduced the [PaliGemma 2](https://developers.googleblog.com/en/introducing-paligemma-2-powerful-vision-language-models-simple-fine-tuning/) vision-language models (VLMs). These are pre-trained (**pt**) models coming in three different sizes: `3B`, `10B`, and `28B`, as well as three different input resolutions for images: `224x224`, `448x448`, and `896x896` pixels. These models represent the latest evolution of vision-language models developed by Google, building upon the foundation laid by its predecessor, PaliGemma. Below, we can see the architecture of the PaliGemma 2 model.
|
16 |
+
|
17 |
+
<figure>
|
18 |
+
<img src="images/paligemma2-architecture.png" alt="PaliGemma 2 architecture" style="display: block; margin: 0 auto">
|
19 |
+
<figcaption style="text-align: center">Figure 1. PaliGemma 2 Architecture Overview <span style="font-size: 0.8em;">[<a href="https://arxiv.org/pdf/2412.03555">Source</a>]</span></figcaption>
|
20 |
+
</figure>
|
21 |
+
|
22 |
+
PaliGemma 2 processes images at resolutions of `224×224`, `448×448`, or `896×896` pixels using a **[SigLIP-400m](https://arxiv.org/abs/2303.15343) vision encoder** with a patch size of 14×14 pixels. This design yields 256, 1024, or 4096 tokens, respectively. After a linear projection, the resulting image tokens are concatenated with the input text tokens, and [**Gemma 2**](https://blog.google/technology/developers/google-gemma-2/) is used as a **text decoder** to autoregressively complete the combined prefix to generate an answer.
|
23 |
+
|
24 |
+
## PaliGemma 2 Mix
|
25 |
+
|
26 |
+
As already mentioned, PaliGemma 2 models are pre-trained models, but they are also designed to be easy to fine-tune and adapt to various specific vision-language tasks and domains. Google wanted to demonstrate the performance of a fine-tuned version of the pt PaliGemma 2 models on downstream tasks, and thus a few months later, in February 2025, they introduced [**PaliGemma 2 mix**](https://developers.googleblog.com/en/introducing-paligemma-2-mix/). These models are fine-tuned to a mixture of vision language tasks that can be used out-of-the-box for common use cases. They are available in three sizes: `3B`, `10B`, and `28B`, and support resolutions of `224×224` and `448×448` pixels.
|
27 |
+
|
28 |
+
### Tasks
|
29 |
+
|
30 |
+
PaliGemma 2 mix can perform the following types of tasks:
|
31 |
+
|
32 |
+
- Short and long captioning
|
33 |
+
- Optical character recognition (OCR)
|
34 |
+
- Image question answering
|
35 |
+
- (Multiple) object detection
|
36 |
+
- (Multiple) image segmentation
|
37 |
+
|
38 |
+
### Prompting
|
39 |
+
|
40 |
+
In general, the PaliGemma models are very sensitive to the prompt’s syntax and patterns. But based on the following Hugging Face [article](https://huggingface.co/blog/paligemma2mix) when using PaliGemma 2 mix models, open-ended prompts yield better performance than the previously required task-prefixed prompts. Earlier, task-specific prefixes were essential, like
|
41 |
+
|
42 |
+
- `"caption {lang}\n"`: Short captions
|
43 |
+
- `"describe {lang}\n"`: More descriptive captions
|
44 |
+
- `"ocr"`: Optical character recognition
|
45 |
+
- `"answer {lang} {question}\n"`: Question answering about the image contents
|
46 |
+
- `"question {lang} {answer}\n"`: Question generation for a given answer
|
47 |
+
|
48 |
+
However, two specific tasks - **object detection** and **image segmentation** - still exclusively require task prefixes:
|
49 |
+
|
50 |
+
- `"detect {object description} ; {object description} ; ...\n"`: Locate multiple objects in an image and return the bounding boxes for those objects
|
51 |
+
- `"segment {object description} ; {object description} ; ...\n"`: Locate the area occupied by multiple objects in an image to create an image segmentation for that object
|
52 |
+
|
53 |
+
# Image Segmentation
|
54 |
+
|
55 |
+
## What is Image Segmentation?
|
56 |
+
|
57 |
+
Image segmentation is a key computer vision technique that divides an image into pixel groups, or segments, enabling tasks like object detection, scene understanding, and advanced image processing. Traditional methods use pixel features such as color, brightness, contrast, and intensity to separate objects from the background, often relying on simple heuristics or basic machine learning. Recently, deep learning models with complex neural networks have dramatically improved segmentation accuracy.
|
58 |
+
|
59 |
+
Unlike image classification, which labels an entire image, or object detection, which locates objects with bounding boxes, image segmentation provides detailed pixel-level annotations. This approach assigns every pixel to a specific category, with variants including semantic segmentation (classifying pixels), instance segmentation (distinguishing between instances of the same object), and panoptic segmentation (combining both methods).
|
60 |
+
|
61 |
+
## Image Segmentation with VLMs
|
62 |
+
|
63 |
+
VLMs enhance traditional image segmentation by enabling open-vocabulary segmentation through textual instructions, moving away from closed-set methods that rely on predefined categories. By merging text and image data into a common feature space, these models reduce adaptation costs and excel at tasks like referring expression segmentation. For example, a user might prompt the model to *“segment the cat sitting on the chair”*, and the VLM would identify and segment the pixels corresponding to that specific cat.
|
64 |
+
|
65 |
+
To achieve this, VLMs harness visual features from encoders like CNNs or Vision Transformers, using cross-attention to focus on image regions relevant to the text. Some models are fine-tuned to produce bounding boxes or segmentation masks directly, and careful prompting guides them to accurately segment based on the integrated understanding of visual content and language.
|
66 |
+
|
67 |
+
## Image Segmentation the PaliGemma 2 Way
|
68 |
+
|
69 |
+
Earlier, in **Figure 1** we saw that PaliGemma 2’s architecture combines a Transformer decoder based on the Gemma 2 language model with a Vision Transformer image encoder initialised from SigLIP-So400m/14. The SigLIP encoder divides input images into `14x14` pixel patches to generate “soft tokens” that capture spatial relationships. Then, a linear projection layer is used to map the visual tokens into the same dimensional space as the input embeddings of the Gemma 2 language model. This projection ensures that the visual information can be seamlessly combined with textual information for processing by the language model.
|
70 |
+
|
71 |
+
The Gemma 2 language model functions as the decoder, processing concatenated image tokens and text tokens to produce autoregressive text output, predicting one token at a time based on the preceding context. To enhance its capabilities for vision-language tasks, PaliGemma extends the vocabulary of the standard Gemma tokenizer (having 256,000 tokens) with additional special tokens. These include 1024 tokens representing coordinates in a normalised image space, denoted as `<loc0000>` through `<loc1023>`, and another 128 tokens, `<seg000>` through `<seg127>`, which are codewords used for a lightweight referring-expression segmentation vector-quantized approach.
|
72 |
+
|
73 |
+
### Segmentation Output
|
74 |
+
|
75 |
+
When processing a segmentation prompt, PaliGemma 2 mix produces a sequence that begins with four location tokens defining the bounding box for the segmented object. These four tokens specify the bounding box coordinates in the normalized image space. This is followed by 16 segmentation tokens, which can be decoded via a learned codebook into a binary segmentation mask confined within the identified region. Below is an example output:
|
76 |
+
|
77 |
+
```text
|
78 |
+
<loc0336><loc0049><loc0791><loc0941><seg106><seg074><seg114><seg081><seg082><seg028><seg018><seg037><seg120><seg073><seg061><seg125><seg045><seg059><seg052><seg084>
|
79 |
+
```
|
80 |
+
|
81 |
+
### Segmentation Mask
|
82 |
+
|
83 |
+
If we want to further process the 16 segmentation tokens to generate a binary segmentation mask within the identified bounding box, we have to decode the segmentation tokens by using the Decoder from Google’s big vision repository related to the PaliGemma models. It is available in the following [script](https://github.com/google-research/big_vision/blob/main/big_vision/evaluators/proj/paligemma/transfers/segmentation.py). As we can see, the script uses JAX and Flax, and it is known that the Metal plug-in for JAX is still not fully supported as stated in the [Accelerated JAX on Mac](https://developer.apple.com/metal/jax/) article. In the next part of this post, we are going to show not only how to reconstruct the binary mask with the help of the above script, but we are also going to show how to translate JAX/Flax to [mlx](https://github.com/ml-explore/mlx) so that we can fully utilise the unified memory in Apple’s chips.
|
84 |
+
|
85 |
+
# Tutorial
|
86 |
+
|
87 |
+
In this section, we are going to generate a segmentation mask with the PaliGemma 2 mix model, specifically [mlx-community/paligemma2–10b-mix-448–8bit](https://huggingface.co/mlx-community/paligemma2-10b-mix-448-8bit), by using only the packages `mlx-vlm` and `mlx`. We are also going to overlay the mask on top of the image we are segmenting.
|
88 |
+
|
89 |
+
## Overview of the Process
|
90 |
+
|
91 |
+
Let’s first begin by outlining the steps of the process for generating a segmentation mask. An illustrative diagram can be seen in **Figure 2**.
|
92 |
+
|
93 |
+
- We start with passing a **prompt** to the model of the form *"segment cat\n"*, and the image we want to segment. This is our **original image** with dimensions $x_{\text{orig}}$ by $y_{\text{orig}}$.
|
94 |
+
- Then, the model’s image processor (SiglipImageProcessor) yields to an **input image** with dimensions $x_{\text{input}}$ by $y_{\text{input}}$. In the PaliGemma 2 mix case this would be either `224x224` or `448x448`, depending on the model we have chosen to use. In our case, it would be `448x448`.
|
95 |
+
- The model generates an output with 4 location coordinates and 16 segmentation tokens. The `<locXXXX><locXXXX><locXXXX><locXXXX>` sequence corresponds to the $y_{\text{min}}$, $x_{\text{min}}$, $y_{\text{max}}$, $x_{\text{max}}$ coordinates defining the **bounding box**. These coordinates should be normalised to an image size of `1024x1024` to obtain the bounding box coordinates of the object we want to segment with respect to the input image dimensions.
|
96 |
+
|
97 |
+
<figure>
|
98 |
+
<img src="images/input_bb.png" alt="Model input and bounding box" style="display: block; margin: 0 auto">
|
99 |
+
<figcaption style="text-align: center">Figure 2. Model input and bounding box coordinates</figcaption>
|
100 |
+
</figure>
|
101 |
+
|
102 |
+
Now that we’ve defined the bounding box by its coordinates, let’s zoom in on its details as shown in **Figure 3**, and dicuss how we would overlay the segmentation mask on top of the image we are segmenting.
|
103 |
+
|
104 |
+
- The model has returned the 16 segmentation tokens of the form `<segXXX>`. After decoding them via the codebook we end up reconstructing the **segmentation mask**. This mask has a size of `64x64` pixels.
|
105 |
+
- Next, we need to map the segmentation mask onto the bounding box that was previously defined. This is accomplished using classical interpolation techniques to scale the mask to the bounding box’s dimensions.
|
106 |
+
|
107 |
+
<figure>
|
108 |
+
<img src="images/map_mask.png" alt="Mapping mask to bounding box" style="display: block; margin: 0 auto">
|
109 |
+
<figcaption style="text-align: center">Figure 3. Mapping the 64x64 mask to the bounding box</figcaption>
|
110 |
+
</figure>
|
111 |
+
|
112 |
+
- Once resized, the mask is aligned to fit within the bounding box. To overlay this mask on the original image, we create an empty array matching the dimensions of the input image and then replace the array values corresponding to the bounding box coordinates with those from the interpolated segmentation mask.
|
113 |
+
|
114 |
+
# MLX
|
115 |
+
|
116 |
+
Finally, it’s time to dive into the coding section of this blog and focus specifically on the `mlx` components. The code can be found in [GitHub](https://github.com/JoeJoe1313/LLMs-Journey/blob/main/VLMs/paligemma_segmentation_mlx.py).
|
117 |
+
|
118 |
+
We begin by importing the necessary libraries and modules,
|
119 |
+
|
120 |
+
```python
|
121 |
+
import argparse
|
122 |
+
import functools
|
123 |
+
import logging
|
124 |
+
import re
|
125 |
+
from typing import Callable, List, Tuple
|
126 |
+
|
127 |
+
import cv2
|
128 |
+
import matplotlib.pyplot as plt
|
129 |
+
import mlx.core as mx
|
130 |
+
import mlx.nn as nn
|
131 |
+
import numpy as np
|
132 |
+
from mlx_vlm import apply_chat_template, generate, load
|
133 |
+
from mlx_vlm.utils import load_image
|
134 |
+
from tensorflow.io import gfile
|
135 |
+
```
|
136 |
+
|
137 |
+
then, we establish the paths for the models and image resources. The `MODEL_PATH` points to the specific PaliGemma model that we are going to use for segmentation tasks. The `IMAGE_PATH` is the location of the image that we will process, and the `_KNOWN_MODELS` dictionary provides a reference to the VAE checkpoint needed for mask reconstruction.
|
138 |
+
|
139 |
+
```python
|
140 |
+
MODEL_PATH = "mlx-community/paligemma2-10b-mix-448-8bit"
|
141 |
+
IMAGE_PATH = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
|
142 |
+
_KNOWN_MODELS = {"oi": "gs://big_vision/paligemma/vae-oid.npz"}
|
143 |
+
```
|
144 |
+
|
145 |
+
Before diving into the core functionality, we set up logging to keep track of the execution flow and for debugging purposes. The following snippet initializes Python’s built-in logging system:
|
146 |
+
|
147 |
+
```python
|
148 |
+
logging.basicConfig()
|
149 |
+
log = logging.getLogger(__name__)
|
150 |
+
log.setLevel(logging.INFO)
|
151 |
+
```
|
152 |
+
|
153 |
+
The `ResBlock` class implements a basic residual block typical for convolutional architectures. It comprises three convolution layers:
|
154 |
+
|
155 |
+
- Two `3x3` convolutions with ReLU activations, which process the input.
|
156 |
+
- One `1x1` convolution to adjust dimensions if needed.
|
157 |
+
|
158 |
+
The output of the block is computed by summing the result of the convolutions with the original input. This residual connection helps maintain gradient flow during training and preserves information across layers.
|
159 |
+
|
160 |
+
```python
|
161 |
+
class ResBlock(nn.Module):
|
162 |
+
def __init__(self, features: int):
|
163 |
+
super().__init__()
|
164 |
+
self.conv1 = nn.Conv2d(
|
165 |
+
in_channels=features, out_channels=features, kernel_size=3, padding=1
|
166 |
+
)
|
167 |
+
self.conv2 = nn.Conv2d(
|
168 |
+
in_channels=features, out_channels=features, kernel_size=3, padding=1
|
169 |
+
)
|
170 |
+
self.conv3 = nn.Conv2d(
|
171 |
+
in_channels=features, out_channels=features, kernel_size=1, padding=0
|
172 |
+
)
|
173 |
+
|
174 |
+
def __call__(self, x: mx.array) -> mx.array:
|
175 |
+
original_x = x
|
176 |
+
x = nn.relu(self.conv1(x))
|
177 |
+
x = nn.relu(self.conv2(x))
|
178 |
+
x = self.conv3(x)
|
179 |
+
return x + original_x
|
180 |
+
```
|
181 |
+
|
182 |
+
The `Decoder` class takes quantized vectors (obtained from segmentation tokens) and upscales them to produce a mask:
|
183 |
+
|
184 |
+
- An initial convolution reduces the channel dimension.
|
185 |
+
- A series of configurable residual blocks further process the features.
|
186 |
+
Multiple transpose convolution layers (upsample layers) scale the feature maps until the desired resolution is reached.
|
187 |
+
- A final convolution produces the output mask.
|
188 |
+
|
189 |
+
```python
|
190 |
+
class Decoder(nn.Module):
|
191 |
+
"""Decoder that upscales quantized vectors to produce a mask.
|
192 |
+
The architecture is parameterized to avoid hardcoded layer definitions.
|
193 |
+
Takes channels-last input data (B, H, W, C).
|
194 |
+
"""
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
in_channels: int = 512,
|
199 |
+
res_channels: int = 128,
|
200 |
+
out_channels: int = 1,
|
201 |
+
num_res_blocks: int = 2,
|
202 |
+
upsample_channels: Tuple[int, ...] = (128, 64, 32, 16),
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
self.conv_in = nn.Conv2d(
|
206 |
+
in_channels=in_channels, out_channels=res_channels, kernel_size=1, padding=0
|
207 |
+
)
|
208 |
+
self.res_blocks = [
|
209 |
+
ResBlock(features=res_channels) for _ in range(num_res_blocks)
|
210 |
+
]
|
211 |
+
self.upsample_layers = []
|
212 |
+
out_up_ch = res_channels
|
213 |
+
for ch in upsample_channels:
|
214 |
+
self.upsample_layers.append(
|
215 |
+
nn.ConvTranspose2d(
|
216 |
+
in_channels=out_up_ch,
|
217 |
+
out_channels=ch,
|
218 |
+
kernel_size=4,
|
219 |
+
stride=2,
|
220 |
+
padding=1,
|
221 |
+
)
|
222 |
+
)
|
223 |
+
out_up_ch = ch
|
224 |
+
self.conv_out = nn.Conv2d(
|
225 |
+
in_channels=upsample_channels[-1],
|
226 |
+
out_channels=out_channels,
|
227 |
+
kernel_size=1,
|
228 |
+
padding=0,
|
229 |
+
)
|
230 |
+
|
231 |
+
def __call__(self, x: mx.array) -> mx.array:
|
232 |
+
x = nn.relu(self.conv_in(x))
|
233 |
+
for block in self.res_blocks:
|
234 |
+
x = block(x)
|
235 |
+
for layer in self.upsample_layers:
|
236 |
+
x = nn.relu(layer(x))
|
237 |
+
|
238 |
+
return self.conv_out(x)
|
239 |
+
```
|
240 |
+
|
241 |
+
The helper function `_get_params` is designed to convert a PyTorch checkpoint into a format that MLX can work with. It does so by
|
242 |
+
|
243 |
+
- Transposing kernel weights to match the expected output format: from PyTorch’s format to MLX’s (Out, H, W, In) format.
|
244 |
+
- Organizing the parameters into a structured dictionary that reflects the architecture of the decoder, including the convolutional layers, residual blocks, and upsample layers.
|
245 |
+
|
246 |
+
This organized set of parameters is then used to initialize the decoder network.
|
247 |
+
|
248 |
+
```python
|
249 |
+
def _get_params(checkpoint: dict) -> dict:
|
250 |
+
"""Converts PyTorch checkpoint to MLX params (nested dict).
|
251 |
+
Uses transpositions yielding (Out, H, W, In) format weights."""
|
252 |
+
|
253 |
+
def transp(kernel: np.ndarray) -> mx.array:
|
254 |
+
return mx.transpose(mx.array(kernel), (0, 2, 3, 1))
|
255 |
+
|
256 |
+
def transp_transpose(kernel: np.ndarray) -> mx.array:
|
257 |
+
intermediate = mx.transpose(mx.array(kernel), (1, 0, 2, 3))
|
258 |
+
|
259 |
+
return mx.transpose(intermediate, (0, 2, 3, 1))
|
260 |
+
|
261 |
+
def conv(name: str) -> dict:
|
262 |
+
return {
|
263 |
+
"bias": mx.array(checkpoint[f"{name}.bias"]),
|
264 |
+
"weight": transp(checkpoint[f"{name}.weight"]),
|
265 |
+
}
|
266 |
+
|
267 |
+
def conv_transpose(name: str) -> dict:
|
268 |
+
return {
|
269 |
+
"bias": mx.array(checkpoint[f"{name}.bias"]),
|
270 |
+
"weight": transp_transpose(checkpoint[f"{name}.weight"]),
|
271 |
+
}
|
272 |
+
|
273 |
+
def resblock(name: str) -> dict:
|
274 |
+
return {
|
275 |
+
"conv1": conv(f"{name}.0"),
|
276 |
+
"conv2": conv(f"{name}.2"),
|
277 |
+
"conv3": conv(f"{name}.4"),
|
278 |
+
}
|
279 |
+
|
280 |
+
params = {
|
281 |
+
"_embeddings": mx.array(checkpoint["_vq_vae._embedding"]),
|
282 |
+
"conv_in": conv("decoder.0"),
|
283 |
+
"res_blocks": [
|
284 |
+
resblock("decoder.2.net"),
|
285 |
+
resblock("decoder.3.net"),
|
286 |
+
],
|
287 |
+
"upsample_layers": [
|
288 |
+
conv_transpose("decoder.4"),
|
289 |
+
conv_transpose("decoder.6"),
|
290 |
+
conv_transpose("decoder.8"),
|
291 |
+
conv_transpose("decoder.10"),
|
292 |
+
],
|
293 |
+
"conv_out": conv("decoder.12"),
|
294 |
+
}
|
295 |
+
|
296 |
+
return params
|
297 |
+
```
|
298 |
+
|
299 |
+
The function `_quantized_values_from_codebook_indices` takes the segmentation tokens (represented as codebook indices) and uses the embeddings from the codebook to retrieve the corresponding encoded representations. These values are reshaped to fit the expected dimensions (batch, height, width, channels) so that they are ready for processing by the decoder.
|
300 |
+
|
301 |
+
```python
|
302 |
+
def _quantized_values_from_codebook_indices(
|
303 |
+
codebook_indices: mx.array, embeddings: mx.array
|
304 |
+
) -> mx.array:
|
305 |
+
batch_size, num_tokens = codebook_indices.shape
|
306 |
+
expected_tokens = 16
|
307 |
+
if num_tokens != expected_tokens:
|
308 |
+
log.error(f"Expected {expected_tokens} tokens, got {codebook_indices.shape}")
|
309 |
+
|
310 |
+
encodings = mx.take(embeddings, codebook_indices.reshape((-1,)), axis=0)
|
311 |
+
|
312 |
+
return encodings.reshape((batch_size, 4, 4, embeddings.shape[1]))
|
313 |
+
```
|
314 |
+
|
315 |
+
The `get_reconstruct_masks` function loads the VAE checkpoint and initializes the decoder with the appropriate parameters. By extracting and setting up the necessary embeddings and decoder weights, this function returns another function (`reconstruct_masks`) that, when given segmentation tokens, decodes them into a binary segmentation mask.
|
316 |
+
|
317 |
+
```python
|
318 |
+
@functools.cache
|
319 |
+
def get_reconstruct_masks(model: str) -> Callable[[mx.array], mx.array]:
|
320 |
+
"""Loads the checkpoint and returns a function that reconstructs masks
|
321 |
+
from codebook indices using a preloaded MLX decoder.
|
322 |
+
"""
|
323 |
+
checkpoint_path = _KNOWN_MODELS.get(model, model)
|
324 |
+
with gfile.GFile(checkpoint_path, "rb") as f:
|
325 |
+
checkpoint_data = dict(np.load(f))
|
326 |
+
|
327 |
+
params = _get_params(checkpoint_data)
|
328 |
+
embeddings = params.pop("_embeddings")
|
329 |
+
log.info(f"VAE embedding dimension: {embeddings.shape[1]}")
|
330 |
+
|
331 |
+
decoder = Decoder()
|
332 |
+
decoder.update(params)
|
333 |
+
|
334 |
+
def reconstruct_masks(codebook_indices: mx.array) -> mx.array:
|
335 |
+
quantized = _quantized_values_from_codebook_indices(
|
336 |
+
codebook_indices, embeddings
|
337 |
+
)
|
338 |
+
return decoder(quantized)
|
339 |
+
|
340 |
+
return reconstruct_masks
|
341 |
+
```
|
342 |
+
|
343 |
+
The function `extract_and_create_arrays` parses a given string pattern for segmentation tokens. It extracts these token numbers, converts them into integers, and then wraps them in MLX arrays for further mask reconstruction processing.
|
344 |
+
|
345 |
+
```python
|
346 |
+
def extract_and_create_arrays(pattern: str) -> List[mx.array]:
|
347 |
+
"""Extracts segmentation tokens from each object in the pattern and returns a list of MLX arrays."""
|
348 |
+
object_strings = [obj.strip() for obj in pattern.split(";") if obj.strip()]
|
349 |
+
|
350 |
+
seg_tokens_arrays = []
|
351 |
+
for obj in object_strings:
|
352 |
+
seg_tokens = re.findall(r"<seg(\d{3})>", obj)
|
353 |
+
if seg_tokens:
|
354 |
+
seg_numbers = [int(token) for token in seg_tokens]
|
355 |
+
seg_tokens_arrays.append(mx.array(seg_numbers))
|
356 |
+
|
357 |
+
return seg_tokens_arrays
|
358 |
+
```
|
359 |
+
|
360 |
+
The `parse_bbox` function interprets the model's output string to extract bounding box coordinates. Each detected object's location is denoted by a string format (`<loc1234>`). This function finds four numbers per object, corresponding to the box boundaries, and aggregates them into a list of bounding boxes.
|
361 |
+
|
362 |
+
```python
|
363 |
+
def parse_bbox(model_output: str):
|
364 |
+
entries = model_output.split(";")
|
365 |
+
|
366 |
+
results = []
|
367 |
+
for entry in entries:
|
368 |
+
entry = entry.strip()
|
369 |
+
numbers = re.findall(r"<loc(\d+)>", entry)
|
370 |
+
if len(numbers) == 4:
|
371 |
+
bbox = [int(num) for num in numbers]
|
372 |
+
results.append(bbox)
|
373 |
+
|
374 |
+
return results
|
375 |
+
```
|
376 |
+
|
377 |
+
The `gather_masks` function combines the reconstruction and bounding box parsing steps. For each object:
|
378 |
+
|
379 |
+
- It reconstructs the mask from its codebook indices.
|
380 |
+
- It obtains the corresponding bounding box coordinates.
|
381 |
+
- It normalizes these coordinates relative to a target image resolution (448×448 in this example).
|
382 |
+
|
383 |
+
Each mask is then paired with its coordinates and stored in a list, making it straightforward to later overlay these onto the original image.
|
384 |
+
|
385 |
+
```python
|
386 |
+
def gather_masks(output, codes_list, reconstruct_fn):
|
387 |
+
masks_list = []
|
388 |
+
|
389 |
+
target_width, target_height = 448, 448
|
390 |
+
for i, codes in enumerate(codes_list):
|
391 |
+
codes_batch = codes[None, :]
|
392 |
+
masks = reconstruct_fn(codes_batch)
|
393 |
+
mask_np = np.array(masks[0, :, :, 0], copy=False)
|
394 |
+
|
395 |
+
y_min, x_min, y_max, x_max = parse_bbox(output)[i]
|
396 |
+
x_min_norm = int(x_min / 1024 * target_width)
|
397 |
+
y_min_norm = int(y_min / 1024 * target_height)
|
398 |
+
x_max_norm = int(x_max / 1024 * target_width)
|
399 |
+
y_max_norm = int(y_max / 1024 * target_height)
|
400 |
+
|
401 |
+
masks_list.append(
|
402 |
+
{
|
403 |
+
"mask": mask_np,
|
404 |
+
"coordinates": (x_min_norm, y_min_norm, x_max_norm, y_max_norm),
|
405 |
+
}
|
406 |
+
)
|
407 |
+
|
408 |
+
return masks_list
|
409 |
+
```
|
410 |
+
|
411 |
+
The function `plot_masks` handles the visualization of the segmentation outcomes. It loads the original image and processes it for display. Two types of visualizations are provided:
|
412 |
+
|
413 |
+
- **Composite Overlay**: All masks are combined and overlaid on the original image.
|
414 |
+
- **Reconstructed Mask**: Each reconstructed mask is plotted next to the composite overlay.
|
415 |
+
|
416 |
+
Using OpenCV for mask resizing and Matplotlib for plotting, the function creates a series of subplots to clearly display both composite and individual mask overlays.
|
417 |
+
|
418 |
+
```python
|
419 |
+
def plot_masks(args, processor, masks_list):
|
420 |
+
|
421 |
+
image = load_image(args.image_path)
|
422 |
+
img_array = processor.image_processor(image)["pixel_values"][0].transpose(1, 2, 0)
|
423 |
+
img_array = (img_array * 0.5 + 0.5).clip(0, 1)
|
424 |
+
|
425 |
+
full = np.ones((448, 448, 1)) * (-1)
|
426 |
+
for mask_info in masks_list:
|
427 |
+
mask_np = mask_info["mask"]
|
428 |
+
x_min_norm, y_min_norm, x_max_norm, y_max_norm = mask_info["coordinates"]
|
429 |
+
|
430 |
+
width = x_max_norm - x_min_norm
|
431 |
+
height = y_max_norm - y_min_norm
|
432 |
+
|
433 |
+
resized_mask = cv2.resize(
|
434 |
+
mask_np, (width, height), interpolation=cv2.INTER_NEAREST
|
435 |
+
)
|
436 |
+
resized_mask = resized_mask.reshape((height, width, 1))
|
437 |
+
|
438 |
+
full[y_min_norm:y_max_norm, x_min_norm:x_max_norm] = resized_mask
|
439 |
+
|
440 |
+
n_masks = len(masks_list)
|
441 |
+
_, axs = plt.subplots(1, n_masks + 1, figsize=(5 * (n_masks + 1), 6))
|
442 |
+
|
443 |
+
axs[0].imshow(img_array)
|
444 |
+
axs[0].imshow(full, alpha=0.5)
|
445 |
+
axs[0].set_title("Mask Overlay")
|
446 |
+
axs[0].axis("on")
|
447 |
+
|
448 |
+
for i, mask_info in enumerate(masks_list, start=1):
|
449 |
+
mask_np = mask_info["mask"]
|
450 |
+
axs[i].imshow(mask_np)
|
451 |
+
axs[i].set_title(f"Reconstructed Mask {i}")
|
452 |
+
axs[i].axis("on")
|
453 |
+
|
454 |
+
plt.tight_layout()
|
455 |
+
plt.show()
|
456 |
+
```
|
457 |
+
|
458 |
+
The `main` function ties all the pieces together. It performs the following steps:
|
459 |
+
|
460 |
+
- **Loading**: Reads the specified PaliGemma model and image.
|
461 |
+
- **Setup**: Initializes the VAE checkpoint and extracts the reconstruction function.
|
462 |
+
- **Prompting**: Formats the prompt using the processor and generates a segmentation output via the model.
|
463 |
+
- **Processing**: Extracts segmentation tokens, reconstructs masks, and parses bounding box coordinates.
|
464 |
+
- **Visualization**: Finally, it calls the plotting function to display the results.
|
465 |
+
|
466 |
+
This function serves as the central point where data processing, model inference, mask reconstruction, and visualization are integrated into one complete pipeline.
|
467 |
+
|
468 |
+
```python
|
469 |
+
def main(args) -> None:
|
470 |
+
log.info(f"Loading PaliGemma model: {args.model_path}")
|
471 |
+
model, processor = load(args.model_path)
|
472 |
+
config = model.config
|
473 |
+
|
474 |
+
image = load_image(args.image_path)
|
475 |
+
log.info(f"Image size: {image.size}")
|
476 |
+
|
477 |
+
vae_path = _KNOWN_MODELS.get(args.vae_checkpoint_path, args.vae_checkpoint_path)
|
478 |
+
reconstruct_fn = get_reconstruct_masks(vae_path)
|
479 |
+
|
480 |
+
prompt = args.prompt.strip() + "\n"
|
481 |
+
log.info(f"Using prompt: '{prompt.strip()}'")
|
482 |
+
formatted_prompt = apply_chat_template(processor, config, prompt, num_images=1)
|
483 |
+
|
484 |
+
log.info("Generating segmentation output...")
|
485 |
+
output = generate(model, processor, formatted_prompt, image, verbose=False)
|
486 |
+
log.info(f"Model output: {output}")
|
487 |
+
|
488 |
+
codes_list = extract_and_create_arrays(output)
|
489 |
+
log.info(f"Extracted codes: {codes_list}")
|
490 |
+
|
491 |
+
log.info("Reconstructing mask from codes...")
|
492 |
+
masks_list = gather_masks(output, codes_list, reconstruct_fn)
|
493 |
+
|
494 |
+
log.info("Plotting masks...")
|
495 |
+
plot_masks(processor, masks_list)
|
496 |
+
```
|
497 |
+
|
498 |
+
Finally, the script includes an entry point that parses command-line arguments. Users can specify the image path, the prompt for the segmentation task, the model path, and the VAE checkpoint path. Once these are provided via `argparse`, the `main` function is invoked to start the processing pipeline.
|
499 |
+
|
500 |
+
```python
|
501 |
+
if __name__ == "__main__":
|
502 |
+
parser = argparse.ArgumentParser(description="Vision tasks using PaliGemma 2 mix.")
|
503 |
+
parser.add_argument(
|
504 |
+
"--image_path", type=str, default=IMAGE_PATH, help="Path to the input image."
|
505 |
+
)
|
506 |
+
parser.add_argument(
|
507 |
+
"--prompt", type=str, required=True, help="Prompt for the model."
|
508 |
+
)
|
509 |
+
parser.add_argument(
|
510 |
+
"--model_path", type=str, default=MODEL_PATH, help="Path to the mlx model."
|
511 |
+
)
|
512 |
+
parser.add_argument(
|
513 |
+
"--vae_checkpoint_path", type=str, default="oi", help="Path to the .npz file."
|
514 |
+
)
|
515 |
+
|
516 |
+
cli_args = parser.parse_args()
|
517 |
+
main(cli_args)
|
518 |
+
```
|
519 |
+
|
520 |
+
# Results
|
521 |
+
|
522 |
+
Let’s take a look at some examples and the segmentations we obtained from the model.
|
523 |
+
|
524 |
+
## Single Object Segmentation
|
525 |
+
|
526 |
+
In this section, we are going to show to examples of single object segmentation.
|
527 |
+
|
528 |
+
---
|
529 |
+
|
530 |
+
**Prompt:**
|
531 |
+
|
532 |
+
```text
|
533 |
+
"segment cow"
|
534 |
+
```
|
535 |
+
|
536 |
+
**Image:**
|
537 |
+
|
538 |
+
<figure>
|
539 |
+
<img src="images/cow_in.png" alt="Cow input" style="display: block; margin: 0 auto">
|
540 |
+
<figcaption style="text-align: center">Figure 4. Original image of size 400x400</figcaption>
|
541 |
+
</figure>
|
542 |
+
|
543 |
+
**Model output:**
|
544 |
+
|
545 |
+
```text
|
546 |
+
<loc0410><loc0528><loc0884><loc1023><seg072><seg055><seg062><seg079><seg104><seg009><seg104><seg096><seg068><seg041><seg103><seg019><seg100><seg004><seg091><seg067>
|
547 |
+
```
|
548 |
+
|
549 |
+
**Mask overlay and reconstructed mask:**
|
550 |
+
|
551 |
+
<figure>
|
552 |
+
<img src="images/cow_out.png" alt="Cow output" style="display: block; margin: 0 auto">
|
553 |
+
<figcaption style="text-align: center">Figure 5. Left: mask overlay onto the input image of size 448x448 | Right: reconstructed mask of size 64x64</figcaption>
|
554 |
+
</figure>
|
555 |
+
|
556 |
+
**Observation:**
|
557 |
+
|
558 |
+
Based on the overlay image, the model manages to detect the precise location of the cow but struggles a bit with the detailed outlines of the cow. Looking only at the reconstructed mask would not persuade me that this is a cow.
|
559 |
+
|
560 |
+
---
|
561 |
+
|
562 |
+
**Prompt:**
|
563 |
+
|
564 |
+
```text
|
565 |
+
"segment cat"
|
566 |
+
```
|
567 |
+
|
568 |
+
**Image:**
|
569 |
+
|
570 |
+
<figure>
|
571 |
+
<img src="images/cat_in.png" alt="Cat input" style="display: block; margin: 0 auto">
|
572 |
+
<figcaption style="text-align: center">Figure 6. Original image of size 400x400</figcaption>
|
573 |
+
</figure>
|
574 |
+
|
575 |
+
**Model output:**
|
576 |
+
|
577 |
+
```text
|
578 |
+
<loc0060><loc0000><loc0920><loc0879><seg039><seg107><seg018><seg006><seg056><seg120><seg058><seg042><seg079><seg094><seg009><seg099><seg074><seg010><seg078><seg012>
|
579 |
+
```
|
580 |
+
|
581 |
+
**Mask overlay and reconstructed mask:**
|
582 |
+
|
583 |
+
<figure>
|
584 |
+
<img src="images/cat_out.png" alt="Cat output" style="display: block; margin: 0 auto">
|
585 |
+
<figcaption style="text-align: center">Figure 7. Left: mask overlay onto the input image of size 448x448 | Right: reconstructed mask of size 64x64</figcaption>
|
586 |
+
</figure>
|
587 |
+
|
588 |
+
**Observation:**
|
589 |
+
|
590 |
+
Based on the overlay image, the model manages to detect the precise location of the cat, and is generally doing a good job with the cat’s outlines.
|
591 |
+
|
592 |
+
## Multiple Object Segmentation
|
593 |
+
|
594 |
+
It was tricky to find a working example for segmenting multiple objects, so there is only one example in this section. My observation is that the PaliGemma models are indeed very sensitive to the prompt formatting, and the 448–10B-8bit model might just not be powerful enough for the task of segmenting multiple objects.
|
595 |
+
|
596 |
+
---
|
597 |
+
|
598 |
+
**Prompt:**
|
599 |
+
|
600 |
+
```text
|
601 |
+
"segment left wheel ; right wheel"
|
602 |
+
```
|
603 |
+
|
604 |
+
**Image:**
|
605 |
+
|
606 |
+
<figure>
|
607 |
+
<img src="images/car_in.png" alt="Car input" style="display: block; margin: 0 auto">
|
608 |
+
<figcaption style="text-align: center">Figure 8. Original image of size 640x480</figcaption>
|
609 |
+
</figure>
|
610 |
+
|
611 |
+
**Model output:**
|
612 |
+
|
613 |
+
```text
|
614 |
+
<loc0591><loc0157><loc0794><loc0311> <seg092><seg004><seg044><seg092><seg120><seg061><seg029><seg120><seg090><seg023><seg021><seg090><seg015><seg041><seg044><seg073> right wheel ; <loc0586><loc0728><loc0794><loc0882> <seg092><seg004><seg089><seg092><seg120><seg048><seg054><seg038><seg119><seg029><seg021><seg090><seg095><seg041><seg044><seg073> right wheel
|
615 |
+
```
|
616 |
+
|
617 |
+
**Mask overlay and reconstructed mask:**
|
618 |
+
|
619 |
+
<figure>
|
620 |
+
<img src="images/car_out.png" alt="Car output" style="display: block; margin: 0 auto">
|
621 |
+
<figcaption style="text-align: center">Figure 9. Left: masks overlay onto the input image of size 448x448 | Right: reconstructed masks of size 64x64</figcaption>
|
622 |
+
</figure>
|
623 |
+
|
624 |
+
**Observation:**
|
625 |
+
|
626 |
+
Looking at the model output we can see that both segmentations are labeled as right wheel. Despite this, based on the overlay image, the model manages to detect the precise location of the wheels, and their outlines.
|
627 |
+
|
628 |
+
# Conclusion
|
629 |
+
|
630 |
+
In summary, we implemented a unified segmentation pipeline by combining Google’s PaliGemma 2 Mix model with Apple’s MLX framework. Our workflow involved formatting segmentation prompts, preprocessing images, extracting segmentation tokens and bounding box coordinates, decoding these tokens into segmentation masks, and finally overlaying the masks on the original images.
|
631 |
+
|
632 |
+
For single object segmentation, the model generally performed well: it accurately localised the object areas, as evidenced by both the “cat” and the “cow” examples. However, the segmentation for the “cow” revealed some challenges with capturing fine details, indicating areas for potential refinement.
|
633 |
+
|
634 |
+
The multiple object segmentation proved challenging, as we struggled to find more than one working example that produced multiple segmentations. In our single example, the model demonstrated the ability to detect the general locations of objects — successfully identifying both wheels — but it also suffered from issues such as prompt sensitivity and duplicate labelling. This difficulty may be attributed to the inherent prompt sensitivity of the model or potential limitations of the specific model variant, particularly the 448–10B-8bit configuration. These observations suggest that either refining prompt structures or exploring more powerful models may be essential for reliably handling segmentation tasks involving multiple objects.
|
635 |
+
|
636 |
+
# References
|
637 |
+
|
638 |
+
- [Introducing PaliGemma 2 mix: A vision-language model for multiple tasks](https://developers.googleblog.com/en/introducing-paligemma-2-mix/)
|
639 |
+
- [PaliGemma prompt and system instructions](https://ai.google.dev/gemma/docs/paligemma/prompt-system-instructions)
|
640 |
+
- [PaliGemma 2 Mix — New Instruction Vision Language Models by Google](https://huggingface.co/blog/paligemma2mix)
|
641 |
+
- [Welcome PaliGemma 2 — New vision language models by Google](https://huggingface.co/blog/paligemma2)
|
642 |
+
- [Introducing PaliGemma 2: Powerful Vision-Language Models, Simple Fine-Tuning](https://developers.googleblog.com/en/introducing-paligemma-2-powerful-vision-language-models-simple-fine-tuning/)
|
643 |
+
- [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://arxiv.org/abs/2412.03555)
|
src/posts/2025-05-06-chat-qwen3-ios/images/3_1.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/3_2.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/4_1.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/4_2.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/5_1.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/5_2.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/config_scheme_dest.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/images/developer_team.png
ADDED
![]() |
Git LFS Details
|
src/posts/2025-05-06-chat-qwen3-ios/index.qmd
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Chat with Qwen3 on your iPhone: A Step-by-Step Guide"
|
3 |
+
author: "Joana Levtcheva"
|
4 |
+
date: "2025-05-06"
|
5 |
+
categories: [Machine Learning, mlx, swift, llm, ios]
|
6 |
+
draft: false
|
7 |
+
---
|
8 |
+
|
9 |
+
Have you ever wanted to run a powerful large language model directly on your iPhone without sending your data to the cloud? Thanks to Apple’s [MLX Swift](https://github.com/ml-explore/mlx-swift) framework, you can now run the remarakably capable [Qwen3](https://qwenlm.github.io/blog/qwen3/) models right on your iPhone.
|
10 |
+
|
11 |
+
What makes this exciting is how capable these new models are even with as little as 4B parameters. Based on Qwen’s benchmark reports, Qwen3–4B can rival the performance of Qwen2.5–72B-Instruct, delivering impressive results with just a fraction of the parameters. This development means truly powerful AI can now fit in your pocket.
|
12 |
+
|
13 |
+
This blog post will guide you through the entire process from setting up your development environment to testing the model on your device. You can embrace true privacy, and take ownership of your AI experience. Your data stays on your device, where it belongs, and most importantly - you have access to this knowledge powerhouse even with no internet connection.
|
14 |
+
|
15 |
+
Medium post can be found [here](https://medium.com/@levchevajoana/chat-with-qwen3-on-your-iphone-a-step-by-step-guide-515bb957cd02).
|
16 |
+
|
17 |
+
# Main Steps
|
18 |
+
|
19 |
+
The whole process consists of three main steps:
|
20 |
+
|
21 |
+
- Enable Developer Mode on your iPhone
|
22 |
+
- Clone Apple’s MLX Swift Examples [repository](https://github.com/ml-explore/mlx-swift-examples)
|
23 |
+
- Configure and deploy the Xcode project to your device
|
24 |
+
|
25 |
+
# Enable Developer Mode on iPhone
|
26 |
+
|
27 |
+
In order to deploy custom apps to your iPhone, you need to enable **Developer Mode**. On your iPhone:
|
28 |
+
|
29 |
+
- Go to **Settings → Privacy & Security → Developer Mode**
|
30 |
+
- Toggle Developer Mode to **On**
|
31 |
+
Restart your iPhone when prompted
|
32 |
+
|
33 |
+
> **NOTE:** It is important to note that your device’s security is reduced in Developer Mode.
|
34 |
+
|
35 |
+
# Apple’s MLX Swift Examples Repository
|
36 |
+
|
37 |
+
Now that your iPhone is ready for development, let’s get the necessary code. We are going to run the **MLXChatExample**, an example chat app supporting LLMs and VLMs for iOS and macOS, which can be found in the Applications folder in the [mlx-swift-examples](https://github.com/ml-explore/mlx-swift-examples) repository. So, the first step is to clone the repository:
|
38 |
+
|
39 |
+
```bash
|
40 |
+
git clone [email protected]:ml-explore/mlx-swift-examples.git
|
41 |
+
```
|
42 |
+
|
43 |
+
Then, navigate to the cloned directory:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
cd mlx-swift-examples
|
47 |
+
```
|
48 |
+
|
49 |
+
And finally, open the Xcode project:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
open mlx-swift-examples.xcodeproj
|
53 |
+
```
|
54 |
+
|
55 |
+
The above command should open the project directly in Xcode.
|
56 |
+
|
57 |
+
# Configure and Deploy the Xcode Project
|
58 |
+
|
59 |
+
With the project open in Xcode, you should connect your iPhone to the Mac, and then make a few adjustments.
|
60 |
+
|
61 |
+
## Configure the Project
|
62 |
+
|
63 |
+
We are interested in deploying the **MLXChatExample** app on your connected iPhone, so you should:
|
64 |
+
|
65 |
+
- Change your scheme to **MLXChatExample** by going to **Product → Scheme → Choose Scheme** (this opens a dropdown menu from the Xcode’s top bar, which can be selected directly from the bar as well)
|
66 |
+
- Set the destination to your connected iPhone by going to **Product → Destination → Choose Destination**
|
67 |
+
|
68 |
+
<figure>
|
69 |
+
<img src="images/config_scheme_dest.png" alt="Configure scheme and destination" style="display: block; margin: 0 auto">
|
70 |
+
<figcaption style="text-align: center">Figure 1. Configure scheme and destination</figcaption>
|
71 |
+
</figure>
|
72 |
+
|
73 |
+
## Add Developer Team
|
74 |
+
|
75 |
+
In order to build and run the app we have to assign a development team:
|
76 |
+
|
77 |
+
- Select the **mlx-swift-examples** project in the Project Navigator (on the left)
|
78 |
+
- Go to the **Signing & Capabilities** tab
|
79 |
+
- Choose your **Team** from the dropdown menu
|
80 |
+
|
81 |
+
<figure>
|
82 |
+
<img src="images/developer_team.png" alt="Developer team" style="display: block; margin: 0 auto">
|
83 |
+
<figcaption style="text-align: center">Figure 2. Add developer team</figcaption>
|
84 |
+
</figure>
|
85 |
+
|
86 |
+
# Deploy to Your iPhone
|
87 |
+
|
88 |
+
Now we are ready to deploy **MLXChatExample** to our actual device:
|
89 |
+
|
90 |
+
- Click the Run button (▶) on the left to run the project, or alternatively choose **Product → Run**
|
91 |
+
- You might be asked to authorize Xcode development on your device
|
92 |
+
- If successful, the app will be downloaded to your iPhone
|
93 |
+
|
94 |
+
## Trust Developer Certificate
|
95 |
+
|
96 |
+
If this is your first app from this developer team, you would get a message on your iPhone that the app is from an **Untrusted Developer** and you are not allowed to use it. We should allow applications from this developer, so on your iPhone:
|
97 |
+
|
98 |
+
- Go to **Settings → General → VPN & Device Management**
|
99 |
+
- Under **Developer App**, tap on your app
|
100 |
+
Select **Trust “[Developer Name]”** and confirm
|
101 |
+
|
102 |
+
## Using Qwen3 on Your iPhone
|
103 |
+
|
104 |
+
With that, we are ready to test the MLXChatExample app. Let’s open it:
|
105 |
+
|
106 |
+
<figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
|
107 |
+
<div>
|
108 |
+
<img src="images/3_1.png" alt="App launch 1" />
|
109 |
+
</div>
|
110 |
+
<div>
|
111 |
+
<img src="images/3_2.png" alt="App launch 2" />
|
112 |
+
</div>
|
113 |
+
</figure>
|
114 |
+
<figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
|
115 |
+
Figure 3. Left: MLXChatExample app default screen | Right: models dropdown menu
|
116 |
+
</figcaption>
|
117 |
+
|
118 |
+
By default, the app loads with the **llama3.2:1b** model. You can select different models from the dropdown menu in the upper right corner, see **Figure 3**. Let’s choose **qwen3:4b**, which runs successfully on an iPhone 14 Pro, and type our first query.
|
119 |
+
|
120 |
+
If this is the first run for a given model, we have to wait for the model to download from Hugging Face to our cache. Below, in **Figure 4**, we can see our first query and how it triggers the model download (the arrow on top), and when clicking on the arrow we can see the download progress.
|
121 |
+
|
122 |
+
<figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
|
123 |
+
<div>
|
124 |
+
<img src="images/4_1.png" alt="Init query" />
|
125 |
+
</div>
|
126 |
+
<div>
|
127 |
+
<img src="images/4_2.png" alt="Download model" />
|
128 |
+
</div>
|
129 |
+
</figure>
|
130 |
+
<figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
|
131 |
+
Figure 4. Left: first query for a model triggering model download | Right: model downloading progress
|
132 |
+
</figcaption>
|
133 |
+
|
134 |
+
The Qwen3 models can have thinking enabled and disabled. At the moment the app doesn’t support a toggle to manually switch between both modes. The default behaviour is to use thinking (the thought process output is encapsulated between the tags `<think></think>`), and in order to disable the thinking we have to add **/no_think** at the end of our prompt. In this case the think tags appear with nothing between them. Refer to **Figure 5** below.
|
135 |
+
|
136 |
+
<figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
|
137 |
+
<div>
|
138 |
+
<img src="images/5_1.png" alt="Gen 1" />
|
139 |
+
</div>
|
140 |
+
<div>
|
141 |
+
<img src="images/5_2.png" alt="Gen 2" />
|
142 |
+
</div>
|
143 |
+
</figure>
|
144 |
+
<figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
|
145 |
+
Figure 5. Left: query with default thinking behaviour | Right: query with disabled thinking
|
146 |
+
</figcaption>
|
147 |
+
|
148 |
+
# The Benefits of On-Device AI
|
149 |
+
|
150 |
+
Running models like Qwen3 directly on your device offers several significant advantages:
|
151 |
+
|
152 |
+
- **Complete Privacy:** Your conversations never leave your device. No server logs, no data collection, just you and your AI.
|
153 |
+
- **Works Offline:** You have all the world’s knowledge in your pocket wherever you are, without depending on internet connection availability. Need AI assistance while traveling, hiking, or in areas with poor connectivity? Local models work anywhere, anytime.
|
154 |
+
- **No Subscription Fees:** Once you’ve set up the model, there are no ongoing costs.
|
155 |
+
- **No Rate Limits:** Chat as much as you want without hitting quotas.
|
156 |
+
- **Reduced Environmental Impact:** On-device inference eliminates the carbon footprint associated with massive data centers processing your requests.
|
157 |
+
|
158 |
+
# Conclusion
|
159 |
+
|
160 |
+
While cloud-based solutions may offer larger models with excellent performance, the gap is closing rapidly as efficiency improvements make smaller models surprisingly capable. The Qwen3 models are an excellent example for this, and are a “living” proof of the rapid advancements in the field of AI. The MLX Swift framework from Apple makes powerful machine learning accessible to everyday users on their personal devices. This shift from cloud-dependent to local-first AI is just beginning.
|
161 |
+
|
162 |
+
***Own your AI. Own your data. Go local.***
|
src/posts/2025-05-23-app-docker-fastapi/index.qmd
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Image Segmentation with PaliGemma 2 mix, Transformers, Docker, FastAPI, and GitHub Actions"
|
3 |
+
author: "Joana Levtcheva"
|
4 |
+
date: "2025-05-23"
|
5 |
+
categories: [Machine Learning, FastAPI, Docker, LLM, GitHub-Actions]
|
6 |
+
draft: false
|
7 |
+
---
|
8 |
+
|
9 |
+
In today’s fast-paced machine learning landscape, deploying AI models is just as important as developing them. In this blog post, we are going to walk through an image segmentation application using Google’s **PaliGemma 2 Mix** model and **transformers**, containerized with **Docker**, and served through a **FastAPI** backend. We are also going to discuss the CI/CD pipeline using **GitHub Actions** to automate building the Docker image and pushing it to Docker Hub. Let’s explore this service, why we chose these technologies, and how you can get started and use the service yourself!
|
10 |
+
|
11 |
+
The complete code is available on [GitHub](https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation).
|
12 |
+
|
13 |
+
# What is This Project All About?
|
14 |
+
|
15 |
+
At its core, this project provides a **FastAPI service** that allows you to perform image segmentation using natural language. You simply provide as input to the REST API:
|
16 |
+
|
17 |
+
- A **text prompt** describing what to segment
|
18 |
+
- An **image** via URL or file upload
|
19 |
+
- The specific **PaliGemma 2 model** to perform image segmentation
|
20 |
+
|
21 |
+
The service then returns:
|
22 |
+
|
23 |
+
- The base64 encoded **model input image**
|
24 |
+
- The base64 encoded segmentation **masks** clearly outlining the desired objects
|
25 |
+
- The **bounding box coordinates** for each segmented object
|
26 |
+
|
27 |
+
The **FastAPI** application is also containerized with **Docker** for consistent deployment across environments. A CI/CD pipeline with **GitHub Actions** is created for automated container builds and registry publishing to Docker Hub.
|
28 |
+
|
29 |
+
# Architectural Blueprint: How It All Works Together
|
30 |
+
|
31 |
+
Understanding the flow of data and the interaction of components is key. Let’s first take a look at our project structure:
|
32 |
+
|
33 |
+
```plaintext
|
34 |
+
project_folder/
|
35 |
+
├── app/
|
36 |
+
│ ├── __init__.py
|
37 |
+
│ ├── main.py # FastAPI application and endpoints
|
38 |
+
│ └── segmentation.py # Image segmentation logic
|
39 |
+
├── models/
|
40 |
+
│ ├── huggingface/ # Cache directory for Hugging Face models
|
41 |
+
│ └── vae-oid.npz # VAE model for mask generation
|
42 |
+
├── .dockerignore
|
43 |
+
├── .github/
|
44 |
+
│ └── workflows/ # GitHub Actions for Docker build and push
|
45 |
+
│ └── docker-build.yml # Workflow to build and push Docker images
|
46 |
+
├── .gitignore
|
47 |
+
├── docker-compose.yml
|
48 |
+
├── Dockerfile
|
49 |
+
├── README.md
|
50 |
+
└── requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
## User & Developer Workflow
|
54 |
+
|
55 |
+
Our system is designed with both the end-user and the developer in mind.
|
56 |
+
|
57 |
+
```{mermaid}
|
58 |
+
graph LR
|
59 |
+
User([User]) -->|Provides Image & Prompt| ClientApp[Client Application]
|
60 |
+
ClientApp -->|POST Request| FastAPI_Service[FastAPI Service]
|
61 |
+
FastAPI_Service -->|Process Input| PaliGemma_Model[PaliGemma Model]
|
62 |
+
PaliGemma_Model -->|Generate Segmentation Tokens| VAE_Model[VAE Model]
|
63 |
+
VAE_Model -->|Decode Masks| FastAPI_Service
|
64 |
+
FastAPI_Service -->|JSON Response | ClientApp
|
65 |
+
ClientApp -->|Display Results| User
|
66 |
+
|
67 |
+
Developer([Developer]) -->|Push Code| GitHubRepo[GitHub Repository]
|
68 |
+
GitHubRepo -->|Trigger| GitHubActions[GitHub Actions]
|
69 |
+
GitHubActions -->|Build & Push Image| DockerRegistry[Docker Hub]
|
70 |
+
DockerRegistry -->|Pull Image| DeploymentEnv[Deployment Environment]
|
71 |
+
DeploymentEnv -.->|Runs| FastAPI_Service
|
72 |
+
```
|
73 |
+
|
74 |
+
<figcaption style="text-align: center">Figure 1. User & Developer Workflow</figcaption>
|
75 |
+
|
76 |
+
**Figure 1** shows the workflow:
|
77 |
+
|
78 |
+
- A **User** interacts with a client application, providing an image and a text prompt, and optionally the specific PaliGemma 2 model.
|
79 |
+
- The **Client Application** sends an HTTP POST request to our **FastAPI Service**.
|
80 |
+
- The **FastAPI Service** preprocesses the input and feeds it to the **PaliGemma Model**.
|
81 |
+
- PaliGemma generates segmentation tokens, which are then passed to the **VAE Model**.
|
82 |
+
- The VAE Model decodes these into pixel-level masks and, along with bounding boxes, sends them back to the API.
|
83 |
+
The API returns a JSON response to the client.
|
84 |
+
|
85 |
+
To visualize the precise sequence of operations when a user requests an image segmentation, the following diagram details the interactions between the core components:
|
86 |
+
|
87 |
+
```{mermaid}
|
88 |
+
sequenceDiagram
|
89 |
+
participant User
|
90 |
+
participant Client
|
91 |
+
participant FastAPI
|
92 |
+
participant SegmentationPy as segmentation.py
|
93 |
+
participant PaliGemma as PaliGemma Model
|
94 |
+
participant VAE as VAE Model
|
95 |
+
|
96 |
+
User->>+Client: Upload image & prompt
|
97 |
+
Client->>+FastAPI: POST /segment
|
98 |
+
FastAPI->>+SegmentationPy: call segment_image()
|
99 |
+
SegmentationPy->>+PaliGemma: infer with PaliGemma
|
100 |
+
PaliGemma-->>-SegmentationPy: (tokens/features)
|
101 |
+
SegmentationPy->>+VAE: generate masks
|
102 |
+
VAE-->>-SegmentationPy: (pixel masks)
|
103 |
+
SegmentationPy-->>-FastAPI: return mask & coords
|
104 |
+
FastAPI-->>-Client: JSON response
|
105 |
+
Client-->>-User: display results
|
106 |
+
```
|
107 |
+
<figcaption style="text-align: center">Figure 2. Segmentation process</figcaption>
|
108 |
+
|
109 |
+
This sequence highlights how FastAPI acts as the entry point, delegating the complex segmentation logic to the `segmentation.py` module, which in turn leverages the PaliGemma and VAE models to produce the desired output.
|
110 |
+
|
111 |
+
For **developers**, pushing code to GitHub triggers **GitHub Actions**, which automatically builds a Docker image and pushes it to a **Container Registry** (Docker Hub), ready for deployment.
|
112 |
+
|
113 |
+
## Inside the Application
|
114 |
+
|
115 |
+
Within the Docker container, the application is neatly structured:
|
116 |
+
|
117 |
+
```{mermaid}
|
118 |
+
graph TD
|
119 |
+
subgraph "Docker Container"
|
120 |
+
subgraph "app/"
|
121 |
+
main[main.py
|
122 |
+
FastAPI Application]
|
123 |
+
segmentation[segmentation.py
|
124 |
+
Image Segmentation Logic]
|
125 |
+
main -->|imports| segmentation
|
126 |
+
end
|
127 |
+
|
128 |
+
subgraph "External Dependencies"
|
129 |
+
NP[numpy]
|
130 |
+
TR[transformers]
|
131 |
+
PT[PyTorch]
|
132 |
+
JF[JAX/Flax]
|
133 |
+
end
|
134 |
+
|
135 |
+
subgraph "Models"
|
136 |
+
PaliGemma[PaliGemma 2 mix]
|
137 |
+
VAE[VAE Checkpoint]
|
138 |
+
end
|
139 |
+
|
140 |
+
segmentation -->|uses| TR
|
141 |
+
segmentation -->|uses| PT
|
142 |
+
segmentation -->|uses| JF
|
143 |
+
segmentation -->|uses| NP
|
144 |
+
|
145 |
+
main -->|loads| PaliGemma
|
146 |
+
segmentation -->|loads| VAE
|
147 |
+
end
|
148 |
+
|
149 |
+
Client[Client Application] -->|HTTP Requests| main
|
150 |
+
|
151 |
+
subgraph "API Endpoints"
|
152 |
+
segment[POST /segment/]
|
153 |
+
root[GET /]
|
154 |
+
end
|
155 |
+
|
156 |
+
main -->|defines| segment
|
157 |
+
main -->|defines| root
|
158 |
+
Client -->|calls| segment
|
159 |
+
Client -->|calls| root
|
160 |
+
|
161 |
+
style main fill:#c2e0ff,stroke:#0078d7
|
162 |
+
style segmentation fill:#c2e0ff,stroke:#0078d7
|
163 |
+
style Client fill:#ffd7b5,stroke:#ff8c00
|
164 |
+
style segment fill:#d5e8d4,stroke:#82b366
|
165 |
+
style root fill:#d5e8d4,stroke:#82b366
|
166 |
+
```
|
167 |
+
|
168 |
+
<figcaption style="text-align: center">Figure 3. Application Architecture</figcaption>
|
169 |
+
|
170 |
+
- `app/main.py`: This is the heart of our API, built using FastAPI. It defines the API endpoints like `/` (for a welcome message) and `/segment` (for the actual segmentation).
|
171 |
+
- `app/segmentation.py`: This module contains all the core logic for image processing, interacting with the PaliGemma and VAE models, and generating the final masks and coordinates. It uses the libraries **JAX**, **Flax** and **Transformers** for efficient model execution and inference.
|
172 |
+
|
173 |
+
# Technology Stack Overview
|
174 |
+
|
175 |
+
Let’s first understand the key technologies used in the project.
|
176 |
+
|
177 |
+
## PaliGemma 2 mix
|
178 |
+
|
179 |
+
[PaliGemma 2 mix](https://developers.googleblog.com/en/introducing-paligemma-2-mix/) is a state-of-the-art vision-language model from Google that can comprehend both images and text. It represents a significant advancement in multimodal AI, allowing for natural language-guided image understanding. Once PaliGemma identifies the segments (as tokens), a Variational Autoencoder (VAE) model steps in. It decodes these segmentation tokens into the detailed, pixel-level masks that you see as the output. This particular service uses a VAE checkpoint specifically `vae-oid.npz` for this task. A comprehensive overview of the image segmentation process with Paligemma 2 mix can be found in one of my previous posts [here](https://joejoe1313.github.io/2025-04-15-paligemma-2-mix.html).
|
180 |
+
|
181 |
+
## FastAPI
|
182 |
+
|
183 |
+
[FastAPI](https://fastapi.tiangolo.com) is a web framework for building APIs with Python. We chose this framework for several compelling reasons:
|
184 |
+
|
185 |
+
- Automatic API documentation with **Swagger UI** and ReDoc
|
186 |
+
Data validation and serialization through **Pydantic** models
|
187 |
+
- Asynchronous request handling with support for `async`/`await` syntax
|
188 |
+
- Excellent performance comparable to Node.js and Go
|
189 |
+
- Built-in dependency injection system
|
190 |
+
- Security and authentication features out of the box
|
191 |
+
- WebSocket support and background tasks
|
192 |
+
|
193 |
+
In our project, `main.py` uses FastAPI to define the `/segment` endpoint which accepts form data including the prompt, and optionally an image URL or an uploaded image file, and the desired `model_id`.
|
194 |
+
|
195 |
+
## Transformers
|
196 |
+
|
197 |
+
**Transformers** is a library by **Hugging Face** that provides state-of-the-art pre-trained models for natural language processing and computer vision. For our project, it’s essential because:
|
198 |
+
|
199 |
+
- It provides easy access to the PaliGemma models
|
200 |
+
- Offers a unified API for loading and using different model architectures
|
201 |
+
- Handles model preprocessing and tokenization
|
202 |
+
- Supports efficient model inference
|
203 |
+
- Enables fine-tuning of models if needed
|
204 |
+
|
205 |
+
## JAX/Flax
|
206 |
+
|
207 |
+
JAX is a high-performance numerical computing library developed by Google. Flax is a neural network library built on top of JAX. Together, they provide:
|
208 |
+
|
209 |
+
- Accelerated computation on GPUs and TPUs
|
210 |
+
- Just-in-time compilation for optimized performance
|
211 |
+
- Automatic differentiation capabilities
|
212 |
+
- Functional programming approach to machine learning
|
213 |
+
|
214 |
+
In our app **JAX/Flax & Transformers** are used for scalable model execution and inference, JAX/Flax is used for the VAE model which decodes segmentation tokens into pixel-level masks.
|
215 |
+
|
216 |
+
## Docker & Docker Compose
|
217 |
+
|
218 |
+
### What is Docker?
|
219 |
+
|
220 |
+
Docker is a platform that uses OS-level virtualization to deliver software in packages called **containers**. A container is a standalone, executable package of software that includes everything needed to run an application: code, runtime, system tools, system libraries, and settings.
|
221 |
+
|
222 |
+
### Why Docker for This Project?
|
223 |
+
|
224 |
+
- **Consistency**: _“It works on my machine”_ is a phrase Docker aims to eliminate. By packaging the application, its dependencies (like specific versions of PyTorch, JAX, Transformers), we ensure it runs identically everywhere - from a developer’s laptop to a production server.
|
225 |
+
- **Isolation**: Containers run in isolation, preventing conflicts between different applications or dependencies on the same host system.
|
226 |
+
- **Simplified Deployment**: Docker abstracts away the underlying infrastructure. With a simple `docker-compose up` command, anyone can get the service running without manually installing Python, various libraries, or configuring complex environments.
|
227 |
+
- **Scalability**: Dockerized applications are inherently easier to scale. Orchestration tools like Kubernetes can manage multiple instances of our container to handle increased load.
|
228 |
+
- **Multi-Architecture Support**: Our Docker setup supports both `amd64` (common for desktops and servers) and `arm64` (increasingly used in cloud instances and devices like Raspberry Pi) architectures, broadening its usability.
|
229 |
+
|
230 |
+
### Key Docker Components Used:
|
231 |
+
|
232 |
+
- **`Dockerfile`**
|
233 |
+
|
234 |
+
This is the **recipe for building our Docker image**. It specifies the base image (e.g., a Python image), copies our application code (`app/` directory, `requirements.txt`, etc.) into the image, installs all necessary Python packages, and defines the command to run when the container starts (e.g., `uvicorn app.main:app`).
|
235 |
+
|
236 |
+
```Dockerfile
|
237 |
+
FROM python:3.11-slim
|
238 |
+
|
239 |
+
WORKDIR /app
|
240 |
+
|
241 |
+
COPY requirements.txt .
|
242 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
243 |
+
|
244 |
+
COPY . .
|
245 |
+
|
246 |
+
ENV MODEL_ID=google/paligemma2-3b-mix-448
|
247 |
+
ENV MODELS_DIR=/app/models
|
248 |
+
|
249 |
+
EXPOSE 8000
|
250 |
+
|
251 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
252 |
+
```
|
253 |
+
|
254 |
+
- **`docker-compose.yml`**
|
255 |
+
|
256 |
+
While a `Dockerfile` builds a single image, Docker Compose is used to define and run multi-container Docker applications. In our case, it simplifies running our FastAPI service. It can also manage networks, volumes, and other aspects of the application stack. For this project, it handles setting up the service and importantly, the volume mounting for model persistence.
|
257 |
+
|
258 |
+
```yaml
|
259 |
+
services:
|
260 |
+
paligemma-api:
|
261 |
+
image: joejoe1313/paligemma-image-segmentation:latest
|
262 |
+
ports:
|
263 |
+
- "8000:8000"
|
264 |
+
environment:
|
265 |
+
- MODEL_ID=google/paligemma2-3b-mix-448
|
266 |
+
- MODELS_DIR=/app/models
|
267 |
+
secrets:
|
268 |
+
- hf_token
|
269 |
+
volumes:
|
270 |
+
- $HOME/.cache/huggingface/hub:/app/models/huggingface
|
271 |
+
restart: unless-stopped
|
272 |
+
|
273 |
+
secrets:
|
274 |
+
hf_token:
|
275 |
+
file: $HOME/.cache/huggingface/token
|
276 |
+
```
|
277 |
+
|
278 |
+
- **`.dockerignore`:**
|
279 |
+
|
280 |
+
Similar to `.gitignore`, this file lists files and directories that should not be copied into the Docker image during the build process (e.g., local virtual environments, `.git` directory). This keeps the image lean and build times faster.
|
281 |
+
|
282 |
+
### Smart Model Management with Volume Mounting
|
283 |
+
|
284 |
+
Models, especially large ones like PaliGemma, take time to download. We use Docker’s **volume mounting** feature to map a directory on the host machine to a directory inside the container. This means:
|
285 |
+
|
286 |
+
- Models are downloaded once and **persisted** on your host machine.
|
287 |
+
- They are **reused** across container restarts.
|
288 |
+
- They can be **shared** if you run multiple instances.
|
289 |
+
- Updating models becomes easier as you can manage them directly on your host.
|
290 |
+
|
291 |
+
The default mount point in our `docker-compose.yaml` file is `$HOME/.cache/huggingface/hub:/app/models/huggingface` which maps the local `$HOME/.cache/huggingface/hub` directory to `/app/models/huggingface` in the container.
|
292 |
+
|
293 |
+
## CI/CD with GitHub Actions
|
294 |
+
|
295 |
+
GitHub Actions provides automated workflows for continuous integration and continuous delivery (CI/CD). This is crucial for modern software development, thus we have implemented a CI/CD pipeline using **GitHub Actions**. We can see our workflow `.github/workflows/docker-build.yml` below:
|
296 |
+
|
297 |
+
```yaml
|
298 |
+
name: Docker Build and Push
|
299 |
+
|
300 |
+
on:
|
301 |
+
push:
|
302 |
+
branches: main
|
303 |
+
pull_request:
|
304 |
+
branches: main
|
305 |
+
|
306 |
+
jobs:
|
307 |
+
build:
|
308 |
+
runs-on: ubuntu-latest
|
309 |
+
steps:
|
310 |
+
- name: Login to Docker Hub
|
311 |
+
uses: docker/login-action@v3
|
312 |
+
with:
|
313 |
+
username: ${{ vars.DOCKERHUB_USERNAME }}
|
314 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
315 |
+
|
316 |
+
- name: Set up QEMU
|
317 |
+
uses: docker/setup-qemu-action@v3
|
318 |
+
|
319 |
+
- name: Set up Docker Buildx
|
320 |
+
uses: docker/setup-buildx-action@v3
|
321 |
+
|
322 |
+
- name: Build and push
|
323 |
+
uses: docker/build-push-action@v6
|
324 |
+
with:
|
325 |
+
push: ${{ github.event_name != 'pull_request' }}
|
326 |
+
platforms: linux/amd64,linux/arm64
|
327 |
+
tags: ${{ vars.DOCKERHUB_USERNAME }}/paligemma-image-segmentation:latest
|
328 |
+
```
|
329 |
+
|
330 |
+
**What does it do?**
|
331 |
+
|
332 |
+
- **Trigger**: Every time code is pushed to the `main` branch (or a pull request is made to `main`), the GitHub Actions workflow is automatically triggered.
|
333 |
+
- **Build**: The workflow checks out the code and builds a multi-architecture (`amd64`/`arm64`) Docker image using the `Dockerfile`.
|
334 |
+
- **Push**: The newly built image is then pushed to a container registry (like Docker Hub).
|
335 |
+
- **Tag**: The image is tagged with the unique commit SHA (for traceability) and also with `latest` (for convenience).
|
336 |
+
|
337 |
+
**Figure 4** below shows a high level overview of the workflow:
|
338 |
+
|
339 |
+
```{mermaid}
|
340 |
+
sequenceDiagram
|
341 |
+
participant G as GitHub Repo
|
342 |
+
participant A as GitHub Actions
|
343 |
+
participant D as Docker Registry
|
344 |
+
|
345 |
+
G->>A: on: push / pull_request to main
|
346 |
+
activate A
|
347 |
+
A-->>A: Login to Docker Registry
|
348 |
+
A-->>A: Set up QEMU (for multi-arch)
|
349 |
+
A-->>A: Set up Docker Buildx
|
350 |
+
A-->>A: Build Docker image (multi-arch)
|
351 |
+
A-->>D: Push image (tagged with SHA & latest)
|
352 |
+
deactivate A
|
353 |
+
```
|
354 |
+
|
355 |
+
<figcaption style="text-align: center">Figure 4. GitHub Actions workflow</figcaption>
|
356 |
+
|
357 |
+
**How to use it in your fork**
|
358 |
+
|
359 |
+
If you fork this repository, you can set up your own CI/CD pipeline by adding the following secrets to your GitHub repository settings:
|
360 |
+
|
361 |
+
- `DOCKERHUB_USERNAME`: Your Docker Hub username.
|
362 |
+
- `DOCKERHUB_TOKEN`: Your Docker Hub access token (not your password!).
|
363 |
+
- You should also update the image in docker-compose.yaml with your username: `{DOCKERHUB_USERNAME}/paligemma-image-segmentation:latest`
|
364 |
+
|
365 |
+
This automated pipeline ensures that a fresh, deployable image is always available.
|
366 |
+
|
367 |
+
# Getting Started: Installation and Setup
|
368 |
+
|
369 |
+
Ready to try it yourself? Here’s how:
|
370 |
+
|
371 |
+
## Prerequisites
|
372 |
+
|
373 |
+
- **Docker**: Ensure Docker Desktop or Docker Engine is installed and running.
|
374 |
+
- **Hugging Face Token**: To access gated models like PaliGemma, you’ll need a Hugging Face token. Make sure it’s stored at `$HOME/.cache/huggingface/token` on your host machine, or adjust the path accordingly. The application will use this token when downloading models.
|
375 |
+
|
376 |
+
## Setup with Docker Compose:
|
377 |
+
|
378 |
+
- Clone the repository:
|
379 |
+
|
380 |
+
```bash
|
381 |
+
git clone https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation.git
|
382 |
+
cd PaliGemma-Image-Segmentation
|
383 |
+
```
|
384 |
+
|
385 |
+
- Ensure your **Hugging Face token** is in place as mentioned above. The `docker-compose.yml` is set up to mount your local Hugging Face cache, including the token.
|
386 |
+
- Run the application:
|
387 |
+
|
388 |
+
```bash
|
389 |
+
docker-compose up -d
|
390 |
+
```
|
391 |
+
|
392 |
+
> **Warning**: It’s generally good practice to avoid running Docker commands as a root user unless necessary. Docker Compose should typically be run by a `user` in the `docker group`.
|
393 |
+
|
394 |
+
This command will pull the pre-built Docker image and start the FastAPI service on `http://localhost:8000`. The `-dflag` runs it in detached mode.
|
395 |
+
|
396 |
+
## Choosing Your PaliGemma Model
|
397 |
+
|
398 |
+
You have two ways to specify which PaliGemma model variant to use:
|
399 |
+
|
400 |
+
- At **runtime via API**: Pass the `model_id` parameter in your POST request to the `/segment` endpoint.
|
401 |
+
- Via **Docker Environment Variable**: Set the `MODEL_ID` environment variable when starting the container.
|
402 |
+
|
403 |
+
> **Note**: If both are set, the `model_id` parameter in the API request takes precedence.
|
404 |
+
|
405 |
+
If a specified model isn’t found in the local cache (`/app/models/huggingface` inside the container, which maps to your `$HOME/.cache/huggingface/hub`), the application will attempt to download it from Hugging Face. **Figure 5** below shows a comprehensive overview of the possible steps.
|
406 |
+
|
407 |
+
```{mermaid}
|
408 |
+
flowchart TD
|
409 |
+
A[API Request] --> B{Check Local Cache}
|
410 |
+
B -->|Found| H[Load from Local Cache]
|
411 |
+
B -->|Not Found| C{Has HF Token?}
|
412 |
+
|
413 |
+
%% Style definitions
|
414 |
+
classDef process fill:#e0e0ff,stroke:#9999ff,color:black
|
415 |
+
classDef decision1 fill:#ffe0b0,stroke:#ffbb66,color:black
|
416 |
+
classDef decision2 fill:#d0f0d0,stroke:#aaddaa,color:black
|
417 |
+
classDef cache fill:#d0e0ff,stroke:#aabbee,color:black
|
418 |
+
|
419 |
+
%% Apply styles
|
420 |
+
class A,D,F,I,E,Z process
|
421 |
+
class B decision1
|
422 |
+
class C decision2
|
423 |
+
class G,H cache
|
424 |
+
|
425 |
+
C -->|Yes| D[Authenticate with HF]
|
426 |
+
C -->|No| E[Try Loading Public Model]
|
427 |
+
D --> F[Download Model]
|
428 |
+
F --> G[Save to Cache]
|
429 |
+
E -->|Success| G
|
430 |
+
E -->|Failure| Z[Auth Error]
|
431 |
+
G --> H
|
432 |
+
H --> I[Use Model]
|
433 |
+
```
|
434 |
+
|
435 |
+
<figcaption style="text-align: center">Figure 5. Model download scheme</figcaption>
|
436 |
+
|
437 |
+
# Examples: Putting the API to the Test
|
438 |
+
|
439 |
+
Once the service is running (default: `http://localhost:8000`):
|
440 |
+
|
441 |
+
- **API Docs**: Visit `http://localhost:8000/docs` for interactive API documentation.
|
442 |
+
|
443 |
+
## Health Check (GET /)
|
444 |
+
|
445 |
+
Verify the API is up and running.
|
446 |
+
|
447 |
+
**Request:**
|
448 |
+
|
449 |
+
```python
|
450 |
+
import requests
|
451 |
+
|
452 |
+
response = requests.get("http://localhost:8000/")
|
453 |
+
print(response.json())
|
454 |
+
```
|
455 |
+
|
456 |
+
**Response:**
|
457 |
+
|
458 |
+
```json
|
459 |
+
{
|
460 |
+
"message": "Welcome to the PaliGemma Segmentation API!"
|
461 |
+
}
|
462 |
+
```
|
463 |
+
|
464 |
+
## Segmenting an Image (POST /segment/)
|
465 |
+
|
466 |
+
Form Parameters:
|
467 |
+
|
468 |
+
- `prompt` (str, required): Text description of objects to segment (e.g., "segment the red car").
|
469 |
+
- `image_url` (str, optional): URL of the image to segment.
|
470 |
+
- `image_file` (UploadFile, optional): Uploaded image file to segment.
|
471 |
+
- `model_id` (str, optional): Specific PaliGemma model ID to use.
|
472 |
+
|
473 |
+
### Using an Image URL
|
474 |
+
|
475 |
+
```{mermaid}
|
476 |
+
sequenceDiagram
|
477 |
+
participant C as Client
|
478 |
+
participant S as /segment Endpoint
|
479 |
+
C->>S: POST Request with JSON body: { "image_url": "your_image_url.jpg", "prompt": "object to segment" }
|
480 |
+
S-->>S: Download Image
|
481 |
+
S-->>S: Process with PaliGemma & VAE
|
482 |
+
S-->>C: JSON Response: { "image": "base64_input_image", "masks": [ { "mask": "base64_mask_data", "coordinates": [x_min,y_min,x_max,y_max] } ] }
|
483 |
+
```
|
484 |
+
<figcaption style="text-align: center">Figure 6. Segment request</figcaption>
|
485 |
+
|
486 |
+
**Request:**
|
487 |
+
|
488 |
+
```python
|
489 |
+
import requests
|
490 |
+
|
491 |
+
data = {
|
492 |
+
"prompt": "segment left wheel",
|
493 |
+
"image_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
|
494 |
+
}
|
495 |
+
|
496 |
+
response = requests.post("http://localhost:8000/segment", data=data)
|
497 |
+
print(response.json())
|
498 |
+
```
|
499 |
+
|
500 |
+
### Uploading an Image File
|
501 |
+
|
502 |
+
**Request:**
|
503 |
+
|
504 |
+
```python
|
505 |
+
import os
|
506 |
+
import requests
|
507 |
+
|
508 |
+
segm_url = "http://localhost:8000/segment"
|
509 |
+
image_path = "your_image.png"
|
510 |
+
|
511 |
+
with open(image_path, "rb") as image_file:
|
512 |
+
data = {
|
513 |
+
"prompt": "segment the main object" # Adjust prompt as needed
|
514 |
+
}
|
515 |
+
files = {
|
516 |
+
"image_file": (os.path.basename(image_path), image_file, "image/png") # Or image/jpeg
|
517 |
+
}
|
518 |
+
|
519 |
+
response = requests.post(segm_url, files=files, data=data)
|
520 |
+
print(response.json())
|
521 |
+
```
|
522 |
+
|
523 |
+
### Expected JSON Response Structure:
|
524 |
+
|
525 |
+
**Response:**
|
526 |
+
|
527 |
+
```python
|
528 |
+
{
|
529 |
+
"image": "base64_encoded_model_input_image_data",
|
530 |
+
"masks": [
|
531 |
+
{
|
532 |
+
"mask": "base64_encoded_mask_data_for_object_1",
|
533 |
+
"coordinates": [x_min, y_min, x_max, y_max]
|
534 |
+
}
|
535 |
+
# ...more masks if multiple instances of the prompt are found
|
536 |
+
]
|
537 |
+
}
|
538 |
+
```
|
539 |
+
|
540 |
+
See [example.ipynb](https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation/blob/main/example.ipynb) for a demonstration of the segmentation pipeline.
|
541 |
+
|
542 |
+
# Conclusion
|
543 |
+
|
544 |
+
In this guide, we explored an image segmentation API using PaliGemma 2 mix, FastAPI, Transformers, and Docker. This service enables users to segment objects in images using natural language prompts, opening up a wide range of applications in computer vision, image editing, and data analysis. The containerized application provides a flexible, scalable solution that can be easily deployed across various environments. The combination of FastAPI’s performance and Docker’s portability makes this architecture ideal for production ML applications.
|
545 |
+
|
546 |
+
_Happy coding!_
|