pszemraj commited on
Commit
629aa9f
·
verified ·
1 Parent(s): f260034

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -209
app.py CHANGED
@@ -7,296 +7,224 @@ import numpy as np
7
  import spaces
8
  import torch
9
  import torch.nn.functional as F
10
- from PIL import Image, ImageDraw, ImageOps
11
  from transformers import AutoImageProcessor, AutoModel
12
 
13
- # --- Configuration ---
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  MODEL_MAP = {
17
- "DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m",
18
- "DINOv3 ViT-L/16 LVD (General Web)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
19
- # "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", # Uncomment if using a large enough GPU
20
  }
21
 
22
- DEFAULT_MODEL_NAME = list(MODEL_MAP.keys())[0]
23
 
24
- # --- Global State ---
25
  processor = None
26
  model = None
27
 
28
 
29
- # --- Core Functions ---
30
  def cleanup_memory():
31
- """Aggressively cleans up memory to prevent OOM errors when switching models."""
32
  global processor, model
 
33
  if model is not None:
34
  del model
 
 
35
  if processor is not None:
36
  del processor
37
- processor, model = None, None
 
38
  gc.collect()
 
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
 
41
 
42
 
43
- def load_model(name: str):
44
- """Loads a specified model and processor, handling memory cleanup."""
45
  global processor, model
 
46
  try:
 
47
  cleanup_memory()
 
48
  model_id = MODEL_MAP[name]
49
 
 
50
  processor = AutoImageProcessor.from_pretrained(model_id)
 
 
51
  model = (
52
- AutoModel.from_pretrained(model_id, torch_dtype="auto").to(DEVICE).eval()
 
 
 
 
 
53
  )
54
 
 
55
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
56
- dtype_str = str(next(model.parameters()).dtype).split(".")[-1]
57
 
58
- return f"Loaded: {name} | {param_count:.1f}B params | {dtype_str} | {DEVICE.upper()}"
 
59
  except Exception as e:
60
  cleanup_memory()
61
- return f"Failed to load {name}: {str(e)}"
62
-
63
 
64
- @spaces.GPU(duration=60)
65
- def _extract_grid(img: Image.Image):
66
- """Extracts a grid of feature vectors from an image."""
67
- with torch.inference_mode():
68
- inputs = processor(images=img, return_tensors="pt").to(DEVICE)
69
- outputs = model(**inputs)
70
 
71
- last_hidden_state = outputs.last_hidden_state[0].to(torch.float32)
 
72
 
73
- # Correctly calculate grid dimensions from model config
74
- p = model.config.patch_size
75
- h, w = inputs.pixel_values.shape[-2] // p, inputs.pixel_values.shape[-1] // p
76
 
77
- # Exclude [CLS] token and reshape
78
- # DINOv2/v3 models don't have register tokens, but this is safe
79
- num_special_tokens = 1
80
- features = last_hidden_state[num_special_tokens:, :].reshape(h, w, -1).cpu()
 
81
 
82
- return features, h, w
 
83
 
 
 
84
 
85
- def _overlay_heatmap(
86
- base_img: Image.Image,
87
- heatmap_01: np.ndarray,
88
- opacity: float,
89
- colormap: str,
90
- box: tuple = None,
91
- ):
92
- """Overlays a heatmap on the base image with a specified colormap and opacity."""
93
- h, w = base_img.height, base_img.width
94
 
95
- # Resize heatmap and apply colormap
96
- heatmap = Image.fromarray((heatmap_01 * 255).astype(np.uint8)).resize(
97
- (w, h), resample=Image.LANCZOS
98
- )
99
- cmap_func = cm.get_cmap(colormap.lower())
100
- rgba_heatmap = (cmap_func(np.asarray(heatmap) / 255.0) * 255).astype(np.uint8)
101
 
102
- # Create overlay image with specified opacity
103
- overlay = Image.fromarray(rgba_heatmap, "RGBA")
104
- overlay.putalpha(int(opacity * 255))
105
 
