Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
from pathlib import Path | |
import gradio as gr | |
import matplotlib.cm as cm | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageOps | |
from transformers import AutoImageProcessor, AutoModel | |
# Device configuration with memory management | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_MAP = { | |
"DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m", | |
"DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
"DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", | |
} | |
DEFAULT_NAME = list(MODEL_MAP.keys())[0] | |
MAX_IMAGE_DIM = 720 # Maximum dimension for longer side | |
# Global model state | |
processor = None | |
model = None | |
def cleanup_memory(): | |
"""Aggressive memory cleanup for model switching""" | |
global processor, model | |
if model is not None: | |
del model | |
model = None | |
if processor is not None: | |
del processor | |
processor = None | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def compute_dynamic_size(height, width, max_dim=720, patch_size=16): | |
""" | |
Compute new dimensions preserving aspect ratio with max_dim constraint. | |
Ensures dimensions are divisible by patch_size for clean patch extraction. | |
""" | |
# Determine scaling factor | |
if height > width: | |
scale = min(1.0, max_dim / height) | |
else: | |
scale = min(1.0, max_dim / width) | |
# Compute new dimensions | |
new_height = int(height * scale) | |
new_width = int(width * scale) | |
# Round to nearest multiple of patch_size for clean patches | |
new_height = (new_height // patch_size) * patch_size | |
new_width = (new_width // patch_size) * patch_size | |
return new_height, new_width | |
def load_model(name): | |
"""Load model with CORRECT dtype""" | |
global processor, model | |
cleanup_memory() | |
model_id = MODEL_MAP[name] | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
model = AutoModel.from_pretrained( | |
model_id, | |
torch_dtype="auto", | |
).eval() | |
param_count = sum(p.numel() for p in model.parameters()) / 1e9 | |
return f"Loaded: {name} | {param_count:.1f}B params | Ready" | |
# Initialize default model | |
load_model(DEFAULT_NAME) | |
def preprocess_image(img): | |
""" | |
Custom preprocessing that respects aspect ratio and uses dynamic sizing. | |
DINOv3's RoPE handles variable sizes beautifully - no need to constrain to 224x224! | |
""" | |
# Convert to RGB if needed | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
# Compute dynamic size | |
orig_h, orig_w = img.height, img.width | |
patch_size = model.config.patch_size | |
new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size) | |
# Resize image | |
img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC) | |
# Convert to tensor and normalize using the processor's normalization params | |
img_array = np.array(img_resized).astype(np.float32) / 255.0 | |
# Apply ImageNet normalization (from processor config) | |
mean = ( | |
processor.image_mean | |
if hasattr(processor, "image_mean") | |
else [0.485, 0.456, 0.406] | |
) | |
std = ( | |
processor.image_std | |
if hasattr(processor, "image_std") | |
else [0.229, 0.224, 0.225] | |
) | |
img_array = (img_array - mean) / std | |
# Convert to tensor with correct shape: [1, C, H, W] | |
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() | |
return img_tensor, new_h, new_w | |
def _extract_grid(img): | |
"""Extract feature grid from image - now with dynamic sizing!""" | |
global model | |
with torch.inference_mode(): | |
# Move model to GPU for this call | |
model = model.to("cuda") | |
# Preprocess with dynamic sizing | |
pv, img_h, img_w = preprocess_image(img) | |
pv = pv.to(model.device) | |
# Run inference - the model handles variable sizes perfectly! | |
out = model(pixel_values=pv) | |
last = out.last_hidden_state[0].to(torch.float32) | |
# Extract features | |
num_reg = getattr(model.config, "num_register_tokens", 0) | |
p = model.config.patch_size | |
# Calculate grid dimensions based on actual image size | |
gh, gw = img_h // p, img_w // p | |
feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() | |
# Move model back to CPU before function exits | |
model = model.cpu() | |
torch.cuda.empty_cache() | |
return feats, gh, gw, img_h, img_w | |
def _overlay(orig, heat01, alpha=0.55, box=None): | |
"""Create heatmap overlay""" | |
H, W = orig.height, orig.width | |
heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H)) | |
# Use turbo colormap - better for satellite imagery | |
rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) | |
ov = Image.fromarray(rgba, "RGBA") | |
ov.putalpha(int(alpha * 255)) | |
base = orig.copy().convert("RGBA") | |
out = Image.alpha_composite(base, ov) | |
if box: | |
from PIL import ImageDraw | |
draw = ImageDraw.Draw(out, "RGBA") | |
# Enhanced box visualization | |
draw.rectangle(box, outline=(255, 255, 255, 255), width=3) | |
draw.rectangle( | |
(box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), | |
outline=(0, 0, 0, 200), | |
width=1, | |
) | |
return out | |
def prepare(img): | |
"""Prepare image and extract features with dynamic sizing""" | |
if img is None: | |
return None | |
base = ImageOps.exif_transpose(img.convert("RGB")) | |
feats, gh, gw, img_h, img_w = _extract_grid(base) | |
return { | |
"orig": base, | |
"feats": feats, | |
"gh": gh, | |
"gw": gw, | |
"processed_h": img_h, | |
"processed_w": img_w, | |
} | |
def click(state, opacity, img_value, evt: gr.SelectData): | |
"""Handle click events for similarity visualization""" | |
# If state wasn't prepared (e.g., Example selection), build it now | |
if state is None and img_value is not None: | |
state = prepare(img_value) | |
if not state or evt.index is None: | |
# Just show whatever is currently in the image component | |
return img_value, state | |
base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] | |
x, y = evt.index | |
px_x, px_y = base.width / gw, base.height / gh | |
i = min(int(x // px_x), gw - 1) | |
j = min(int(y // px_y), gh - 1) | |
d = feats.shape[-1] | |
grid = F.normalize(feats.reshape(-1, d), dim=1) | |
v = F.normalize(feats[j, i].reshape(1, d), dim=1) | |
sims = (grid @ v.T).reshape(gh, gw).numpy() | |
smin, smax = float(sims.min()), float(sims.max()) | |
heat01 = (sims - smin) / (smax - smin + 1e-12) | |
box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) | |
overlay = _overlay(base, heat01, alpha=opacity, box=box) | |
# Add info about resolution being processed | |
info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches)" | |
return overlay, state, info_text | |
def reset(): | |
"""Reset the interface""" | |
return None, None, "" | |
with gr.Blocks( | |
theme=gr.themes.Citrus(), | |
css=""" | |
.container {max-width: 1200px; margin: auto;} | |
.header {text-align: center; padding: 20px;} | |
.info-box { | |
background: rgba(0,0,0,0.03); | |
border-radius: 8px; | |
padding: 12px; | |
margin: 10px 0; | |
border-left: 4px solid #2563eb; | |
} | |
""", | |
) as demo: | |
gr.HTML( | |
""" | |
<div class="header"> | |
<h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1> | |
<p style="font-size: 1.1em; color: #666;"> | |
Click any region to visualize feature similarities across the image | |
</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_choice = gr.Dropdown( | |
choices=list(MODEL_MAP.keys()), | |
value=DEFAULT_NAME, | |
label="Model Selection", | |
info="Select a model (size/pretraining dataset)", | |
) | |
status = gr.Textbox( | |
label="Model Status", | |
value=f"Loaded: {DEFAULT_NAME}", | |
interactive=False, | |
lines=1, | |
) | |
resolution_info = gr.Textbox( | |
label="Processing Resolution", | |
value="", | |
interactive=False, | |
lines=1, | |
) | |
opacity = gr.Slider( | |
0.0, | |
1.0, | |
0.55, | |
step=0.05, | |
label="Heatmap Opacity", | |
info="Balance between image and similarity map", | |
) | |
with gr.Row(): | |
reset_btn = gr.Button("Reset", variant="secondary", scale=1) | |
clear_btn = gr.ClearButton(value="Clear All", scale=1) | |
with gr.Column(scale=2): | |
img = gr.Image( | |
type="pil", | |
label="Interactive Canvas (Click to explore)", | |
interactive=True, | |
height=600, | |
show_download_button=True, | |
show_share_button=False, | |
) | |
state = gr.State() | |
model_choice.change( | |
load_model, inputs=model_choice, outputs=status, show_progress="full" | |
) | |
img.upload(prepare, inputs=img, outputs=state) | |
img.select( | |
click, | |
inputs=[state, opacity, img], | |
outputs=[img, state, resolution_info], | |
show_progress="minimal", | |
) | |
reset_btn.click(reset, outputs=[img, state, resolution_info]) | |
clear_btn.add([img, state, resolution_info]) | |
# Examples from current directory | |
example_files = [ | |
f.name | |
for f in Path.cwd().iterdir() | |
if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"] | |
] | |
if example_files: | |
gr.Examples( | |
examples=[[f] for f in example_files], | |
inputs=img, | |
fn=prepare, | |
outputs=[state], | |
label="Example Images", | |
examples_per_page=4, | |
cache_examples=False, | |
) | |
gr.Markdown( | |
f""" | |
--- | |
<div style="text-align: center; color: #666; font-size: 0.9em;"> | |
<b>Dynamic Resolution:</b> Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio. | |
DINOv3's 3D RoPE embeddings handle variable sizes. | |
<br><br> | |
<b>Performance Notes:</b> Satellite models are intended for: geographic patterns, land use classification, | |
and structural analysis. The 7B model provides exceptional detail at the codt of high memory usage. | |
<br><br> | |
</div> | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False, debug=True) |