JoeJoe1313 commited on
Commit
e9dcf74
·
1 Parent(s): 0361bfb
src/posts/2025-02-13-qwen2_5-vl-mlx-vlm/index.qmd CHANGED
@@ -1,5 +1,6 @@
1
  ---
2
  title: "Qwen2.5-vl with MLX-VLM"
 
3
  date: "2025-02-13"
4
  categories: [Machine Learning, mlx, vlm]
5
  draft: false
 
1
  ---
2
  title: "Qwen2.5-vl with MLX-VLM"
3
+ author: "Joana Levtcheva"
4
  date: "2025-02-13"
5
  categories: [Machine Learning, mlx, vlm]
6
  draft: false
src/posts/2025-04-06-fine-tuning-function-calling/index.qmd ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "Fine-Tuning a Model for Function-Calling with MLX-LM"
3
+ author: "Joana Levtcheva"
4
+ date: "2025-04-06"
5
+ categories: [Machine Learning, mlx, llm]
6
+ draft: false
7
+ ---
8
+
9
+ In this post, we explore the process of fine-tuning a language model for function-calling using [MLX-LM](https://github.com/ml-explore/mlx-lm). Following the Hugging Face Agents course [notebook](https://huggingface.co/agents-course/notebooks/blob/main/bonus-unit1/bonus-unit1.ipynb), we’ll walk through the steps from setting up the environment to training the model with LoRA adapters. The goal is to empower the model with the ability to intelligently plan and generate function calls, making it a versatile tool for interactive applications. Medium post can be found [here](https://medium.com/@levchevajoana/fine-tuning-a-model-for-function-calling-with-mlx-lm-d00d587e2559)
10
+
11
+ ## Introduction
12
+
13
+ Modern AI models can do much more than generate plain text — they can now integrate with external tools by “calling” functions based on user intent. In this tutorial, we demonstrate how to adapt a pre-trained model (in our case, the [gemma-2-2b-it-4bit](https://huggingface.co/mlx-community/gemma-2-2b-it-4bit) model from the [MLX Community](https://huggingface.co/mlx-community)) to handle function-calling by using the `mlx-lm` package. This involves creating a specialized chat template, preprocessing a dataset of function call interactions, and applying LoRA for efficient fine-tuning.
14
+
15
+ ## Setting Up the Model and Tokenizer
16
+
17
+ We start by importing the necessary libraries and modules, including the MLX-LM package, dataset utilities, and LoRA functions.
18
+
19
+ ```python
20
+ import json
21
+ import os
22
+ from enum import Enum
23
+ from typing import Dict, List, Tuple, Union
24
+
25
+ import mlx.optimizers as optim
26
+ from datasets import load_dataset
27
+ from mlx.utils import tree_flatten
28
+ from mlx_lm import generate, load
29
+ from mlx_lm.tuner import TrainingArgs, datasets, linear_to_lora_layers, train
30
+ ```
31
+
32
+ After loading our model and tokenizer,
33
+
34
+ ```python
35
+ model_path = "mlx-community/gemma-2-2b-it-4bit"
36
+ model, tokenizer = load(model_path)
37
+ ```
38
+
39
+ we customize the tokenizer’s chat template to define the structure of our conversational interactions.
40
+
41
+ ```python
42
+ tokenizer.chat_template = (
43
+ "{{ bos_token }}"
44
+ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}"
45
+ "{% for message in messages %}"
46
+ "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] | trim + '<end_of_turn><eos>\n' }}"
47
+ "{% endfor %}"
48
+ "{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
49
+ )
50
+ ```
51
+
52
+ This template embeds special tokens (like `<bos>`, `<start_of_turn>`,` <think>`, and `<tool_call>`) that mark the different stages of the conversation - from the user’s prompt to the model’s internal reasoning and eventual function call.
53
+
54
+ ## Dataset Preparation and Preprocessing
55
+
56
+ We use the dataset [Jofthomas/hermes-function-calling-thinking-V1](https://huggingface.co/datasets/Jofthomas/hermes-function-calling-thinking-V1) which contains conversations involving function calls.
57
+
58
+ ```python
59
+ dataset_path = "Jofthomas/hermes-function-calling-thinking-V1"
60
+ ```
61
+
62
+ Let’s load the dataset.
63
+
64
+ ```python
65
+ dataset = load_dataset(dataset_path)
66
+ dataset
67
+ ```
68
+
69
+ This outputs
70
+
71
+ ```
72
+ DatasetDict({
73
+ train: Dataset({
74
+ features: ['conversations'],
75
+ num_rows: 3570
76
+ })
77
+ })
78
+ ```
79
+
80
+ showing that the dataset originally includes a “conversations” column, and has 3570 rows. We rename this column to “messages” for consistency
81
+
82
+ ```python
83
+ dataset = dataset.rename_column("conversations", "messages")
84
+ dataset
85
+ ```
86
+
87
+ and then apply the following preprocessing function
88
+
89
+ ```python
90
+ def preprocess(sample):
91
+ messages = sample["messages"]
92
+ first_message = messages[0]
93
+
94
+ # Instead of adding a system message, we merge the content into the first user message
95
+ if first_message["role"] == "system":
96
+ system_message_content = first_message["content"]
97
+ # Merge system content with the first user message
98
+ messages[1]["content"] = (
99
+ system_message_content
100
+ + "Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>\n\n"
101
+ + messages[1]["content"]
102
+ )
103
+ # Remove the system message from the conversation
104
+ messages.pop(0)
105
+
106
+ return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
107
+ ```
108
+
109
+ to the dataset
110
+
111
+ ```python
112
+ dataset = dataset.map(preprocess, remove_columns="messages")
113
+ dataset = dataset["train"].train_test_split(0.1)
114
+ dataset
115
+ ```
116
+
117
+ This function merges any system messages into the first user message, ensuring the context is maintained without extra role annotations. This outputs
118
+
119
+ ```text
120
+ DatasetDict({
121
+ train: Dataset({
122
+ features: ['text'],
123
+ num_rows: 3213
124
+ })
125
+ test: Dataset({
126
+ features: ['text'],
127
+ num_rows: 357
128
+ })
129
+ })
130
+ ```
131
+
132
+ showing that we have successfully separated our original dataset into a train set with 3213 records, and a test set with 357 records. Each sample is now a formatted text string ready for fine-tuning. Let’s see one train example
133
+
134
+ ```text
135
+ <bos><start_of_turn>human
136
+ You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'create_todo', 'description': 'Create a new todo item', 'parameters': {'type': 'object', 'properties': {'task': {'type': 'string', 'description': 'The task description'}, 'due_date': {'type': 'string', 'format': 'date', 'description': 'The due date of the task'}, 'priority': {'type': 'integer', 'description': 'The priority of the task (1-5)'}}, 'required': ['task', 'due_date']}}}, {'type': 'function', 'function': {'name': 'calculate_bmi', 'description': 'Calculate the Body Mass Index (BMI)', 'parameters': {'type': 'object', 'properties': {'weight': {'type': 'number', 'description': 'The weight in kilograms'}, 'height': {'type': 'number', 'description': 'The height in meters'}}, 'required': ['weight', 'height']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
137
+ <tool_call>
138
+ {tool_call}
139
+ </tool_call>Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>
140
+
141
+ I need to create a new task.<end_of_turn><eos>
142
+ <start_of_turn>model
143
+ Sure, I can help with that. Could you please provide me with the task description, the due date, and the priority level?<end_of_turn><eos>
144
+ <start_of_turn>human
145
+ The task is to prepare a presentation for the annual meeting. The due date is 2022-09-15 and the priority level is 3.<end_of_turn><eos>
146
+ <start_of_turn>model
147
+ <think>Okay, so I need to figure out the reasoning that goes between the conversation and the next function call. Let me break this down step by step.
148
+
149
+ First, looking at the conversation, the user starts by saying, "I need to create a new task." The model responds by asking for the task description, due date, and priority level. The user then provides all the necessary details: the task is to prepare a presentation for the annual meeting, the due date is 2022-09-15, and the priority is 3.
150
+
151
+ Now, the model's next move should be to call the appropriate function. The available functions are 'create_todo' and 'calculate_bmi'. Since the user is talking about creating a new task, 'create_todo' is the relevant function here.
152
+
153
+ Examining the function's parameters, it requires 'task', 'due_date', and takes 'priority' as optional. The user provided all three, so we can include them in the arguments.
154
+
155
+ Therefore, the model will execute the 'create_todo' function with the provided task details. This makes sense because the conversation is about setting up a new task, and the function is designed for that exact purpose.
156
+ </think><tool_call>
157
+ {'name': 'create_todo', 'arguments': {'task': 'Prepare a presentation for the annual meeting', 'due_date': '2022-09-15', 'priority': 3}}
158
+ </tool_call><end_of_turn><eos>
159
+ <start_of_turn>tool
160
+ <tool_response>
161
+ {'status': 'success', 'message': 'Todo item successfully created', 'todo_id': '12345'}
162
+ </tool_response><end_of_turn><eos>
163
+ <start_of_turn>model
164
+ Your task has been successfully created. The ID for your new task is 12345.<end_of_turn><eos>
165
+ ```
166
+
167
+ ## Training Setup with LoRA Adapters
168
+
169
+ To efficiently fine-tune the model without retraining all of its parameters, we leverage LoRA. First, we create a directory to store adapter configurations and weights.
170
+
171
+ ```python
172
+ adapter_path = "adapters_fc"
173
+ os.makedirs(adapter_path, exist_ok=True)
174
+ adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
175
+ adapter_file_path = os.path.join(adapter_path, "adapters.safetensors")
176
+ ```
177
+
178
+ Then we define our LoRA configuration, with parameters like number of layers 8, a rank of 16, scale of 64, a dropout of 0.05,
179
+
180
+ ```python
181
+ lora_config = {
182
+ "num_layers": 8,
183
+ "lora_parameters": {
184
+ "rank": 16,
185
+ "scale": 64,
186
+ "dropout": 0.05,
187
+ },
188
+ }
189
+ ```
190
+
191
+ and save it as a JSON file.
192
+
193
+ ```python
194
+ with open(adapter_config_path, "w") as f:
195
+ json.dump(lora_config, f, indent=4)
196
+ ```
197
+
198
+ Next, we define the training arguments, specifically setting a single iteration,
199
+
200
+ ```python
201
+ training_args = TrainingArgs(
202
+ adapter_file=adapter_file_path,
203
+ iters=1,
204
+ steps_per_eval=50,
205
+ )
206
+ ```
207
+
208
+ and freeze the original model parameters.
209
+
210
+ ```python
211
+ _ = model.freeze()
212
+ ```
213
+
214
+ Then, we convert selected linear layers to LoRA layers to make only a small subset of parameters trainable.
215
+
216
+ ```python
217
+ linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])
218
+ ```
219
+
220
+ In our example, this results in
221
+
222
+ ```python
223
+ num_train_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
224
+ print(f"Number of trainable parameters: {num_train_params}")
225
+ ```
226
+
227
+ 983,040 trainable parameters. Finally, we should not forget to activate training mode while still preserving the frozen state of the main model parameters.
228
+
229
+ ```python
230
+ _ = model.train()
231
+ ```
232
+
233
+ ## Fine-Tuning Process and Metrics
234
+
235
+ With our model and dataset ready, we configure a metrics tracker to log both training and validation losses,
236
+
237
+ ```python
238
+ class Metrics:
239
+ def __init__(self) -> None:
240
+ self.train_losses: List[Tuple[int, float]] = []
241
+ self.val_losses: List[Tuple[int, float]] = []
242
+
243
+ def on_train_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
244
+ self.train_losses.append((info["iteration"], info["train_loss"]))
245
+
246
+ def on_val_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
247
+ self.val_losses.append((info["iteration"], info["val_loss"]))
248
+ ```
249
+
250
+ and create an instance of this class.
251
+
252
+ ```python
253
+ metrics = Metrics()
254
+ ```
255
+
256
+ We also create mlx-lm–suitable datasets by first defining the following configuration about our datasets,
257
+
258
+ ```python
259
+ configs = {
260
+ "mask_prompt": False,
261
+ "prompt_feature": "prompt",
262
+ "text_feature": "text",
263
+ "completion_feature": "completion",
264
+ "chat_feature": "messages",
265
+ }
266
+ ```
267
+
268
+ and then create a train set with the help of the mlx-lm function `datasets.create_dataset` and passing the configuration from above.
269
+
270
+ ```python
271
+ train_set = datasets.create_dataset(
272
+ dataset["train"],
273
+ tokenizer,
274
+ configs
275
+ )
276
+ ```
277
+
278
+ Similarly, we create our validation set.
279
+
280
+ ```python
281
+ val_set = datasets.create_dataset(
282
+ dataset["test"],
283
+ tokenizer,
284
+ configs
285
+ )
286
+ ```
287
+
288
+ Finally, we start the fine-tuning process by calling the `train()` function.
289
+
290
+ ```python
291
+ train(
292
+ model=model,
293
+ tokenizer=tokenizer,
294
+ args=training_args,
295
+ optimizer=optim.Adam(learning_rate=1e-5),
296
+ train_dataset=train_set,
297
+ val_dataset=val_set,
298
+ training_callback=metrics,
299
+ )
300
+ ```
301
+
302
+ The training logs report both training and validation losses, along with performance metrics like tokens processed per second and memory usage. After training, the adapter weights are saved and can later be reloaded to quickly deploy the fine-tuned model.
303
+
304
+ ```
305
+ Starting training..., iters: 1
306
+ Iter 1: Val loss 1.821, Val took 128.584s
307
+ Iter 1: Train loss 1.861, Learning Rate 1.000e-05, It/sec 0.430, Tokens/sec 160.427, Trained Tokens 3735, Peak mem 20.665 GB
308
+ Saved final weights to adapters_fc/adapters.safetensors.
309
+ ```
310
+
311
+ ## Evaluating the Fine-Tuned Model
312
+
313
+ After training, we reload the model with the newly learned LoRA weights,
314
+
315
+ ```python
316
+ model_lora, _ = load(model_path, adapter_path=adapter_path)
317
+ ```
318
+
319
+ set our prompt to
320
+
321
+ ```python
322
+ prompt="""<bos><start_of_turn>human
323
+ You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'convert_currency', 'description': 'Convert from one currency to another', 'parameters': {'type': 'object', 'properties': {'amount': {'type': 'number', 'description': 'The amount to convert'}, 'from_currency': {'type': 'string', 'description': 'The currency to convert from'}, 'to_currency': {'type': 'string', 'description': 'The currency to convert to'}}, 'required': ['amount', 'from_currency', 'to_currency']}}}, {'type': 'function', 'function': {'name': 'calculate_distance', 'description': 'Calculate the distance between two locations', 'parameters': {'type': 'object', 'properties': {'start_location': {'type': 'string', 'description': 'The starting location'}, 'end_location': {'type': 'string', 'description': 'The ending location'}}, 'required': ['start_location', 'end_location']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
324
+ <tool_call>
325
+ {tool_call}
326
+ </tool_call>Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>
327
+
328
+ Hi, I need to convert 500 USD to Euros. Can you help me with that?<end_of_turn><eos>
329
+ <start_of_turn>model
330
+ <think>"""
331
+ ```
332
+
333
+ and generate a response
334
+
335
+ ```python
336
+ generate(model_lora, tokenizer, prompt=prompt, verbose=True, max_tokens=1000)
337
+ ```
338
+
339
+ which returns
340
+
341
+ ```text
342
+ ==========
343
+
344
+ To convert USD to Euros, I need to use the 'convert_currency' function from the provided tools. I need to provide the amount to convert, the currency to convert from (USD), and the currency to convert to (Euros). I should also make sure the amount is a number.
345
+ </think>
346
+
347
+ <tool_call>
348
+ {
349
+ 'name': 'convert_currency',
350
+ 'arguments': {
351
+ 'amount': 500,
352
+ 'from_currency': 'USD',
353
+ 'to_currency': 'EUR'
354
+ }
355
+ }
356
+ </tool_call>
357
+
358
+ ==========
359
+ Prompt: 460 tokens, 862.170 tokens-per-sec
360
+ Generation: 135 tokens, 68.472 tokens-per-sec
361
+ Peak memory: 20.665 GB
362
+ ```
363
+
364
+ The model first walks through a thought process before generating a function call to convert USD to Euros. This demonstrates the model’s improved ability to generate precise JSON function calls within `<tool_call>` XML tags.
365
+
366
+ ## Conclusion
367
+
368
+ Fine-tuning a model for function-calling can significantly enhance its interactivity and real-world utility. By adapting the chat template, preprocessing the dataset, and applying LoRA adapters, we’ve demonstrated a streamlined approach to training a model that can generate executable function calls with clear reasoning. It is impressive that we achieved this by using only the `mlx-lm` package. Happy fine-tuning!
src/posts/2025-04-15-paligemma-2-mix/images/car_in.png ADDED

Git LFS Details

  • SHA256: 8efe48224397e33fd05310794f94fb8217b4436c4b4a05ee0c3b53fa95c22c87
  • Pointer size: 130 Bytes
  • Size of remote file: 24.6 kB
src/posts/2025-04-15-paligemma-2-mix/images/car_out.png ADDED

Git LFS Details

  • SHA256: 8ff2338db99e4341637ddbacae8385cedf6c22720944c55adef814795c2f53d4
  • Pointer size: 130 Bytes
  • Size of remote file: 40.2 kB
src/posts/2025-04-15-paligemma-2-mix/images/cat_in.png ADDED

Git LFS Details

  • SHA256: dff76fdc3ccf51169c4067d648da865d5a869d1cfbbd0485375f8d1dc8e4a706
  • Pointer size: 130 Bytes
  • Size of remote file: 17.4 kB
src/posts/2025-04-15-paligemma-2-mix/images/cat_out.png ADDED

Git LFS Details

  • SHA256: 8a30dca97b7a1bfe54c4952f9b3f220088c7f4cae07faf76d333b78c0e50e6a2
  • Pointer size: 130 Bytes
  • Size of remote file: 26.1 kB
src/posts/2025-04-15-paligemma-2-mix/images/cow_in.png ADDED

Git LFS Details

  • SHA256: 76390e4a95bbef72a5ad07e8dae5f3afcae8e4cecf641f37e8462236c9ced565
  • Pointer size: 130 Bytes
  • Size of remote file: 21.6 kB
src/posts/2025-04-15-paligemma-2-mix/images/cow_out.png ADDED

Git LFS Details

  • SHA256: caa438c66f931f18aa2c7dc980d35bcc0971a28d10be857b8aab872c391eb911
  • Pointer size: 130 Bytes
  • Size of remote file: 30.9 kB
src/posts/2025-04-15-paligemma-2-mix/images/input_bb.png ADDED

Git LFS Details

  • SHA256: a71c39606e26bd8e3a8e03d69529a5098e13625ab932c7d77d8d7e3f914a1fe2
  • Pointer size: 130 Bytes
  • Size of remote file: 27.8 kB
src/posts/2025-04-15-paligemma-2-mix/images/map_mask.png ADDED

Git LFS Details

  • SHA256: 0d61685c93bab19135cbee308b8f358f5da2d3d0afc6e1604946795f62537fc1
  • Pointer size: 130 Bytes
  • Size of remote file: 28.9 kB
src/posts/2025-04-15-paligemma-2-mix/images/paligemma2-architecture.png ADDED

Git LFS Details

  • SHA256: b0ddde55dab331fb1fc6bb00dd85c8b0baf3a28217afce6bc32b805cef6e2d39
  • Pointer size: 130 Bytes
  • Size of remote file: 11.3 kB
src/posts/2025-04-15-paligemma-2-mix/index.qmd ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "Image Segmentation with PaliGemma 2 Mix and MLX"
3
+ author: "Joana Levtcheva"
4
+ date: "2025-04-15"
5
+ categories: [Machine Learning, mlx, vlm]
6
+ draft: false
7
+ ---
8
+
9
+ In this post, we are going to explore Google’s [**PaliGemma 2 mix**](https://developers.googleblog.com/en/introducing-paligemma-2-mix/) vision-language model (VLM), and its capabilities to perform image segmentation. What’s interesting is that we are going to perform this task by only using Apple’s MLX framework, and MLX-VLM. This would eliminate the dependency of using JAX/Flax as in the original Google’s segmentation [script](https://github.com/google-research/big_vision/blob/main/big_vision/evaluators/proj/paligemma/transfers/segmentation.py), and would allow us to fully and seamlessly utilise Apple’s unified memory. Medium post can be found [here](https://medium.com/@levchevajoana/image-segmentation-with-paligemma-2-mix-and-mlx-7e69e077968b).
10
+
11
+ # Introduction
12
+
13
+ ## PaliGemma 2
14
+
15
+ In December 2024 Google introduced the [PaliGemma 2](https://developers.googleblog.com/en/introducing-paligemma-2-powerful-vision-language-models-simple-fine-tuning/) vision-language models (VLMs). These are pre-trained (**pt**) models coming in three different sizes: `3B`, `10B`, and `28B`, as well as three different input resolutions for images: `224x224`, `448x448`, and `896x896` pixels. These models represent the latest evolution of vision-language models developed by Google, building upon the foundation laid by its predecessor, PaliGemma. Below, we can see the architecture of the PaliGemma 2 model.
16
+
17
+ <figure>
18
+ <img src="images/paligemma2-architecture.png" alt="PaliGemma 2 architecture" style="display: block; margin: 0 auto">
19
+ <figcaption style="text-align: center">Figure 1. PaliGemma 2 Architecture Overview <span style="font-size: 0.8em;">[<a href="https://arxiv.org/pdf/2412.03555">Source</a>]</span></figcaption>
20
+ </figure>
21
+
22
+ PaliGemma 2 processes images at resolutions of `224×224`, `448×448`, or `896×896` pixels using a **[SigLIP-400m](https://arxiv.org/abs/2303.15343) vision encoder** with a patch size of 14×14 pixels. This design yields 256, 1024, or 4096 tokens, respectively. After a linear projection, the resulting image tokens are concatenated with the input text tokens, and [**Gemma 2**](https://blog.google/technology/developers/google-gemma-2/) is used as a **text decoder** to autoregressively complete the combined prefix to generate an answer.
23
+
24
+ ## PaliGemma 2 Mix
25
+
26
+ As already mentioned, PaliGemma 2 models are pre-trained models, but they are also designed to be easy to fine-tune and adapt to various specific vision-language tasks and domains. Google wanted to demonstrate the performance of a fine-tuned version of the pt PaliGemma 2 models on downstream tasks, and thus a few months later, in February 2025, they introduced [**PaliGemma 2 mix**](https://developers.googleblog.com/en/introducing-paligemma-2-mix/). These models are fine-tuned to a mixture of vision language tasks that can be used out-of-the-box for common use cases. They are available in three sizes: `3B`, `10B`, and `28B`, and support resolutions of `224×224` and `448×448` pixels.
27
+
28
+ ### Tasks
29
+
30
+ PaliGemma 2 mix can perform the following types of tasks:
31
+
32
+ - Short and long captioning
33
+ - Optical character recognition (OCR)
34
+ - Image question answering
35
+ - (Multiple) object detection
36
+ - (Multiple) image segmentation
37
+
38
+ ### Prompting
39
+
40
+ In general, the PaliGemma models are very sensitive to the prompt’s syntax and patterns. But based on the following Hugging Face [article](https://huggingface.co/blog/paligemma2mix) when using PaliGemma 2 mix models, open-ended prompts yield better performance than the previously required task-prefixed prompts. Earlier, task-specific prefixes were essential, like
41
+
42
+ - `"caption {lang}\n"`: Short captions
43
+ - `"describe {lang}\n"`: More descriptive captions
44
+ - `"ocr"`: Optical character recognition
45
+ - `"answer {lang} {question}\n"`: Question answering about the image contents
46
+ - `"question {lang} {answer}\n"`: Question generation for a given answer
47
+
48
+ However, two specific tasks - **object detection** and **image segmentation** - still exclusively require task prefixes:
49
+
50
+ - `"detect {object description} ; {object description} ; ...\n"`: Locate multiple objects in an image and return the bounding boxes for those objects
51
+ - `"segment {object description} ; {object description} ; ...\n"`: Locate the area occupied by multiple objects in an image to create an image segmentation for that object
52
+
53
+ # Image Segmentation
54
+
55
+ ## What is Image Segmentation?
56
+
57
+ Image segmentation is a key computer vision technique that divides an image into pixel groups, or segments, enabling tasks like object detection, scene understanding, and advanced image processing. Traditional methods use pixel features such as color, brightness, contrast, and intensity to separate objects from the background, often relying on simple heuristics or basic machine learning. Recently, deep learning models with complex neural networks have dramatically improved segmentation accuracy.
58
+
59
+ Unlike image classification, which labels an entire image, or object detection, which locates objects with bounding boxes, image segmentation provides detailed pixel-level annotations. This approach assigns every pixel to a specific category, with variants including semantic segmentation (classifying pixels), instance segmentation (distinguishing between instances of the same object), and panoptic segmentation (combining both methods).
60
+
61
+ ## Image Segmentation with VLMs
62
+
63
+ VLMs enhance traditional image segmentation by enabling open-vocabulary segmentation through textual instructions, moving away from closed-set methods that rely on predefined categories. By merging text and image data into a common feature space, these models reduce adaptation costs and excel at tasks like referring expression segmentation. For example, a user might prompt the model to *“segment the cat sitting on the chair”*, and the VLM would identify and segment the pixels corresponding to that specific cat.
64
+
65
+ To achieve this, VLMs harness visual features from encoders like CNNs or Vision Transformers, using cross-attention to focus on image regions relevant to the text. Some models are fine-tuned to produce bounding boxes or segmentation masks directly, and careful prompting guides them to accurately segment based on the integrated understanding of visual content and language.
66
+
67
+ ## Image Segmentation the PaliGemma 2 Way
68
+
69
+ Earlier, in **Figure 1** we saw that PaliGemma 2’s architecture combines a Transformer decoder based on the Gemma 2 language model with a Vision Transformer image encoder initialised from SigLIP-So400m/14. The SigLIP encoder divides input images into `14x14` pixel patches to generate “soft tokens” that capture spatial relationships. Then, a linear projection layer is used to map the visual tokens into the same dimensional space as the input embeddings of the Gemma 2 language model. This projection ensures that the visual information can be seamlessly combined with textual information for processing by the language model.
70
+
71
+ The Gemma 2 language model functions as the decoder, processing concatenated image tokens and text tokens to produce autoregressive text output, predicting one token at a time based on the preceding context. To enhance its capabilities for vision-language tasks, PaliGemma extends the vocabulary of the standard Gemma tokenizer (having 256,000 tokens) with additional special tokens. These include 1024 tokens representing coordinates in a normalised image space, denoted as `<loc0000>` through `<loc1023>`, and another 128 tokens, `<seg000>` through `<seg127>`, which are codewords used for a lightweight referring-expression segmentation vector-quantized approach.
72
+
73
+ ### Segmentation Output
74
+
75
+ When processing a segmentation prompt, PaliGemma 2 mix produces a sequence that begins with four location tokens defining the bounding box for the segmented object. These four tokens specify the bounding box coordinates in the normalized image space. This is followed by 16 segmentation tokens, which can be decoded via a learned codebook into a binary segmentation mask confined within the identified region. Below is an example output:
76
+
77
+ ```text
78
+ <loc0336><loc0049><loc0791><loc0941><seg106><seg074><seg114><seg081><seg082><seg028><seg018><seg037><seg120><seg073><seg061><seg125><seg045><seg059><seg052><seg084>
79
+ ```
80
+
81
+ ### Segmentation Mask
82
+
83
+ If we want to further process the 16 segmentation tokens to generate a binary segmentation mask within the identified bounding box, we have to decode the segmentation tokens by using the Decoder from Google’s big vision repository related to the PaliGemma models. It is available in the following [script](https://github.com/google-research/big_vision/blob/main/big_vision/evaluators/proj/paligemma/transfers/segmentation.py). As we can see, the script uses JAX and Flax, and it is known that the Metal plug-in for JAX is still not fully supported as stated in the [Accelerated JAX on Mac](https://developer.apple.com/metal/jax/) article. In the next part of this post, we are going to show not only how to reconstruct the binary mask with the help of the above script, but we are also going to show how to translate JAX/Flax to [mlx](https://github.com/ml-explore/mlx) so that we can fully utilise the unified memory in Apple’s chips.
84
+
85
+ # Tutorial
86
+
87
+ In this section, we are going to generate a segmentation mask with the PaliGemma 2 mix model, specifically [mlx-community/paligemma2–10b-mix-448–8bit](https://huggingface.co/mlx-community/paligemma2-10b-mix-448-8bit), by using only the packages `mlx-vlm` and `mlx`. We are also going to overlay the mask on top of the image we are segmenting.
88
+
89
+ ## Overview of the Process
90
+
91
+ Let’s first begin by outlining the steps of the process for generating a segmentation mask. An illustrative diagram can be seen in **Figure 2**.
92
+
93
+ - We start with passing a **prompt** to the model of the form *"segment cat\n"*, and the image we want to segment. This is our **original image** with dimensions $x_{\text{orig}}$ by $y_{\text{orig}}$.
94
+ - Then, the model’s image processor (SiglipImageProcessor) yields to an **input image** with dimensions $x_{\text{input}}$ by $y_{\text{input}}$. In the PaliGemma 2 mix case this would be either `224x224` or `448x448`, depending on the model we have chosen to use. In our case, it would be `448x448`.
95
+ - The model generates an output with 4 location coordinates and 16 segmentation tokens. The `<locXXXX><locXXXX><locXXXX><locXXXX>` sequence corresponds to the $y_{\text{min}}$, $x_{\text{min}}$, $y_{\text{max}}$, $x_{\text{max}}$ coordinates defining the **bounding box**. These coordinates should be normalised to an image size of `1024x1024` to obtain the bounding box coordinates of the object we want to segment with respect to the input image dimensions.
96
+
97
+ <figure>
98
+ <img src="images/input_bb.png" alt="Model input and bounding box" style="display: block; margin: 0 auto">
99
+ <figcaption style="text-align: center">Figure 2. Model input and bounding box coordinates</figcaption>
100
+ </figure>
101
+
102
+ Now that we’ve defined the bounding box by its coordinates, let’s zoom in on its details as shown in **Figure 3**, and dicuss how we would overlay the segmentation mask on top of the image we are segmenting.
103
+
104
+ - The model has returned the 16 segmentation tokens of the form `<segXXX>`. After decoding them via the codebook we end up reconstructing the **segmentation mask**. This mask has a size of `64x64` pixels.
105
+ - Next, we need to map the segmentation mask onto the bounding box that was previously defined. This is accomplished using classical interpolation techniques to scale the mask to the bounding box’s dimensions.
106
+
107
+ <figure>
108
+ <img src="images/map_mask.png" alt="Mapping mask to bounding box" style="display: block; margin: 0 auto">
109
+ <figcaption style="text-align: center">Figure 3. Mapping the 64x64 mask to the bounding box</figcaption>
110
+ </figure>
111
+
112
+ - Once resized, the mask is aligned to fit within the bounding box. To overlay this mask on the original image, we create an empty array matching the dimensions of the input image and then replace the array values corresponding to the bounding box coordinates with those from the interpolated segmentation mask.
113
+
114
+ # MLX
115
+
116
+ Finally, it’s time to dive into the coding section of this blog and focus specifically on the `mlx` components. The code can be found in [GitHub](https://github.com/JoeJoe1313/LLMs-Journey/blob/main/VLMs/paligemma_segmentation_mlx.py).
117
+
118
+ We begin by importing the necessary libraries and modules,
119
+
120
+ ```python
121
+ import argparse
122
+ import functools
123
+ import logging
124
+ import re
125
+ from typing import Callable, List, Tuple
126
+
127
+ import cv2
128
+ import matplotlib.pyplot as plt
129
+ import mlx.core as mx
130
+ import mlx.nn as nn
131
+ import numpy as np
132
+ from mlx_vlm import apply_chat_template, generate, load
133
+ from mlx_vlm.utils import load_image
134
+ from tensorflow.io import gfile
135
+ ```
136
+
137
+ then, we establish the paths for the models and image resources. The `MODEL_PATH` points to the specific PaliGemma model that we are going to use for segmentation tasks. The `IMAGE_PATH` is the location of the image that we will process, and the `_KNOWN_MODELS` dictionary provides a reference to the VAE checkpoint needed for mask reconstruction.
138
+
139
+ ```python
140
+ MODEL_PATH = "mlx-community/paligemma2-10b-mix-448-8bit"
141
+ IMAGE_PATH = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
142
+ _KNOWN_MODELS = {"oi": "gs://big_vision/paligemma/vae-oid.npz"}
143
+ ```
144
+
145
+ Before diving into the core functionality, we set up logging to keep track of the execution flow and for debugging purposes. The following snippet initializes Python’s built-in logging system:
146
+
147
+ ```python
148
+ logging.basicConfig()
149
+ log = logging.getLogger(__name__)
150
+ log.setLevel(logging.INFO)
151
+ ```
152
+
153
+ The `ResBlock` class implements a basic residual block typical for convolutional architectures. It comprises three convolution layers:
154
+
155
+ - Two `3x3` convolutions with ReLU activations, which process the input.
156
+ - One `1x1` convolution to adjust dimensions if needed.
157
+
158
+ The output of the block is computed by summing the result of the convolutions with the original input. This residual connection helps maintain gradient flow during training and preserves information across layers.
159
+
160
+ ```python
161
+ class ResBlock(nn.Module):
162
+ def __init__(self, features: int):
163
+ super().__init__()
164
+ self.conv1 = nn.Conv2d(
165
+ in_channels=features, out_channels=features, kernel_size=3, padding=1
166
+ )
167
+ self.conv2 = nn.Conv2d(
168
+ in_channels=features, out_channels=features, kernel_size=3, padding=1
169
+ )
170
+ self.conv3 = nn.Conv2d(
171
+ in_channels=features, out_channels=features, kernel_size=1, padding=0
172
+ )
173
+
174
+ def __call__(self, x: mx.array) -> mx.array:
175
+ original_x = x
176
+ x = nn.relu(self.conv1(x))
177
+ x = nn.relu(self.conv2(x))
178
+ x = self.conv3(x)
179
+ return x + original_x
180
+ ```
181
+
182
+ The `Decoder` class takes quantized vectors (obtained from segmentation tokens) and upscales them to produce a mask:
183
+
184
+ - An initial convolution reduces the channel dimension.
185
+ - A series of configurable residual blocks further process the features.
186
+ Multiple transpose convolution layers (upsample layers) scale the feature maps until the desired resolution is reached.
187
+ - A final convolution produces the output mask.
188
+
189
+ ```python
190
+ class Decoder(nn.Module):
191
+ """Decoder that upscales quantized vectors to produce a mask.
192
+ The architecture is parameterized to avoid hardcoded layer definitions.
193
+ Takes channels-last input data (B, H, W, C).
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ in_channels: int = 512,
199
+ res_channels: int = 128,
200
+ out_channels: int = 1,
201
+ num_res_blocks: int = 2,
202
+ upsample_channels: Tuple[int, ...] = (128, 64, 32, 16),
203
+ ):
204
+ super().__init__()
205
+ self.conv_in = nn.Conv2d(
206
+ in_channels=in_channels, out_channels=res_channels, kernel_size=1, padding=0
207
+ )
208
+ self.res_blocks = [
209
+ ResBlock(features=res_channels) for _ in range(num_res_blocks)
210
+ ]
211
+ self.upsample_layers = []
212
+ out_up_ch = res_channels
213
+ for ch in upsample_channels:
214
+ self.upsample_layers.append(
215
+ nn.ConvTranspose2d(
216
+ in_channels=out_up_ch,
217
+ out_channels=ch,
218
+ kernel_size=4,
219
+ stride=2,
220
+ padding=1,
221
+ )
222
+ )
223
+ out_up_ch = ch
224
+ self.conv_out = nn.Conv2d(
225
+ in_channels=upsample_channels[-1],
226
+ out_channels=out_channels,
227
+ kernel_size=1,
228
+ padding=0,
229
+ )
230
+
231
+ def __call__(self, x: mx.array) -> mx.array:
232
+ x = nn.relu(self.conv_in(x))
233
+ for block in self.res_blocks:
234
+ x = block(x)
235
+ for layer in self.upsample_layers:
236
+ x = nn.relu(layer(x))
237
+
238
+ return self.conv_out(x)
239
+ ```
240
+
241
+ The helper function `_get_params` is designed to convert a PyTorch checkpoint into a format that MLX can work with. It does so by
242
+
243
+ - Transposing kernel weights to match the expected output format: from PyTorch’s format to MLX’s (Out, H, W, In) format.
244
+ - Organizing the parameters into a structured dictionary that reflects the architecture of the decoder, including the convolutional layers, residual blocks, and upsample layers.
245
+
246
+ This organized set of parameters is then used to initialize the decoder network.
247
+
248
+ ```python
249
+ def _get_params(checkpoint: dict) -> dict:
250
+ """Converts PyTorch checkpoint to MLX params (nested dict).
251
+ Uses transpositions yielding (Out, H, W, In) format weights."""
252
+
253
+ def transp(kernel: np.ndarray) -> mx.array:
254
+ return mx.transpose(mx.array(kernel), (0, 2, 3, 1))
255
+
256
+ def transp_transpose(kernel: np.ndarray) -> mx.array:
257
+ intermediate = mx.transpose(mx.array(kernel), (1, 0, 2, 3))
258
+
259
+ return mx.transpose(intermediate, (0, 2, 3, 1))
260
+
261
+ def conv(name: str) -> dict:
262
+ return {
263
+ "bias": mx.array(checkpoint[f"{name}.bias"]),
264
+ "weight": transp(checkpoint[f"{name}.weight"]),
265
+ }
266
+
267
+ def conv_transpose(name: str) -> dict:
268
+ return {
269
+ "bias": mx.array(checkpoint[f"{name}.bias"]),
270
+ "weight": transp_transpose(checkpoint[f"{name}.weight"]),
271
+ }
272
+
273
+ def resblock(name: str) -> dict:
274
+ return {
275
+ "conv1": conv(f"{name}.0"),
276
+ "conv2": conv(f"{name}.2"),
277
+ "conv3": conv(f"{name}.4"),
278
+ }
279
+
280
+ params = {
281
+ "_embeddings": mx.array(checkpoint["_vq_vae._embedding"]),
282
+ "conv_in": conv("decoder.0"),
283
+ "res_blocks": [
284
+ resblock("decoder.2.net"),
285
+ resblock("decoder.3.net"),
286
+ ],
287
+ "upsample_layers": [
288
+ conv_transpose("decoder.4"),
289
+ conv_transpose("decoder.6"),
290
+ conv_transpose("decoder.8"),
291
+ conv_transpose("decoder.10"),
292
+ ],
293
+ "conv_out": conv("decoder.12"),
294
+ }
295
+
296
+ return params
297
+ ```
298
+
299
+ The function `_quantized_values_from_codebook_indices` takes the segmentation tokens (represented as codebook indices) and uses the embeddings from the codebook to retrieve the corresponding encoded representations. These values are reshaped to fit the expected dimensions (batch, height, width, channels) so that they are ready for processing by the decoder.
300
+
301
+ ```python
302
+ def _quantized_values_from_codebook_indices(
303
+ codebook_indices: mx.array, embeddings: mx.array
304
+ ) -> mx.array:
305
+ batch_size, num_tokens = codebook_indices.shape
306
+ expected_tokens = 16
307
+ if num_tokens != expected_tokens:
308
+ log.error(f"Expected {expected_tokens} tokens, got {codebook_indices.shape}")
309
+
310
+ encodings = mx.take(embeddings, codebook_indices.reshape((-1,)), axis=0)
311
+
312
+ return encodings.reshape((batch_size, 4, 4, embeddings.shape[1]))
313
+ ```
314
+
315
+ The `get_reconstruct_masks` function loads the VAE checkpoint and initializes the decoder with the appropriate parameters. By extracting and setting up the necessary embeddings and decoder weights, this function returns another function (`reconstruct_masks`) that, when given segmentation tokens, decodes them into a binary segmentation mask.
316
+
317
+ ```python
318
+ @functools.cache
319
+ def get_reconstruct_masks(model: str) -> Callable[[mx.array], mx.array]:
320
+ """Loads the checkpoint and returns a function that reconstructs masks
321
+ from codebook indices using a preloaded MLX decoder.
322
+ """
323
+ checkpoint_path = _KNOWN_MODELS.get(model, model)
324
+ with gfile.GFile(checkpoint_path, "rb") as f:
325
+ checkpoint_data = dict(np.load(f))
326
+
327
+ params = _get_params(checkpoint_data)
328
+ embeddings = params.pop("_embeddings")
329
+ log.info(f"VAE embedding dimension: {embeddings.shape[1]}")
330
+
331
+ decoder = Decoder()
332
+ decoder.update(params)
333
+
334
+ def reconstruct_masks(codebook_indices: mx.array) -> mx.array:
335
+ quantized = _quantized_values_from_codebook_indices(
336
+ codebook_indices, embeddings
337
+ )
338
+ return decoder(quantized)
339
+
340
+ return reconstruct_masks
341
+ ```
342
+
343
+ The function `extract_and_create_arrays` parses a given string pattern for segmentation tokens. It extracts these token numbers, converts them into integers, and then wraps them in MLX arrays for further mask reconstruction processing.
344
+
345
+ ```python
346
+ def extract_and_create_arrays(pattern: str) -> List[mx.array]:
347
+ """Extracts segmentation tokens from each object in the pattern and returns a list of MLX arrays."""
348
+ object_strings = [obj.strip() for obj in pattern.split(";") if obj.strip()]
349
+
350
+ seg_tokens_arrays = []
351
+ for obj in object_strings:
352
+ seg_tokens = re.findall(r"<seg(\d{3})>", obj)
353
+ if seg_tokens:
354
+ seg_numbers = [int(token) for token in seg_tokens]
355
+ seg_tokens_arrays.append(mx.array(seg_numbers))
356
+
357
+ return seg_tokens_arrays
358
+ ```
359
+
360
+ The `parse_bbox` function interprets the model's output string to extract bounding box coordinates. Each detected object's location is denoted by a string format (`<loc1234>`). This function finds four numbers per object, corresponding to the box boundaries, and aggregates them into a list of bounding boxes.
361
+
362
+ ```python
363
+ def parse_bbox(model_output: str):
364
+ entries = model_output.split(";")
365
+
366
+ results = []
367
+ for entry in entries:
368
+ entry = entry.strip()
369
+ numbers = re.findall(r"<loc(\d+)>", entry)
370
+ if len(numbers) == 4:
371
+ bbox = [int(num) for num in numbers]
372
+ results.append(bbox)
373
+
374
+ return results
375
+ ```
376
+
377
+ The `gather_masks` function combines the reconstruction and bounding box parsing steps. For each object:
378
+
379
+ - It reconstructs the mask from its codebook indices.
380
+ - It obtains the corresponding bounding box coordinates.
381
+ - It normalizes these coordinates relative to a target image resolution (448×448 in this example).
382
+
383
+ Each mask is then paired with its coordinates and stored in a list, making it straightforward to later overlay these onto the original image.
384
+
385
+ ```python
386
+ def gather_masks(output, codes_list, reconstruct_fn):
387
+ masks_list = []
388
+
389
+ target_width, target_height = 448, 448
390
+ for i, codes in enumerate(codes_list):
391
+ codes_batch = codes[None, :]
392
+ masks = reconstruct_fn(codes_batch)
393
+ mask_np = np.array(masks[0, :, :, 0], copy=False)
394
+
395
+ y_min, x_min, y_max, x_max = parse_bbox(output)[i]
396
+ x_min_norm = int(x_min / 1024 * target_width)
397
+ y_min_norm = int(y_min / 1024 * target_height)
398
+ x_max_norm = int(x_max / 1024 * target_width)
399
+ y_max_norm = int(y_max / 1024 * target_height)
400
+
401
+ masks_list.append(
402
+ {
403
+ "mask": mask_np,
404
+ "coordinates": (x_min_norm, y_min_norm, x_max_norm, y_max_norm),
405
+ }
406
+ )
407
+
408
+ return masks_list
409
+ ```
410
+
411
+ The function `plot_masks` handles the visualization of the segmentation outcomes. It loads the original image and processes it for display. Two types of visualizations are provided:
412
+
413
+ - **Composite Overlay**: All masks are combined and overlaid on the original image.
414
+ - **Reconstructed Mask**: Each reconstructed mask is plotted next to the composite overlay.
415
+
416
+ Using OpenCV for mask resizing and Matplotlib for plotting, the function creates a series of subplots to clearly display both composite and individual mask overlays.
417
+
418
+ ```python
419
+ def plot_masks(args, processor, masks_list):
420
+
421
+ image = load_image(args.image_path)
422
+ img_array = processor.image_processor(image)["pixel_values"][0].transpose(1, 2, 0)
423
+ img_array = (img_array * 0.5 + 0.5).clip(0, 1)
424
+
425
+ full = np.ones((448, 448, 1)) * (-1)
426
+ for mask_info in masks_list:
427
+ mask_np = mask_info["mask"]
428
+ x_min_norm, y_min_norm, x_max_norm, y_max_norm = mask_info["coordinates"]
429
+
430
+ width = x_max_norm - x_min_norm
431
+ height = y_max_norm - y_min_norm
432
+
433
+ resized_mask = cv2.resize(
434
+ mask_np, (width, height), interpolation=cv2.INTER_NEAREST
435
+ )
436
+ resized_mask = resized_mask.reshape((height, width, 1))
437
+
438
+ full[y_min_norm:y_max_norm, x_min_norm:x_max_norm] = resized_mask
439
+
440
+ n_masks = len(masks_list)
441
+ _, axs = plt.subplots(1, n_masks + 1, figsize=(5 * (n_masks + 1), 6))
442
+
443
+ axs[0].imshow(img_array)
444
+ axs[0].imshow(full, alpha=0.5)
445
+ axs[0].set_title("Mask Overlay")
446
+ axs[0].axis("on")
447
+
448
+ for i, mask_info in enumerate(masks_list, start=1):
449
+ mask_np = mask_info["mask"]
450
+ axs[i].imshow(mask_np)
451
+ axs[i].set_title(f"Reconstructed Mask {i}")
452
+ axs[i].axis("on")
453
+
454
+ plt.tight_layout()
455
+ plt.show()
456
+ ```
457
+
458
+ The `main` function ties all the pieces together. It performs the following steps:
459
+
460
+ - **Loading**: Reads the specified PaliGemma model and image.
461
+ - **Setup**: Initializes the VAE checkpoint and extracts the reconstruction function.
462
+ - **Prompting**: Formats the prompt using the processor and generates a segmentation output via the model.
463
+ - **Processing**: Extracts segmentation tokens, reconstructs masks, and parses bounding box coordinates.
464
+ - **Visualization**: Finally, it calls the plotting function to display the results.
465
+
466
+ This function serves as the central point where data processing, model inference, mask reconstruction, and visualization are integrated into one complete pipeline.
467
+
468
+ ```python
469
+ def main(args) -> None:
470
+ log.info(f"Loading PaliGemma model: {args.model_path}")
471
+ model, processor = load(args.model_path)
472
+ config = model.config
473
+
474
+ image = load_image(args.image_path)
475
+ log.info(f"Image size: {image.size}")
476
+
477
+ vae_path = _KNOWN_MODELS.get(args.vae_checkpoint_path, args.vae_checkpoint_path)
478
+ reconstruct_fn = get_reconstruct_masks(vae_path)
479
+
480
+ prompt = args.prompt.strip() + "\n"
481
+ log.info(f"Using prompt: '{prompt.strip()}'")
482
+ formatted_prompt = apply_chat_template(processor, config, prompt, num_images=1)
483
+
484
+ log.info("Generating segmentation output...")
485
+ output = generate(model, processor, formatted_prompt, image, verbose=False)
486
+ log.info(f"Model output: {output}")
487
+
488
+ codes_list = extract_and_create_arrays(output)
489
+ log.info(f"Extracted codes: {codes_list}")
490
+
491
+ log.info("Reconstructing mask from codes...")
492
+ masks_list = gather_masks(output, codes_list, reconstruct_fn)
493
+
494
+ log.info("Plotting masks...")
495
+ plot_masks(processor, masks_list)
496
+ ```
497
+
498
+ Finally, the script includes an entry point that parses command-line arguments. Users can specify the image path, the prompt for the segmentation task, the model path, and the VAE checkpoint path. Once these are provided via `argparse`, the `main` function is invoked to start the processing pipeline.
499
+
500
+ ```python
501
+ if __name__ == "__main__":
502
+ parser = argparse.ArgumentParser(description="Vision tasks using PaliGemma 2 mix.")
503
+ parser.add_argument(
504
+ "--image_path", type=str, default=IMAGE_PATH, help="Path to the input image."
505
+ )
506
+ parser.add_argument(
507
+ "--prompt", type=str, required=True, help="Prompt for the model."
508
+ )
509
+ parser.add_argument(
510
+ "--model_path", type=str, default=MODEL_PATH, help="Path to the mlx model."
511
+ )
512
+ parser.add_argument(
513
+ "--vae_checkpoint_path", type=str, default="oi", help="Path to the .npz file."
514
+ )
515
+
516
+ cli_args = parser.parse_args()
517
+ main(cli_args)
518
+ ```
519
+
520
+ # Results
521
+
522
+ Let’s take a look at some examples and the segmentations we obtained from the model.
523
+
524
+ ## Single Object Segmentation
525
+
526
+ In this section, we are going to show to examples of single object segmentation.
527
+
528
+ ---
529
+
530
+ **Prompt:**
531
+
532
+ ```text
533
+ "segment cow"
534
+ ```
535
+
536
+ **Image:**
537
+
538
+ <figure>
539
+ <img src="images/cow_in.png" alt="Cow input" style="display: block; margin: 0 auto">
540
+ <figcaption style="text-align: center">Figure 4. Original image of size 400x400</figcaption>
541
+ </figure>
542
+
543
+ **Model output:**
544
+
545
+ ```text
546
+ <loc0410><loc0528><loc0884><loc1023><seg072><seg055><seg062><seg079><seg104><seg009><seg104><seg096><seg068><seg041><seg103><seg019><seg100><seg004><seg091><seg067>
547
+ ```
548
+
549
+ **Mask overlay and reconstructed mask:**
550
+
551
+ <figure>
552
+ <img src="images/cow_out.png" alt="Cow output" style="display: block; margin: 0 auto">
553
+ <figcaption style="text-align: center">Figure 5. Left: mask overlay onto the input image of size 448x448 | Right: reconstructed mask of size 64x64</figcaption>
554
+ </figure>
555
+
556
+ **Observation:**
557
+
558
+ Based on the overlay image, the model manages to detect the precise location of the cow but struggles a bit with the detailed outlines of the cow. Looking only at the reconstructed mask would not persuade me that this is a cow.
559
+
560
+ ---
561
+
562
+ **Prompt:**
563
+
564
+ ```text
565
+ "segment cat"
566
+ ```
567
+
568
+ **Image:**
569
+
570
+ <figure>
571
+ <img src="images/cat_in.png" alt="Cat input" style="display: block; margin: 0 auto">
572
+ <figcaption style="text-align: center">Figure 6. Original image of size 400x400</figcaption>
573
+ </figure>
574
+
575
+ **Model output:**
576
+
577
+ ```text
578
+ <loc0060><loc0000><loc0920><loc0879><seg039><seg107><seg018><seg006><seg056><seg120><seg058><seg042><seg079><seg094><seg009><seg099><seg074><seg010><seg078><seg012>
579
+ ```
580
+
581
+ **Mask overlay and reconstructed mask:**
582
+
583
+ <figure>
584
+ <img src="images/cat_out.png" alt="Cat output" style="display: block; margin: 0 auto">
585
+ <figcaption style="text-align: center">Figure 7. Left: mask overlay onto the input image of size 448x448 | Right: reconstructed mask of size 64x64</figcaption>
586
+ </figure>
587
+
588
+ **Observation:**
589
+
590
+ Based on the overlay image, the model manages to detect the precise location of the cat, and is generally doing a good job with the cat’s outlines.
591
+
592
+ ## Multiple Object Segmentation
593
+
594
+ It was tricky to find a working example for segmenting multiple objects, so there is only one example in this section. My observation is that the PaliGemma models are indeed very sensitive to the prompt formatting, and the 448–10B-8bit model might just not be powerful enough for the task of segmenting multiple objects.
595
+
596
+ ---
597
+
598
+ **Prompt:**
599
+
600
+ ```text
601
+ "segment left wheel ; right wheel"
602
+ ```
603
+
604
+ **Image:**
605
+
606
+ <figure>
607
+ <img src="images/car_in.png" alt="Car input" style="display: block; margin: 0 auto">
608
+ <figcaption style="text-align: center">Figure 8. Original image of size 640x480</figcaption>
609
+ </figure>
610
+
611
+ **Model output:**
612
+
613
+ ```text
614
+ <loc0591><loc0157><loc0794><loc0311> <seg092><seg004><seg044><seg092><seg120><seg061><seg029><seg120><seg090><seg023><seg021><seg090><seg015><seg041><seg044><seg073> right wheel ; <loc0586><loc0728><loc0794><loc0882> <seg092><seg004><seg089><seg092><seg120><seg048><seg054><seg038><seg119><seg029><seg021><seg090><seg095><seg041><seg044><seg073> right wheel
615
+ ```
616
+
617
+ **Mask overlay and reconstructed mask:**
618
+
619
+ <figure>
620
+ <img src="images/car_out.png" alt="Car output" style="display: block; margin: 0 auto">
621
+ <figcaption style="text-align: center">Figure 9. Left: masks overlay onto the input image of size 448x448 | Right: reconstructed masks of size 64x64</figcaption>
622
+ </figure>
623
+
624
+ **Observation:**
625
+
626
+ Looking at the model output we can see that both segmentations are labeled as right wheel. Despite this, based on the overlay image, the model manages to detect the precise location of the wheels, and their outlines.
627
+
628
+ # Conclusion
629
+
630
+ In summary, we implemented a unified segmentation pipeline by combining Google’s PaliGemma 2 Mix model with Apple’s MLX framework. Our workflow involved formatting segmentation prompts, preprocessing images, extracting segmentation tokens and bounding box coordinates, decoding these tokens into segmentation masks, and finally overlaying the masks on the original images.
631
+
632
+ For single object segmentation, the model generally performed well: it accurately localised the object areas, as evidenced by both the “cat” and the “cow” examples. However, the segmentation for the “cow” revealed some challenges with capturing fine details, indicating areas for potential refinement.
633
+
634
+ The multiple object segmentation proved challenging, as we struggled to find more than one working example that produced multiple segmentations. In our single example, the model demonstrated the ability to detect the general locations of objects — successfully identifying both wheels — but it also suffered from issues such as prompt sensitivity and duplicate labelling. This difficulty may be attributed to the inherent prompt sensitivity of the model or potential limitations of the specific model variant, particularly the 448–10B-8bit configuration. These observations suggest that either refining prompt structures or exploring more powerful models may be essential for reliably handling segmentation tasks involving multiple objects.
635
+
636
+ # References
637
+
638
+ - [Introducing PaliGemma 2 mix: A vision-language model for multiple tasks](https://developers.googleblog.com/en/introducing-paligemma-2-mix/)
639
+ - [PaliGemma prompt and system instructions](https://ai.google.dev/gemma/docs/paligemma/prompt-system-instructions)
640
+ - [PaliGemma 2 Mix — New Instruction Vision Language Models by Google](https://huggingface.co/blog/paligemma2mix)
641
+ - [Welcome PaliGemma 2 — New vision language models by Google](https://huggingface.co/blog/paligemma2)
642
+ - [Introducing PaliGemma 2: Powerful Vision-Language Models, Simple Fine-Tuning](https://developers.googleblog.com/en/introducing-paligemma-2-powerful-vision-language-models-simple-fine-tuning/)
643
+ - [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://arxiv.org/abs/2412.03555)
src/posts/2025-05-06-chat-qwen3-ios/images/3_1.png ADDED

Git LFS Details

  • SHA256: 4a40d8c3abbf2e264b259685dc7eb0d3a527d2f505f6b2a40f2be90176908a53
  • Pointer size: 130 Bytes
  • Size of remote file: 21.2 kB
src/posts/2025-05-06-chat-qwen3-ios/images/3_2.png ADDED

Git LFS Details

  • SHA256: 8cae2f6c86012024bdd78c123a988aa750b3f0a74ba7c551647d39cf798b47d7
  • Pointer size: 130 Bytes
  • Size of remote file: 39.2 kB
src/posts/2025-05-06-chat-qwen3-ios/images/4_1.png ADDED

Git LFS Details

  • SHA256: 98701b1f841973a46d67b4d4f6016e4686f4e6cbf86155d83973b4a1cc82c0e7
  • Pointer size: 130 Bytes
  • Size of remote file: 48.5 kB
src/posts/2025-05-06-chat-qwen3-ios/images/4_2.png ADDED

Git LFS Details

  • SHA256: 78c73607f34bffd0212ad453912387e1ef61e324bf24b8be0a864d71da86be0d
  • Pointer size: 130 Bytes
  • Size of remote file: 14.3 kB
src/posts/2025-05-06-chat-qwen3-ios/images/5_1.png ADDED

Git LFS Details

  • SHA256: 568147e84421667c375278a6e1783c780836897240388dbf8d059300d5ba9361
  • Pointer size: 130 Bytes
  • Size of remote file: 95 kB
src/posts/2025-05-06-chat-qwen3-ios/images/5_2.png ADDED

Git LFS Details

  • SHA256: 3c5f9ff0ae3f813c5bcfae5eafdc93cdb3e436f6d2270a8c265dc3e83068b1ab
  • Pointer size: 130 Bytes
  • Size of remote file: 78.2 kB
src/posts/2025-05-06-chat-qwen3-ios/images/config_scheme_dest.png ADDED

Git LFS Details

  • SHA256: 14a613619d05b614d95fa74f44a2b88542cad6f6fcacc8646faaf6a8cf2dfed4
  • Pointer size: 130 Bytes
  • Size of remote file: 52.4 kB
src/posts/2025-05-06-chat-qwen3-ios/images/developer_team.png ADDED

Git LFS Details

  • SHA256: f70f550d8704ed8001fec5230eb56e34405a2791353f7cada725ddde29ebd4b1
  • Pointer size: 130 Bytes
  • Size of remote file: 72.1 kB
src/posts/2025-05-06-chat-qwen3-ios/index.qmd ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "Chat with Qwen3 on your iPhone: A Step-by-Step Guide"
3
+ author: "Joana Levtcheva"
4
+ date: "2025-05-06"
5
+ categories: [Machine Learning, mlx, swift, llm, ios]
6
+ draft: false
7
+ ---
8
+
9
+ Have you ever wanted to run a powerful large language model directly on your iPhone without sending your data to the cloud? Thanks to Apple’s [MLX Swift](https://github.com/ml-explore/mlx-swift) framework, you can now run the remarakably capable [Qwen3](https://qwenlm.github.io/blog/qwen3/) models right on your iPhone.
10
+
11
+ What makes this exciting is how capable these new models are even with as little as 4B parameters. Based on Qwen’s benchmark reports, Qwen3–4B can rival the performance of Qwen2.5–72B-Instruct, delivering impressive results with just a fraction of the parameters. This development means truly powerful AI can now fit in your pocket.
12
+
13
+ This blog post will guide you through the entire process from setting up your development environment to testing the model on your device. You can embrace true privacy, and take ownership of your AI experience. Your data stays on your device, where it belongs, and most importantly - you have access to this knowledge powerhouse even with no internet connection.
14
+
15
+ Medium post can be found [here](https://medium.com/@levchevajoana/chat-with-qwen3-on-your-iphone-a-step-by-step-guide-515bb957cd02).
16
+
17
+ # Main Steps
18
+
19
+ The whole process consists of three main steps:
20
+
21
+ - Enable Developer Mode on your iPhone
22
+ - Clone Apple’s MLX Swift Examples [repository](https://github.com/ml-explore/mlx-swift-examples)
23
+ - Configure and deploy the Xcode project to your device
24
+
25
+ # Enable Developer Mode on iPhone
26
+
27
+ In order to deploy custom apps to your iPhone, you need to enable **Developer Mode**. On your iPhone:
28
+
29
+ - Go to **Settings &#8594; Privacy & Security &#8594; Developer Mode**
30
+ - Toggle Developer Mode to **On**
31
+ Restart your iPhone when prompted
32
+
33
+ > **NOTE:** It is important to note that your device’s security is reduced in Developer Mode.
34
+
35
+ # Apple’s MLX Swift Examples Repository
36
+
37
+ Now that your iPhone is ready for development, let’s get the necessary code. We are going to run the **MLXChatExample**, an example chat app supporting LLMs and VLMs for iOS and macOS, which can be found in the Applications folder in the [mlx-swift-examples](https://github.com/ml-explore/mlx-swift-examples) repository. So, the first step is to clone the repository:
38
+
39
+ ```bash
40
+ git clone [email protected]:ml-explore/mlx-swift-examples.git
41
+ ```
42
+
43
+ Then, navigate to the cloned directory:
44
+
45
+ ```bash
46
+ cd mlx-swift-examples
47
+ ```
48
+
49
+ And finally, open the Xcode project:
50
+
51
+ ```bash
52
+ open mlx-swift-examples.xcodeproj
53
+ ```
54
+
55
+ The above command should open the project directly in Xcode.
56
+
57
+ # Configure and Deploy the Xcode Project
58
+
59
+ With the project open in Xcode, you should connect your iPhone to the Mac, and then make a few adjustments.
60
+
61
+ ## Configure the Project
62
+
63
+ We are interested in deploying the **MLXChatExample** app on your connected iPhone, so you should:
64
+
65
+ - Change your scheme to **MLXChatExample** by going to **Product &#8594; Scheme &#8594; Choose Scheme** (this opens a dropdown menu from the Xcode’s top bar, which can be selected directly from the bar as well)
66
+ - Set the destination to your connected iPhone by going to **Product &#8594; Destination &#8594; Choose Destination**
67
+
68
+ <figure>
69
+ <img src="images/config_scheme_dest.png" alt="Configure scheme and destination" style="display: block; margin: 0 auto">
70
+ <figcaption style="text-align: center">Figure 1. Configure scheme and destination</figcaption>
71
+ </figure>
72
+
73
+ ## Add Developer Team
74
+
75
+ In order to build and run the app we have to assign a development team:
76
+
77
+ - Select the **mlx-swift-examples** project in the Project Navigator (on the left)
78
+ - Go to the **Signing & Capabilities** tab
79
+ - Choose your **Team** from the dropdown menu
80
+
81
+ <figure>
82
+ <img src="images/developer_team.png" alt="Developer team" style="display: block; margin: 0 auto">
83
+ <figcaption style="text-align: center">Figure 2. Add developer team</figcaption>
84
+ </figure>
85
+
86
+ # Deploy to Your iPhone
87
+
88
+ Now we are ready to deploy **MLXChatExample** to our actual device:
89
+
90
+ - Click the Run button (▶) on the left to run the project, or alternatively choose **Product &#8594; Run**
91
+ - You might be asked to authorize Xcode development on your device
92
+ - If successful, the app will be downloaded to your iPhone
93
+
94
+ ## Trust Developer Certificate
95
+
96
+ If this is your first app from this developer team, you would get a message on your iPhone that the app is from an **Untrusted Developer** and you are not allowed to use it. We should allow applications from this developer, so on your iPhone:
97
+
98
+ - Go to **Settings &#8594; General &#8594; VPN & Device Management**
99
+ - Under **Developer App**, tap on your app
100
+ Select **Trust “[Developer Name]”** and confirm
101
+
102
+ ## Using Qwen3 on Your iPhone
103
+
104
+ With that, we are ready to test the MLXChatExample app. Let’s open it:
105
+
106
+ <figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
107
+ <div>
108
+ <img src="images/3_1.png" alt="App launch 1" />
109
+ </div>
110
+ <div>
111
+ <img src="images/3_2.png" alt="App launch 2" />
112
+ </div>
113
+ </figure>
114
+ <figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
115
+ Figure 3. Left: MLXChatExample app default screen | Right: models dropdown menu
116
+ </figcaption>
117
+
118
+ By default, the app loads with the **llama3.2:1b** model. You can select different models from the dropdown menu in the upper right corner, see **Figure 3**. Let’s choose **qwen3:4b**, which runs successfully on an iPhone 14 Pro, and type our first query.
119
+
120
+ If this is the first run for a given model, we have to wait for the model to download from Hugging Face to our cache. Below, in **Figure 4**, we can see our first query and how it triggers the model download (the arrow on top), and when clicking on the arrow we can see the download progress.
121
+
122
+ <figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
123
+ <div>
124
+ <img src="images/4_1.png" alt="Init query" />
125
+ </div>
126
+ <div>
127
+ <img src="images/4_2.png" alt="Download model" />
128
+ </div>
129
+ </figure>
130
+ <figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
131
+ Figure 4. Left: first query for a model triggering model download | Right: model downloading progress
132
+ </figcaption>
133
+
134
+ The Qwen3 models can have thinking enabled and disabled. At the moment the app doesn’t support a toggle to manually switch between both modes. The default behaviour is to use thinking (the thought process output is encapsulated between the tags `<think></think>`), and in order to disable the thinking we have to add **/no_think** at the end of our prompt. In this case the think tags appear with nothing between them. Refer to **Figure 5** below.
135
+
136
+ <figure style="display: flex; justify-content: center; gap: 1rem; margin: 0;">
137
+ <div>
138
+ <img src="images/5_1.png" alt="Gen 1" />
139
+ </div>
140
+ <div>
141
+ <img src="images/5_2.png" alt="Gen 2" />
142
+ </div>
143
+ </figure>
144
+ <figcaption style="text-align: center; width: 100%; margin-top: 0.5rem;">
145
+ Figure 5. Left: query with default thinking behaviour | Right: query with disabled thinking
146
+ </figcaption>
147
+
148
+ # The Benefits of On-Device AI
149
+
150
+ Running models like Qwen3 directly on your device offers several significant advantages:
151
+
152
+ - **Complete Privacy:** Your conversations never leave your device. No server logs, no data collection, just you and your AI.
153
+ - **Works Offline:** You have all the world’s knowledge in your pocket wherever you are, without depending on internet connection availability. Need AI assistance while traveling, hiking, or in areas with poor connectivity? Local models work anywhere, anytime.
154
+ - **No Subscription Fees:** Once you’ve set up the model, there are no ongoing costs.
155
+ - **No Rate Limits:** Chat as much as you want without hitting quotas.
156
+ - **Reduced Environmental Impact:** On-device inference eliminates the carbon footprint associated with massive data centers processing your requests.
157
+
158
+ # Conclusion
159
+
160
+ While cloud-based solutions may offer larger models with excellent performance, the gap is closing rapidly as efficiency improvements make smaller models surprisingly capable. The Qwen3 models are an excellent example for this, and are a “living” proof of the rapid advancements in the field of AI. The MLX Swift framework from Apple makes powerful machine learning accessible to everyday users on their personal devices. This shift from cloud-dependent to local-first AI is just beginning.
161
+
162
+ ***Own your AI. Own your data. Go local.***
src/posts/2025-05-23-app-docker-fastapi/index.qmd ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "Image Segmentation with PaliGemma 2 mix, Transformers, Docker, FastAPI, and GitHub Actions"
3
+ author: "Joana Levtcheva"
4
+ date: "2025-05-23"
5
+ categories: [Machine Learning, FastAPI, Docker, LLM, GitHub-Actions]
6
+ draft: false
7
+ ---
8
+
9
+ In today’s fast-paced machine learning landscape, deploying AI models is just as important as developing them. In this blog post, we are going to walk through an image segmentation application using Google’s **PaliGemma 2 Mix** model and **transformers**, containerized with **Docker**, and served through a **FastAPI** backend. We are also going to discuss the CI/CD pipeline using **GitHub Actions** to automate building the Docker image and pushing it to Docker Hub. Let’s explore this service, why we chose these technologies, and how you can get started and use the service yourself!
10
+
11
+ The complete code is available on [GitHub](https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation).
12
+
13
+ # What is This Project All About?
14
+
15
+ At its core, this project provides a **FastAPI service** that allows you to perform image segmentation using natural language. You simply provide as input to the REST API:
16
+
17
+ - A **text prompt** describing what to segment
18
+ - An **image** via URL or file upload
19
+ - The specific **PaliGemma 2 model** to perform image segmentation
20
+
21
+ The service then returns:
22
+
23
+ - The base64 encoded **model input image**
24
+ - The base64 encoded segmentation **masks** clearly outlining the desired objects
25
+ - The **bounding box coordinates** for each segmented object
26
+
27
+ The **FastAPI** application is also containerized with **Docker** for consistent deployment across environments. A CI/CD pipeline with **GitHub Actions** is created for automated container builds and registry publishing to Docker Hub.
28
+
29
+ # Architectural Blueprint: How It All Works Together
30
+
31
+ Understanding the flow of data and the interaction of components is key. Let’s first take a look at our project structure:
32
+
33
+ ```plaintext
34
+ project_folder/
35
+ ├── app/
36
+ │ ├── __init__.py
37
+ │ ├── main.py # FastAPI application and endpoints
38
+ │ └── segmentation.py # Image segmentation logic
39
+ ├── models/
40
+ │ ├── huggingface/ # Cache directory for Hugging Face models
41
+ │ └── vae-oid.npz # VAE model for mask generation
42
+ ├── .dockerignore
43
+ ├── .github/
44
+ │ └── workflows/ # GitHub Actions for Docker build and push
45
+ │ └── docker-build.yml # Workflow to build and push Docker images
46
+ ├── .gitignore
47
+ ├── docker-compose.yml
48
+ ├── Dockerfile
49
+ ├── README.md
50
+ └── requirements.txt
51
+ ```
52
+
53
+ ## User & Developer Workflow
54
+
55
+ Our system is designed with both the end-user and the developer in mind.
56
+
57
+ ```{mermaid}
58
+ graph LR
59
+ User([User]) -->|Provides Image & Prompt| ClientApp[Client Application]
60
+ ClientApp -->|POST Request| FastAPI_Service[FastAPI Service]
61
+ FastAPI_Service -->|Process Input| PaliGemma_Model[PaliGemma Model]
62
+ PaliGemma_Model -->|Generate Segmentation Tokens| VAE_Model[VAE Model]
63
+ VAE_Model -->|Decode Masks| FastAPI_Service
64
+ FastAPI_Service -->|JSON Response | ClientApp
65
+ ClientApp -->|Display Results| User
66
+
67
+ Developer([Developer]) -->|Push Code| GitHubRepo[GitHub Repository]
68
+ GitHubRepo -->|Trigger| GitHubActions[GitHub Actions]
69
+ GitHubActions -->|Build & Push Image| DockerRegistry[Docker Hub]
70
+ DockerRegistry -->|Pull Image| DeploymentEnv[Deployment Environment]
71
+ DeploymentEnv -.->|Runs| FastAPI_Service
72
+ ```
73
+
74
+ <figcaption style="text-align: center">Figure 1. User & Developer Workflow</figcaption>
75
+
76
+ **Figure 1** shows the workflow:
77
+
78
+ - A **User** interacts with a client application, providing an image and a text prompt, and optionally the specific PaliGemma 2 model.
79
+ - The **Client Application** sends an HTTP POST request to our **FastAPI Service**.
80
+ - The **FastAPI Service** preprocesses the input and feeds it to the **PaliGemma Model**.
81
+ - PaliGemma generates segmentation tokens, which are then passed to the **VAE Model**.
82
+ - The VAE Model decodes these into pixel-level masks and, along with bounding boxes, sends them back to the API.
83
+ The API returns a JSON response to the client.
84
+
85
+ To visualize the precise sequence of operations when a user requests an image segmentation, the following diagram details the interactions between the core components:
86
+
87
+ ```{mermaid}
88
+ sequenceDiagram
89
+ participant User
90
+ participant Client
91
+ participant FastAPI
92
+ participant SegmentationPy as segmentation.py
93
+ participant PaliGemma as PaliGemma Model
94
+ participant VAE as VAE Model
95
+
96
+ User->>+Client: Upload image & prompt
97
+ Client->>+FastAPI: POST /segment
98
+ FastAPI->>+SegmentationPy: call segment_image()
99
+ SegmentationPy->>+PaliGemma: infer with PaliGemma
100
+ PaliGemma-->>-SegmentationPy: (tokens/features)
101
+ SegmentationPy->>+VAE: generate masks
102
+ VAE-->>-SegmentationPy: (pixel masks)
103
+ SegmentationPy-->>-FastAPI: return mask & coords
104
+ FastAPI-->>-Client: JSON response
105
+ Client-->>-User: display results
106
+ ```
107
+ <figcaption style="text-align: center">Figure 2. Segmentation process</figcaption>
108
+
109
+ This sequence highlights how FastAPI acts as the entry point, delegating the complex segmentation logic to the `segmentation.py` module, which in turn leverages the PaliGemma and VAE models to produce the desired output.
110
+
111
+ For **developers**, pushing code to GitHub triggers **GitHub Actions**, which automatically builds a Docker image and pushes it to a **Container Registry** (Docker Hub), ready for deployment.
112
+
113
+ ## Inside the Application
114
+
115
+ Within the Docker container, the application is neatly structured:
116
+
117
+ ```{mermaid}
118
+ graph TD
119
+ subgraph "Docker Container"
120
+ subgraph "app/"
121
+ main[main.py
122
+ FastAPI Application]
123
+ segmentation[segmentation.py
124
+ Image Segmentation Logic]
125
+ main -->|imports| segmentation
126
+ end
127
+
128
+ subgraph "External Dependencies"
129
+ NP[numpy]
130
+ TR[transformers]
131
+ PT[PyTorch]
132
+ JF[JAX/Flax]
133
+ end
134
+
135
+ subgraph "Models"
136
+ PaliGemma[PaliGemma 2 mix]
137
+ VAE[VAE Checkpoint]
138
+ end
139
+
140
+ segmentation -->|uses| TR
141
+ segmentation -->|uses| PT
142
+ segmentation -->|uses| JF
143
+ segmentation -->|uses| NP
144
+
145
+ main -->|loads| PaliGemma
146
+ segmentation -->|loads| VAE
147
+ end
148
+
149
+ Client[Client Application] -->|HTTP Requests| main
150
+
151
+ subgraph "API Endpoints"
152
+ segment[POST /segment/]
153
+ root[GET /]
154
+ end
155
+
156
+ main -->|defines| segment
157
+ main -->|defines| root
158
+ Client -->|calls| segment
159
+ Client -->|calls| root
160
+
161
+ style main fill:#c2e0ff,stroke:#0078d7
162
+ style segmentation fill:#c2e0ff,stroke:#0078d7
163
+ style Client fill:#ffd7b5,stroke:#ff8c00
164
+ style segment fill:#d5e8d4,stroke:#82b366
165
+ style root fill:#d5e8d4,stroke:#82b366
166
+ ```
167
+
168
+ <figcaption style="text-align: center">Figure 3. Application Architecture</figcaption>
169
+
170
+ - `app/main.py`: This is the heart of our API, built using FastAPI. It defines the API endpoints like `/` (for a welcome message) and `/segment` (for the actual segmentation).
171
+ - `app/segmentation.py`: This module contains all the core logic for image processing, interacting with the PaliGemma and VAE models, and generating the final masks and coordinates. It uses the libraries **JAX**, **Flax** and **Transformers** for efficient model execution and inference.
172
+
173
+ # Technology Stack Overview
174
+
175
+ Let’s first understand the key technologies used in the project.
176
+
177
+ ## PaliGemma 2 mix
178
+
179
+ [PaliGemma 2 mix](https://developers.googleblog.com/en/introducing-paligemma-2-mix/) is a state-of-the-art vision-language model from Google that can comprehend both images and text. It represents a significant advancement in multimodal AI, allowing for natural language-guided image understanding. Once PaliGemma identifies the segments (as tokens), a Variational Autoencoder (VAE) model steps in. It decodes these segmentation tokens into the detailed, pixel-level masks that you see as the output. This particular service uses a VAE checkpoint specifically `vae-oid.npz` for this task. A comprehensive overview of the image segmentation process with Paligemma 2 mix can be found in one of my previous posts [here](https://joejoe1313.github.io/2025-04-15-paligemma-2-mix.html).
180
+
181
+ ## FastAPI
182
+
183
+ [FastAPI](https://fastapi.tiangolo.com) is a web framework for building APIs with Python. We chose this framework for several compelling reasons:
184
+
185
+ - Automatic API documentation with **Swagger UI** and ReDoc
186
+ Data validation and serialization through **Pydantic** models
187
+ - Asynchronous request handling with support for `async`/`await` syntax
188
+ - Excellent performance comparable to Node.js and Go
189
+ - Built-in dependency injection system
190
+ - Security and authentication features out of the box
191
+ - WebSocket support and background tasks
192
+
193
+ In our project, `main.py` uses FastAPI to define the `/segment` endpoint which accepts form data including the prompt, and optionally an image URL or an uploaded image file, and the desired `model_id`.
194
+
195
+ ## Transformers
196
+
197
+ **Transformers** is a library by **Hugging Face** that provides state-of-the-art pre-trained models for natural language processing and computer vision. For our project, it’s essential because:
198
+
199
+ - It provides easy access to the PaliGemma models
200
+ - Offers a unified API for loading and using different model architectures
201
+ - Handles model preprocessing and tokenization
202
+ - Supports efficient model inference
203
+ - Enables fine-tuning of models if needed
204
+
205
+ ## JAX/Flax
206
+
207
+ JAX is a high-performance numerical computing library developed by Google. Flax is a neural network library built on top of JAX. Together, they provide:
208
+
209
+ - Accelerated computation on GPUs and TPUs
210
+ - Just-in-time compilation for optimized performance
211
+ - Automatic differentiation capabilities
212
+ - Functional programming approach to machine learning
213
+
214
+ In our app **JAX/Flax & Transformers** are used for scalable model execution and inference, JAX/Flax is used for the VAE model which decodes segmentation tokens into pixel-level masks.
215
+
216
+ ## Docker & Docker Compose
217
+
218
+ ### What is Docker?
219
+
220
+ Docker is a platform that uses OS-level virtualization to deliver software in packages called **containers**. A container is a standalone, executable package of software that includes everything needed to run an application: code, runtime, system tools, system libraries, and settings.
221
+
222
+ ### Why Docker for This Project?
223
+
224
+ - **Consistency**: _“It works on my machine”_ is a phrase Docker aims to eliminate. By packaging the application, its dependencies (like specific versions of PyTorch, JAX, Transformers), we ensure it runs identically everywhere - from a developer’s laptop to a production server.
225
+ - **Isolation**: Containers run in isolation, preventing conflicts between different applications or dependencies on the same host system.
226
+ - **Simplified Deployment**: Docker abstracts away the underlying infrastructure. With a simple `docker-compose up` command, anyone can get the service running without manually installing Python, various libraries, or configuring complex environments.
227
+ - **Scalability**: Dockerized applications are inherently easier to scale. Orchestration tools like Kubernetes can manage multiple instances of our container to handle increased load.
228
+ - **Multi-Architecture Support**: Our Docker setup supports both `amd64` (common for desktops and servers) and `arm64` (increasingly used in cloud instances and devices like Raspberry Pi) architectures, broadening its usability.
229
+
230
+ ### Key Docker Components Used:
231
+
232
+ - **`Dockerfile`**
233
+
234
+ This is the **recipe for building our Docker image**. It specifies the base image (e.g., a Python image), copies our application code (`app/` directory, `requirements.txt`, etc.) into the image, installs all necessary Python packages, and defines the command to run when the container starts (e.g., `uvicorn app.main:app`).
235
+
236
+ ```Dockerfile
237
+ FROM python:3.11-slim
238
+
239
+ WORKDIR /app
240
+
241
+ COPY requirements.txt .
242
+ RUN pip install --no-cache-dir -r requirements.txt
243
+
244
+ COPY . .
245
+
246
+ ENV MODEL_ID=google/paligemma2-3b-mix-448
247
+ ENV MODELS_DIR=/app/models
248
+
249
+ EXPOSE 8000
250
+
251
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
252
+ ```
253
+
254
+ - **`docker-compose.yml`**
255
+
256
+ While a `Dockerfile` builds a single image, Docker Compose is used to define and run multi-container Docker applications. In our case, it simplifies running our FastAPI service. It can also manage networks, volumes, and other aspects of the application stack. For this project, it handles setting up the service and importantly, the volume mounting for model persistence.
257
+
258
+ ```yaml
259
+ services:
260
+ paligemma-api:
261
+ image: joejoe1313/paligemma-image-segmentation:latest
262
+ ports:
263
+ - "8000:8000"
264
+ environment:
265
+ - MODEL_ID=google/paligemma2-3b-mix-448
266
+ - MODELS_DIR=/app/models
267
+ secrets:
268
+ - hf_token
269
+ volumes:
270
+ - $HOME/.cache/huggingface/hub:/app/models/huggingface
271
+ restart: unless-stopped
272
+
273
+ secrets:
274
+ hf_token:
275
+ file: $HOME/.cache/huggingface/token
276
+ ```
277
+
278
+ - **`.dockerignore`:**
279
+
280
+ Similar to `.gitignore`, this file lists files and directories that should not be copied into the Docker image during the build process (e.g., local virtual environments, `.git` directory). This keeps the image lean and build times faster.
281
+
282
+ ### Smart Model Management with Volume Mounting
283
+
284
+ Models, especially large ones like PaliGemma, take time to download. We use Docker’s **volume mounting** feature to map a directory on the host machine to a directory inside the container. This means:
285
+
286
+ - Models are downloaded once and **persisted** on your host machine.
287
+ - They are **reused** across container restarts.
288
+ - They can be **shared** if you run multiple instances.
289
+ - Updating models becomes easier as you can manage them directly on your host.
290
+
291
+ The default mount point in our `docker-compose.yaml` file is `$HOME/.cache/huggingface/hub:/app/models/huggingface` which maps the local `$HOME/.cache/huggingface/hub` directory to `/app/models/huggingface` in the container.
292
+
293
+ ## CI/CD with GitHub Actions
294
+
295
+ GitHub Actions provides automated workflows for continuous integration and continuous delivery (CI/CD). This is crucial for modern software development, thus we have implemented a CI/CD pipeline using **GitHub Actions**. We can see our workflow `.github/workflows/docker-build.yml` below:
296
+
297
+ ```yaml
298
+ name: Docker Build and Push
299
+
300
+ on:
301
+ push:
302
+ branches: main
303
+ pull_request:
304
+ branches: main
305
+
306
+ jobs:
307
+ build:
308
+ runs-on: ubuntu-latest
309
+ steps:
310
+ - name: Login to Docker Hub
311
+ uses: docker/login-action@v3
312
+ with:
313
+ username: ${{ vars.DOCKERHUB_USERNAME }}
314
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
315
+
316
+ - name: Set up QEMU
317
+ uses: docker/setup-qemu-action@v3
318
+
319
+ - name: Set up Docker Buildx
320
+ uses: docker/setup-buildx-action@v3
321
+
322
+ - name: Build and push
323
+ uses: docker/build-push-action@v6
324
+ with:
325
+ push: ${{ github.event_name != 'pull_request' }}
326
+ platforms: linux/amd64,linux/arm64
327
+ tags: ${{ vars.DOCKERHUB_USERNAME }}/paligemma-image-segmentation:latest
328
+ ```
329
+
330
+ **What does it do?**
331
+
332
+ - **Trigger**: Every time code is pushed to the `main` branch (or a pull request is made to `main`), the GitHub Actions workflow is automatically triggered.
333
+ - **Build**: The workflow checks out the code and builds a multi-architecture (`amd64`/`arm64`) Docker image using the `Dockerfile`.
334
+ - **Push**: The newly built image is then pushed to a container registry (like Docker Hub).
335
+ - **Tag**: The image is tagged with the unique commit SHA (for traceability) and also with `latest` (for convenience).
336
+
337
+ **Figure 4** below shows a high level overview of the workflow:
338
+
339
+ ```{mermaid}
340
+ sequenceDiagram
341
+ participant G as GitHub Repo
342
+ participant A as GitHub Actions
343
+ participant D as Docker Registry
344
+
345
+ G->>A: on: push / pull_request to main
346
+ activate A
347
+ A-->>A: Login to Docker Registry
348
+ A-->>A: Set up QEMU (for multi-arch)
349
+ A-->>A: Set up Docker Buildx
350
+ A-->>A: Build Docker image (multi-arch)
351
+ A-->>D: Push image (tagged with SHA & latest)
352
+ deactivate A
353
+ ```
354
+
355
+ <figcaption style="text-align: center">Figure 4. GitHub Actions workflow</figcaption>
356
+
357
+ **How to use it in your fork**
358
+
359
+ If you fork this repository, you can set up your own CI/CD pipeline by adding the following secrets to your GitHub repository settings:
360
+
361
+ - `DOCKERHUB_USERNAME`: Your Docker Hub username.
362
+ - `DOCKERHUB_TOKEN`: Your Docker Hub access token (not your password!).
363
+ - You should also update the image in docker-compose.yaml with your username: `{DOCKERHUB_USERNAME}/paligemma-image-segmentation:latest`
364
+
365
+ This automated pipeline ensures that a fresh, deployable image is always available.
366
+
367
+ # Getting Started: Installation and Setup
368
+
369
+ Ready to try it yourself? Here’s how:
370
+
371
+ ## Prerequisites
372
+
373
+ - **Docker**: Ensure Docker Desktop or Docker Engine is installed and running.
374
+ - **Hugging Face Token**: To access gated models like PaliGemma, you’ll need a Hugging Face token. Make sure it’s stored at `$HOME/.cache/huggingface/token` on your host machine, or adjust the path accordingly. The application will use this token when downloading models.
375
+
376
+ ## Setup with Docker Compose:
377
+
378
+ - Clone the repository:
379
+
380
+ ```bash
381
+ git clone https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation.git
382
+ cd PaliGemma-Image-Segmentation
383
+ ```
384
+
385
+ - Ensure your **Hugging Face token** is in place as mentioned above. The `docker-compose.yml` is set up to mount your local Hugging Face cache, including the token.
386
+ - Run the application:
387
+
388
+ ```bash
389
+ docker-compose up -d
390
+ ```
391
+
392
+ > **Warning**: It’s generally good practice to avoid running Docker commands as a root user unless necessary. Docker Compose should typically be run by a `user` in the `docker group`.
393
+
394
+ This command will pull the pre-built Docker image and start the FastAPI service on `http://localhost:8000`. The `-dflag` runs it in detached mode.
395
+
396
+ ## Choosing Your PaliGemma Model
397
+
398
+ You have two ways to specify which PaliGemma model variant to use:
399
+
400
+ - At **runtime via API**: Pass the `model_id` parameter in your POST request to the `/segment` endpoint.
401
+ - Via **Docker Environment Variable**: Set the `MODEL_ID` environment variable when starting the container.
402
+
403
+ > **Note**: If both are set, the `model_id` parameter in the API request takes precedence.
404
+
405
+ If a specified model isn’t found in the local cache (`/app/models/huggingface` inside the container, which maps to your `$HOME/.cache/huggingface/hub`), the application will attempt to download it from Hugging Face. **Figure 5** below shows a comprehensive overview of the possible steps.
406
+
407
+ ```{mermaid}
408
+ flowchart TD
409
+ A[API Request] --> B{Check Local Cache}
410
+ B -->|Found| H[Load from Local Cache]
411
+ B -->|Not Found| C{Has HF Token?}
412
+
413
+ %% Style definitions
414
+ classDef process fill:#e0e0ff,stroke:#9999ff,color:black
415
+ classDef decision1 fill:#ffe0b0,stroke:#ffbb66,color:black
416
+ classDef decision2 fill:#d0f0d0,stroke:#aaddaa,color:black
417
+ classDef cache fill:#d0e0ff,stroke:#aabbee,color:black
418
+
419
+ %% Apply styles
420
+ class A,D,F,I,E,Z process
421
+ class B decision1
422
+ class C decision2
423
+ class G,H cache
424
+
425
+ C -->|Yes| D[Authenticate with HF]
426
+ C -->|No| E[Try Loading Public Model]
427
+ D --> F[Download Model]
428
+ F --> G[Save to Cache]
429
+ E -->|Success| G
430
+ E -->|Failure| Z[Auth Error]
431
+ G --> H
432
+ H --> I[Use Model]
433
+ ```
434
+
435
+ <figcaption style="text-align: center">Figure 5. Model download scheme</figcaption>
436
+
437
+ # Examples: Putting the API to the Test
438
+
439
+ Once the service is running (default: `http://localhost:8000`):
440
+
441
+ - **API Docs**: Visit `http://localhost:8000/docs` for interactive API documentation.
442
+
443
+ ## Health Check (GET /)
444
+
445
+ Verify the API is up and running.
446
+
447
+ **Request:**
448
+
449
+ ```python
450
+ import requests
451
+
452
+ response = requests.get("http://localhost:8000/")
453
+ print(response.json())
454
+ ```
455
+
456
+ **Response:**
457
+
458
+ ```json
459
+ {
460
+ "message": "Welcome to the PaliGemma Segmentation API!"
461
+ }
462
+ ```
463
+
464
+ ## Segmenting an Image (POST /segment/)
465
+
466
+ Form Parameters:
467
+
468
+ - `prompt` (str, required): Text description of objects to segment (e.g., "segment the red car").
469
+ - `image_url` (str, optional): URL of the image to segment.
470
+ - `image_file` (UploadFile, optional): Uploaded image file to segment.
471
+ - `model_id` (str, optional): Specific PaliGemma model ID to use.
472
+
473
+ ### Using an Image URL
474
+
475
+ ```{mermaid}
476
+ sequenceDiagram
477
+ participant C as Client
478
+ participant S as /segment Endpoint
479
+ C->>S: POST Request with JSON body: { "image_url": "your_image_url.jpg", "prompt": "object to segment" }
480
+ S-->>S: Download Image
481
+ S-->>S: Process with PaliGemma & VAE
482
+ S-->>C: JSON Response: { "image": "base64_input_image", "masks": [ { "mask": "base64_mask_data", "coordinates": [x_min,y_min,x_max,y_max] } ] }
483
+ ```
484
+ <figcaption style="text-align: center">Figure 6. Segment request</figcaption>
485
+
486
+ **Request:**
487
+
488
+ ```python
489
+ import requests
490
+
491
+ data = {
492
+ "prompt": "segment left wheel",
493
+ "image_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
494
+ }
495
+
496
+ response = requests.post("http://localhost:8000/segment", data=data)
497
+ print(response.json())
498
+ ```
499
+
500
+ ### Uploading an Image File
501
+
502
+ **Request:**
503
+
504
+ ```python
505
+ import os
506
+ import requests
507
+
508
+ segm_url = "http://localhost:8000/segment"
509
+ image_path = "your_image.png"
510
+
511
+ with open(image_path, "rb") as image_file:
512
+ data = {
513
+ "prompt": "segment the main object" # Adjust prompt as needed
514
+ }
515
+ files = {
516
+ "image_file": (os.path.basename(image_path), image_file, "image/png") # Or image/jpeg
517
+ }
518
+
519
+ response = requests.post(segm_url, files=files, data=data)
520
+ print(response.json())
521
+ ```
522
+
523
+ ### Expected JSON Response Structure:
524
+
525
+ **Response:**
526
+
527
+ ```python
528
+ {
529
+ "image": "base64_encoded_model_input_image_data",
530
+ "masks": [
531
+ {
532
+ "mask": "base64_encoded_mask_data_for_object_1",
533
+ "coordinates": [x_min, y_min, x_max, y_max]
534
+ }
535
+ # ...more masks if multiple instances of the prompt are found
536
+ ]
537
+ }
538
+ ```
539
+
540
+ See [example.ipynb](https://github.com/JoeJoe1313/PaliGemma-Image-Segmentation/blob/main/example.ipynb) for a demonstration of the segmentation pipeline.
541
+
542
+ # Conclusion
543
+
544
+ In this guide, we explored an image segmentation API using PaliGemma 2 mix, FastAPI, Transformers, and Docker. This service enables users to segment objects in images using natural language prompts, opening up a wide range of applications in computer vision, image editing, and data analysis. The containerized application provides a flexible, scalable solution that can be easily deployed across various environments. The combination of FastAPI’s performance and Docker’s portability makes this architecture ideal for production ML applications.
545
+
546
+ _Happy coding!_