import gc from pathlib import Path import gradio as gr import matplotlib.cm as cm import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image, ImageOps from transformers import AutoImageProcessor, AutoModel # Device configuration with memory management DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_MAP = { "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m", "DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m", "DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", } DEFAULT_NAME = list(MODEL_MAP.keys())[0] MAX_IMAGE_DIM = 720 # Maximum dimension for longer side # Global model state processor = None model = None def cleanup_memory(): """Aggressive memory cleanup for model switching""" global processor, model if model is not None: del model model = None if processor is not None: del processor processor = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def compute_dynamic_size(height, width, max_dim=720, patch_size=16): """ Compute new dimensions preserving aspect ratio with max_dim constraint. Ensures dimensions are divisible by patch_size for clean patch extraction. """ # Determine scaling factor if height > width: scale = min(1.0, max_dim / height) else: scale = min(1.0, max_dim / width) # Compute new dimensions new_height = int(height * scale) new_width = int(width * scale) # Round to nearest multiple of patch_size for clean patches new_height = (new_height // patch_size) * patch_size new_width = (new_width // patch_size) * patch_size return new_height, new_width def load_model(name): """Load model with CORRECT dtype""" global processor, model cleanup_memory() model_id = MODEL_MAP[name] processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained( model_id, torch_dtype="auto", ).eval() param_count = sum(p.numel() for p in model.parameters()) / 1e9 return f"Loaded: {name} | {param_count:.1f}B params | Ready" # Initialize default model load_model(DEFAULT_NAME) def preprocess_image(img): """ Custom preprocessing that respects aspect ratio and uses dynamic sizing. DINOv3's RoPE handles variable sizes beautifully - no need to constrain to 224x224! """ # Convert to RGB if needed if img.mode != "RGB": img = img.convert("RGB") # Compute dynamic size orig_h, orig_w = img.height, img.width patch_size = model.config.patch_size new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size) # Resize image img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC) # Convert to tensor and normalize using the processor's normalization params img_array = np.array(img_resized).astype(np.float32) / 255.0 # Apply ImageNet normalization (from processor config) mean = ( processor.image_mean if hasattr(processor, "image_mean") else [0.485, 0.456, 0.406] ) std = ( processor.image_std if hasattr(processor, "image_std") else [0.229, 0.224, 0.225] ) img_array = (img_array - mean) / std # Convert to tensor with correct shape: [1, C, H, W] img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() return img_tensor, new_h, new_w @spaces.GPU(duration=60) def _extract_grid(img): """Extract feature grid from image - now with dynamic sizing!""" global model with torch.inference_mode(): # Move model to GPU for this call model = model.to("cuda") # Preprocess with dynamic sizing pv, img_h, img_w = preprocess_image(img) pv = pv.to(model.device) # Run inference - the model handles variable sizes perfectly! out = model(pixel_values=pv) last = out.last_hidden_state[0].to(torch.float32) # Extract features num_reg = getattr(model.config, "num_register_tokens", 0) p = model.config.patch_size # Calculate grid dimensions based on actual image size gh, gw = img_h // p, img_w // p feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() # Move model back to CPU before function exits model = model.cpu() torch.cuda.empty_cache() return feats, gh, gw, img_h, img_w def _overlay(orig, heat01, alpha=0.55, box=None): """Create heatmap overlay""" H, W = orig.height, orig.width heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H)) # Use turbo colormap - better for satellite imagery rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) ov = Image.fromarray(rgba, "RGBA") ov.putalpha(int(alpha * 255)) base = orig.copy().convert("RGBA") out = Image.alpha_composite(base, ov) if box: from PIL import ImageDraw draw = ImageDraw.Draw(out, "RGBA") # Enhanced box visualization draw.rectangle(box, outline=(255, 255, 255, 255), width=3) draw.rectangle( (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), outline=(0, 0, 0, 200), width=1, ) return out def prepare(img): """Prepare image and extract features with dynamic sizing""" if img is None: return None base = ImageOps.exif_transpose(img.convert("RGB")) feats, gh, gw, img_h, img_w = _extract_grid(base) return { "orig": base, "feats": feats, "gh": gh, "gw": gw, "processed_h": img_h, "processed_w": img_w, } def click(state, opacity, img_value, evt: gr.SelectData): """Handle click events for similarity visualization with progress feedback""" # Immediate feedback in resolution_info box if img_value is not None: yield img_value, state, "Computing similarity..." # If state wasn't prepared (e.g., Example selection), build it now if state is None and img_value is not None: state = prepare(img_value) if not state or evt.index is None: # Just show whatever is currently in the image component yield img_value, state, "" return base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] x, y = evt.index px_x, px_y = base.width / gw, base.height / gh i = min(int(x // px_x), gw - 1) j = min(int(y // px_y), gh - 1) d = feats.shape[-1] grid = F.normalize(feats.reshape(-1, d), dim=1) v = F.normalize(feats[j, i].reshape(1, d), dim=1) sims = (grid @ v.T).reshape(gh, gw).numpy() smin, smax = float(sims.min()), float(sims.max()) heat01 = (sims - smin) / (smax - smin + 1e-12) box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) overlay = _overlay(base, heat01, alpha=opacity, box=box) # Add info about resolution being processed info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches) | Patch [{i},{j}] • Range: {smin:.3f}-{smax:.3f}" yield overlay, state, info_text def reset(): """Reset the interface""" return None, None, "" with gr.Blocks( theme=gr.themes.Citrus(), css=""" .container {max-width: 1200px; margin: auto;} .header {text-align: center; padding: 20px;} .info-box { background: rgba(0,0,0,0.03); border-radius: 8px; padding: 12px; margin: 10px 0; border-left: 4px solid #2563eb; } """, ) as demo: gr.HTML( """

🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity

Click any region to visualize feature similarities across the image

""" ) with gr.Row(): with gr.Column(scale=1): model_choice = gr.Dropdown( choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model Selection", info="Select a model (size/pretraining dataset)", ) status = gr.Textbox( label="Model Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False, lines=1, ) resolution_info = gr.Textbox( label="Info & Status", value="", interactive=False, lines=1, ) opacity = gr.Slider( 0.0, 1.0, 0.55, step=0.05, label="Heatmap Opacity", info="Balance between image and similarity map", ) with gr.Row(): reset_btn = gr.Button("Reset", variant="secondary", scale=1) clear_btn = gr.ClearButton(value="Clear All", scale=1) with gr.Column(scale=2): img = gr.Image( type="pil", label="Interactive Canvas (Click to explore)", interactive=True, height=600, show_download_button=True, show_share_button=False, ) state = gr.State() model_choice.change( load_model, inputs=model_choice, outputs=status, show_progress="full" ) img.upload(prepare, inputs=img, outputs=state) img.select( click, inputs=[state, opacity, img], outputs=[img, state, resolution_info], show_progress="hidden", # Hide default overlay, use resolution_info for feedback ) reset_btn.click(reset, outputs=[img, state, resolution_info]) clear_btn.add([img, state, resolution_info]) # Examples from current directory example_files = [ f.name for f in Path.cwd().iterdir() if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"] ] if example_files: gr.Examples( examples=[[f] for f in example_files], inputs=img, fn=prepare, outputs=[state], label="Example Images", examples_per_page=4, cache_examples=False, ) gr.Markdown( f""" ---
Satellite-pretrained models are intended for: geographic patterns, land use classification. structural analysis, etc.

Dynamic Resolution: Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio. DINOv3's 3D RoPE embeddings handle variable sizes.

Performance Notes:The 7B model provides exceptional detail at the cost of high memory usage.
""" ) if __name__ == "__main__": demo.launch(share=False, debug=True)