106
- # Composite overlay onto a copy of the base image
107
- out_img = Image.alpha_composite(base_img.copy().convert("RGBA"), overlay)
108
 
109
- # Draw selection box if provided
 
 
 
 
 
 
 
 
110
  if box:
111
- draw = ImageDraw.Draw(out_img, "RGBA")
112
- # Draw a thick white border with a thin black outline for visibility
113
- draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
114
- draw.rectangle(
115
- (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
116
- outline=(0, 0, 0, 200),
117
- width=1,
118
- )
119
 
120
- return out_img.convert("RGB")
 
 
 
121
 
122
 
123
- # --- Gradio Event Handlers ---
124
- def prepare(img: Image.Image):
125
- """Prepares the image by extracting features and storing them in the state."""
126
  if img is None:
127
- return None, "Upload an image to begin."
 
 
 
128
 
129
- base_img = ImageOps.exif_transpose(img.convert("RGB"))
130
- feats, gh, gw = _extract_grid(base_img)
131
 
132
- state = {"orig": base_img, "feats": feats, "gh": gh, "gw": gw}
133
- return state, "Click on the image to compute similarity."
134
 
 
 
 
 
 
135
 
136
- def on_click(state: dict, opacity: float, colormap: str, evt: gr.SelectData):
137
- """Handles click events to compute and display the similarity heatmap."""
138
  if not state or evt.index is None:
139
- return gr.UNCHANGED, state, "Please upload an image first."
 
140
 
141
  base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
142
 
143
- # Calculate patch index from click coordinates
144
  x, y = evt.index
145
- px_w, px_h = base.width / gw, base.height / gh
146
- i = min(int(x // px_w), gw - 1)
147
- j = min(int(y // px_h), gh - 1)
148
 
149
- # Compute similarity
150
  d = feats.shape[-1]
151
- query_vec = F.normalize(feats[j, i].reshape(1, d), dim=1)
152
- feature_grid = F.normalize(feats.reshape(-1, d), dim=1)
153
- sims = (feature_grid @ query_vec.T).reshape(gh, gw).numpy()
154
 
155
- # Normalize similarity map to [0, 1] for visualization
156
  smin, smax = float(sims.min()), float(sims.max())
157
- heatmap_01 = (sims - smin) / (smax - smin + 1e-12)
158
-
159
- # Define selection box coordinates
160
- box = (int(i * px_w), int(j * px_h), int((i + 1) * px_w), int((j + 1) * px_h))
161
-
162
- # Generate overlay image
163
- output_img = _overlay_heatmap(base, heatmap_01, opacity, colormap, box)
164
-
165
- # Create statistics string
166
- stats = f"""📊 **Similarity Statistics**
167
- - **Min**: `{smin:.3f}`
168
- - **Max**: `{smax:.3f}`
169
- - **Range**: `{smax - smin:.3f}`
170
- - **Patch Index**: `({i}, {j})`
171
- - **Grid Size**: `{gw}×{gh}`"""
172
-
173
- return output_img, state, stats
174
-
175
-
176
- def reset_overlay(state: dict):
177
- """Resets the image to its original state, removing the heatmap."""
178
- if state and "orig" in state:
179
- return state["orig"], state, "Overlay reset. Click the image again."
180
- return None, None, "Upload an image to begin."
181
-
182
-
183
- # --- Gradio UI ---
184
- with gr.Blocks(
185
- theme=gr.themes.Soft(primary_hue="blue"),
186
- css=".container {max-width: 1200px; margin: auto;}",
187
- ) as demo:
188
- gr.HTML(
189
- """
190
- <div style="text-align: center; padding: 20px;">
191
- <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
192
- <p style="font-size: 1.1em; color: #666;">Explore how DINOv3 models trained on satellite imagery understand visual patterns</p>
193
- </div>
194
- """
195
- )
196
 
197
- with gr.Row():
198
- with gr.Column(scale=3):
199
- gr.Markdown(
200
- """
201
- ### How it works
202
- 1. **Select a model** - Satellite-pretrained models are optimized for aerial/satellite imagery.
203
- 2. **Upload or select an image** - Works best with satellite, aerial, or outdoor scenes.
204
- 3. **Click any region** - See how similar other patches are to your selection.
205
- 4. **Adjust visualization** - Fine-tune opacity and colormap for clarity.
206
- """
207
- )
208
- with gr.Column(scale=2):
209
- gr.HTML(
210
- """
211
- <div style="background: rgba(0,0,0,0.03); border-radius: 8px; padding: 12px; border-left: 4px solid #2563eb;">
212
- <b>💡 Model Info:</b><br>
213
- • <b>Satellite models</b>: Trained on 493M satellite images.<br>
214
- • <b>LVD model</b>: Trained on 1.7B diverse images.<br>
215
- • <b>7B model</b>: Massive capacity, slower but more nuanced.
216
- </div>
217
- """
218
- )
219
 
220
- with gr.Row(variant="panel"):
221
- with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
222
  model_choice = gr.Dropdown(
223
- choices=list(MODEL_MAP.keys()),
224
- value=DEFAULT_MODEL_NAME,
225
- label="🤖 Model Selection",
226
  )
227
  status = gr.Textbox(
228
- label="📡 Model Status",
229
- value="Loading initial model...",
230
- interactive=False,
231
  )
 
232
 
233
- with gr.Accordion("Visualization Controls", open=True):
234
- opacity = gr.Slider(
235
- 0.2, 0.9, 0.55, step=0.05, label="🎨 Heatmap Opacity"
236
- )
237
- colormap = gr.Dropdown(
238
- ["Turbo", "Inferno", "Viridis", "Plasma", "Jet"],
239
- value="Turbo",
240
- label="🌈 Colormap",
241
- )
242
-
243
- info_panel = gr.Markdown(
244
- value="*Upload an image and click on it to see statistics here.*",
245
- label="Statistics",
246
- )
247
-
248
- with gr.Row():
249
- reset_btn = gr.Button("🔄 Reset Overlay")
250
- # A ClearButton is simpler for clearing multiple components
251
- clear_btn = gr.ClearButton(value="🗑️ Clear All")
252
-
253
- with gr.Column(scale=2):
254
  img = gr.Image(
255
- type="pil",
256
- label="Interactive Canvas (Click to explore)",
257
- height=600,
258
- show_download_button=True,
259
  )
260
 
261
- # Define a state object to hold persistent data (original image, features)
262
  state = gr.State()
263
 
264
- # NOTE: For Hugging Face Spaces, list file paths explicitly.
265
- # Make sure these images are in your repository.
266
- gr.Examples(
267
- examples=[
268
- ["examples/satellite_city.jpg"],
269
- ["examples/coastal_area.png"],
270
- ["examples/farmland.webp"],
271
- ],
272
- inputs=[img],
273
- outputs=[state, info_panel],
274
- fn=prepare,
275
- cache_examples=torch.cuda.is_available(), # Cache on GPU instances
276
- )
277
 
278
- # --- Event Wiring ---
279
- demo.load(lambda: load_model(DEFAULT_MODEL_NAME), outputs=status)
280
-
281
- model_choice.change(
282
- load_model, inputs=model_choice, outputs=status, show_progress="full"
283
- )
284
-
285
- img.upload(
286
- prepare, inputs=img, outputs=[state, info_panel], show_progress="minimal"
287
- )
288
 
289
  img.select(
290
- on_click,
291
- inputs=[state, opacity, colormap], # REMOVED `img` from inputs to fix the error
292
- outputs=[img, state, info_panel],
293
  show_progress="minimal",
294
  )
295
 
296
- reset_btn.click(reset_overlay, inputs=[state], outputs=[img, state, info_panel])
297
-
298
- # Wire the clear button to the components it should clear
299
- clear_btn.add([img, state, info_panel])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- print(load_model(DEFAULT_MODEL_NAME))
302
- demo.launch(share=False, show_error=True)
 
7
  import spaces
8
  import torch
9
  import torch.nn.functional as F
10
+ from PIL import Image, ImageOps
11
  from transformers import AutoImageProcessor, AutoModel
12
 
13
+ # Device configuration with memory management
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  MODEL_MAP = {
17
+ "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m",
18
+ "DINOv3 ViT-L/16 LVD (1.7B)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
19
+ "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
20
  }
21
 
22
+ DEFAULT_NAME = list(MODEL_MAP.keys())[0]
23
 
24
+ # Global model state
25
  processor = None
26
  model = None
27
 
28
 
 
29
  def cleanup_memory():
30
+ """Aggressive memory cleanup for model switching"""
31
  global processor, model
32
+
33
  if model is not None:
34
  del model
35
+ model = None
36
+
37
  if processor is not None:
38
  del processor
39
+ processor = None
40
+
41
  gc.collect()
42
+
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
+ torch.cuda.synchronize()
46
 
47
 
48
+ def load_model(name):
49
+ """Load model with proper memory management and dtype handling"""
50
  global processor, model
51
+
52
  try:
53
+ # Clean up existing model
54
  cleanup_memory()
55
+
56
  model_id = MODEL_MAP[name]
57
 
58
+ # Load processor
59
  processor = AutoImageProcessor.from_pretrained(model_id)
60
+
61
+ # Load model with auto dtype for optimal performance
62
  model = (
63
+ AutoModel.from_pretrained(
64
+ model_id,
65
+ torch_dtype="auto",
66
+ )
67
+ .to(DEVICE)
68
+ .eval()
69
  )
70
 
71
+ # Get model info
72
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
 
73
 
74
+ return f"Loaded: {name} | {param_count:.1f}B params | {DEVICE.upper()}"
75
+
76
  except Exception as e:
77
  cleanup_memory()
78
+ return f"Failed to load {name}: {str(e)}"
 
79
 
 
 
 
 
 
 
80
 
81
+ # Initialize default model
82
+ load_model(DEFAULT_NAME)
83
 
 
 
 
84
 
85
+ @spaces.GPU(duration=60)
86
+ def _extract_grid(img):
87
+ """Extract feature grid from image"""
88
+ with torch.inference_mode():
89
+ pv = processor(images=img, return_tensors="pt").pixel_values
90
 
91
+ if DEVICE == "cuda":
92
+ pv = pv.to(DEVICE)
93
 
94
+ out = model(pixel_values=pv)
95
+ last = out.last_hidden_state[0].to(torch.float32)
96
 
97
+ num_reg = getattr(model.config, "num_register_tokens", 0)
98
+ p = model.config.patch_size
99
+ _, _, Ht, Wt = pv.shape
100
+ gh, gw = Ht // p, Wt // p
 
 
 
 
 
101
 
102
+ feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
 
 
 
 
 
103
 
104
+ return feats, gh, gw
 
 
105
 
 
 
106
 
107
+ def _overlay(orig, heat01, alpha=0.55, box=None):
108
+ """Create heatmap overlay"""
109
+ H, W = orig.height, orig.width
110
+ heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
111
+ rgba = (cm.get_cmap("inferno")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
112
+ ov = Image.fromarray(rgba, "RGBA")
113
+ ov.putalpha(int(alpha * 255))
114
+ base = orig.copy().convert("RGBA")
115
+ out = Image.alpha_composite(base, ov)
116
  if box:
117
+ from PIL import ImageDraw
 
 
 
 
 
 
 
118
 
119
+ ImageDraw.Draw(out, "RGBA").rectangle(
120
+ box, outline=(255, 255, 255, 220), width=2
121
+ )
122
+ return out
123
 
124
 
125
+ def prepare(img):
126
+ """Prepare image and extract features"""
 
127
  if img is None:
128
+ return None
129
+
130
+ base = ImageOps.exif_transpose(img.convert("RGB"))
131
+ feats, gh, gw = _extract_grid(base)
132
 
133
+ return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
 
134
 
 
 
135
 
136
+ def click(state, opacity, img_value, evt: gr.SelectData):
137
+ """Handle click events for similarity visualization"""
138
+ # If state wasn't prepared (e.g., Example selection), build it now
139
+ if state is None and img_value is not None:
140
+ state = prepare(img_value)
141
 
 
 
142
  if not state or evt.index is None:
143
+ # Just show whatever is currently in the image component
144
+ return img_value, state
145
 
146
  base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
147
 
 
148
  x, y = evt.index
149
+ px_x, px_y = base.width / gw, base.height / gh
150
+ i = min(int(x // px_x), gw - 1)
151
+ j = min(int(y // px_y), gh - 1)
152
 
 
153
  d = feats.shape[-1]
154
+ grid = F.normalize(feats.reshape(-1, d), dim=1)
155
+ v = F.normalize(feats[j, i].reshape(1, d), dim=1)
156
+ sims = (grid @ v.T).reshape(gh, gw).numpy()
157
 
 
158
  smin, smax = float(sims.min()), float(sims.max())
159
+ heat01 = (sims - smin) / (smax - smin + 1e-12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
162
+ overlay = _overlay(base, heat01, alpha=opacity, box=box)
163
+ return overlay, state
164
+
165
+
166
+ def reset():
167
+ """Reset the interface"""
168
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+
171
+ with gr.Blocks() as demo:
172
+ gr.Markdown("## DINOv3: patch similarity visualizer")
173
+ gr.Markdown(
174
+ "This is an app where you can upload an image, click on an object in the image and get the most similar patches to it according to DINOv3, revealing the way DINOv3 segments objects through features natively."
175
+ )
176
+ gr.Markdown("There's multiple model options you can pick from the dropdown.")
177
+ gr.Markdown(
178
+ "Please click Reset before you want to upload another image, as this app keeps features of the images."
179
+ )
180
+
181
+ with gr.Column():
182
+ with gr.Row(scale=1):
183
  model_choice = gr.Dropdown(
184
+ choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model"
 
 
185
  )
186
  status = gr.Textbox(
187
+ label="Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False
 
 
188
  )
189
+ opacity = gr.Slider(0.0, 1.0, 0.55, step=0.05, label="Opacity for the Map")
190
 
191
+ with gr.Row(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  img = gr.Image(
193
+ type="pil", label="Image", interactive=True, height=750, width=750
 
 
 
194
  )
195
 
 
196
  state = gr.State()
197
 
198
+ model_choice.change(load_model, inputs=model_choice, outputs=status)
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ img.upload(prepare, inputs=img, outputs=state)
 
 
 
 
 
 
 
 
 
201
 
202
  img.select(
203
+ click,
204
+ inputs=[state, opacity, img],
205
+ outputs=[img, state],
206
  show_progress="minimal",
207
  )
208
 
209
+ gr.Button("Reset").click(reset, outputs=[img, state])
210
+
211
+ # Examples from current directory
212
+ example_files = [
213
+ str(f)
214
+ for f in Path.cwd().iterdir()
215
+ if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
216
+ ]
217
+
218
+ if example_files:
219
+ gr.Examples(
220
+ examples=[[f] for f in example_files],
221
+ inputs=img,
222
+ fn=prepare,
223
+ outputs=[state],
224
+ label="Try an example image and then click on the objects.",
225
+ examples_per_page=4,
226
+ cache_examples=False,
227
+ )
228
 
229
+ if __name__ == "__main__":
230
+ demo.launch()