pszemraj's picture
Update app.py
894a56e verified
import gc
from pathlib import Path
import gradio as gr
import matplotlib.cm as cm
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
from transformers import AutoImageProcessor, AutoModel
# Device configuration with memory management
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_MAP = {
"DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m",
"DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
}
DEFAULT_NAME = list(MODEL_MAP.keys())[0]
MAX_IMAGE_DIM = 720 # Maximum dimension for longer side
# Global model state
processor = None
model = None
def cleanup_memory():
"""Aggressive memory cleanup for model switching"""
global processor, model
if model is not None:
del model
model = None
if processor is not None:
del processor
processor = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def compute_dynamic_size(height, width, max_dim=720, patch_size=16):
"""
Compute new dimensions preserving aspect ratio with max_dim constraint.
Ensures dimensions are divisible by patch_size for clean patch extraction.
"""
# Determine scaling factor
if height > width:
scale = min(1.0, max_dim / height)
else:
scale = min(1.0, max_dim / width)
# Compute new dimensions
new_height = int(height * scale)
new_width = int(width * scale)
# Round to nearest multiple of patch_size for clean patches
new_height = (new_height // patch_size) * patch_size
new_width = (new_width // patch_size) * patch_size
return new_height, new_width
def load_model(name):
"""Load model with CORRECT dtype"""
global processor, model
cleanup_memory()
model_id = MODEL_MAP[name]
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(
model_id,
torch_dtype="auto",
).eval()
param_count = sum(p.numel() for p in model.parameters()) / 1e9
return f"Loaded: {name} | {param_count:.1f}B params | Ready"
# Initialize default model
load_model(DEFAULT_NAME)
def preprocess_image(img):
"""
Custom preprocessing that respects aspect ratio and uses dynamic sizing.
DINOv3's RoPE handles variable sizes beautifully - no need to constrain to 224x224!
"""
# Convert to RGB if needed
if img.mode != "RGB":
img = img.convert("RGB")
# Compute dynamic size
orig_h, orig_w = img.height, img.width
patch_size = model.config.patch_size
new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size)
# Resize image
img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC)
# Convert to tensor and normalize using the processor's normalization params
img_array = np.array(img_resized).astype(np.float32) / 255.0
# Apply ImageNet normalization (from processor config)
mean = (
processor.image_mean
if hasattr(processor, "image_mean")
else [0.485, 0.456, 0.406]
)
std = (
processor.image_std
if hasattr(processor, "image_std")
else [0.229, 0.224, 0.225]
)
img_array = (img_array - mean) / std
# Convert to tensor with correct shape: [1, C, H, W]
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float()
return img_tensor, new_h, new_w
@spaces.GPU(duration=60)
def _extract_grid(img):
"""Extract feature grid from image - now with dynamic sizing!"""
global model
with torch.inference_mode():
# Move model to GPU for this call
model = model.to("cuda")
# Preprocess with dynamic sizing
pv, img_h, img_w = preprocess_image(img)
pv = pv.to(model.device)
# Run inference - the model handles variable sizes perfectly!
out = model(pixel_values=pv)
last = out.last_hidden_state[0].to(torch.float32)
# Extract features
num_reg = getattr(model.config, "num_register_tokens", 0)
p = model.config.patch_size
# Calculate grid dimensions based on actual image size
gh, gw = img_h // p, img_w // p
feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
# Move model back to CPU before function exits
model = model.cpu()
torch.cuda.empty_cache()
return feats, gh, gw, img_h, img_w
def _overlay(orig, heat01, alpha=0.55, box=None):
"""Create heatmap overlay"""
H, W = orig.height, orig.width
heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
# Use turbo colormap - better for satellite imagery
rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
ov = Image.fromarray(rgba, "RGBA")
ov.putalpha(int(alpha * 255))
base = orig.copy().convert("RGBA")
out = Image.alpha_composite(base, ov)
if box:
from PIL import ImageDraw
draw = ImageDraw.Draw(out, "RGBA")
# Enhanced box visualization
draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
draw.rectangle(
(box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
outline=(0, 0, 0, 200),
width=1,
)
return out
def prepare(img):
"""Prepare image and extract features with dynamic sizing"""
if img is None:
return None
base = ImageOps.exif_transpose(img.convert("RGB"))
feats, gh, gw, img_h, img_w = _extract_grid(base)
return {
"orig": base,
"feats": feats,
"gh": gh,
"gw": gw,
"processed_h": img_h,
"processed_w": img_w,
}
def click(state, opacity, img_value, evt: gr.SelectData):
"""Handle click events for similarity visualization"""
# If state wasn't prepared (e.g., Example selection), build it now
if state is None and img_value is not None:
state = prepare(img_value)
if not state or evt.index is None:
# Just show whatever is currently in the image component
return img_value, state
base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
x, y = evt.index
px_x, px_y = base.width / gw, base.height / gh
i = min(int(x // px_x), gw - 1)
j = min(int(y // px_y), gh - 1)
d = feats.shape[-1]
grid = F.normalize(feats.reshape(-1, d), dim=1)
v = F.normalize(feats[j, i].reshape(1, d), dim=1)
sims = (grid @ v.T).reshape(gh, gw).numpy()
smin, smax = float(sims.min()), float(sims.max())
heat01 = (sims - smin) / (smax - smin + 1e-12)
box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
overlay = _overlay(base, heat01, alpha=opacity, box=box)
# Add info about resolution being processed
info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches)"
return overlay, state, info_text
def reset():
"""Reset the interface"""
return None, None, ""
with gr.Blocks(
theme=gr.themes.Citrus(),
css="""
.container {max-width: 1200px; margin: auto;}
.header {text-align: center; padding: 20px;}
.info-box {
background: rgba(0,0,0,0.03);
border-radius: 8px;
padding: 12px;
margin: 10px 0;
border-left: 4px solid #2563eb;
}
""",
) as demo:
gr.HTML(
"""
<div class="header">
<h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
<p style="font-size: 1.1em; color: #666;">
Click any region to visualize feature similarities across the image
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=list(MODEL_MAP.keys()),
value=DEFAULT_NAME,
label="Model Selection",
info="Select a model (size/pretraining dataset)",
)
status = gr.Textbox(
label="Model Status",
value=f"Loaded: {DEFAULT_NAME}",
interactive=False,
lines=1,
)
resolution_info = gr.Textbox(
label="Processing Resolution",
value="",
interactive=False,
lines=1,
)
opacity = gr.Slider(
0.0,
1.0,
0.55,
step=0.05,
label="Heatmap Opacity",
info="Balance between image and similarity map",
)
with gr.Row():
reset_btn = gr.Button("Reset", variant="secondary", scale=1)
clear_btn = gr.ClearButton(value="Clear All", scale=1)
with gr.Column(scale=2):
img = gr.Image(
type="pil",
label="Interactive Canvas (Click to explore)",
interactive=True,
height=600,
show_download_button=True,
show_share_button=False,
)
state = gr.State()
model_choice.change(
load_model, inputs=model_choice, outputs=status, show_progress="full"
)
img.upload(prepare, inputs=img, outputs=state)
img.select(
click,
inputs=[state, opacity, img],
outputs=[img, state, resolution_info],
show_progress="minimal",
)
reset_btn.click(reset, outputs=[img, state, resolution_info])
clear_btn.add([img, state, resolution_info])
# Examples from current directory
example_files = [
f.name
for f in Path.cwd().iterdir()
if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
]
if example_files:
gr.Examples(
examples=[[f] for f in example_files],
inputs=img,
fn=prepare,
outputs=[state],
label="Example Images",
examples_per_page=4,
cache_examples=False,
)
gr.Markdown(
f"""
---
<div style="text-align: center; color: #666; font-size: 0.9em;">
<b>Dynamic Resolution:</b> Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio.
DINOv3's 3D RoPE embeddings handle variable sizes.
<br><br>
<b>Performance Notes:</b> Satellite models are intended for: geographic patterns, land use classification,
and structural analysis. The 7B model provides exceptional detail at the codt of high memory usage.
<br><br>
</div>
"""
)
if __name__ == "__main__":
demo.launch(share=False, debug=True)