pszemraj's picture
Update app.py
89e989d verified
raw
history blame
8.06 kB
import gc
from pathlib import Path
import gradio as gr
import matplotlib.cm as cm
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
from transformers import AutoImageProcessor, AutoModel
# Device configuration with memory management
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_MAP = {
"DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m",
"DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
}
DEFAULT_NAME = list(MODEL_MAP.keys())[0]
# Global model state
processor = None
model = None
def cleanup_memory():
"""Aggressive memory cleanup for model switching"""
global processor, model
if model is not None:
del model
model = None
if processor is not None:
del processor
processor = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# torch.cuda.synchronize()
def load_model(name):
"""Load model with CORRECT dtype"""
global processor, model
cleanup_memory()
model_id = MODEL_MAP[name]
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(
model_id,
torch_dtype="auto",
).eval()
param_count = sum(p.numel() for p in model.parameters()) / 1e9
return f"Loaded: {name} | {param_count:.1f}B params | Ready"
# Initialize default model
load_model(DEFAULT_NAME)
@spaces.GPU(duration=60)
def _extract_grid(img):
"""Extract feature grid from image"""
global model
with torch.inference_mode():
# Move model to GPU for this call
model = model.to('cuda')
# Process image and move to GPU
pv = processor(images=img, return_tensors="pt").pixel_values.to(model.device)
# Run inference
out = model(pixel_values=pv)
last = out.last_hidden_state[0].to(torch.float32)
# Extract features
num_reg = getattr(model.config, "num_register_tokens", 0)
p = model.config.patch_size
_, _, Ht, Wt = pv.shape
gh, gw = Ht // p, Wt // p
feats = last[1 + num_reg:, :].reshape(gh, gw, -1).cpu()
# Move model back to CPU before function exits
model = model.cpu()
torch.cuda.empty_cache()
return feats, gh, gw
def _overlay(orig, heat01, alpha=0.55, box=None):
"""Create heatmap overlay"""
H, W = orig.height, orig.width
heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
# Use turbo colormap - better for satellite imagery
rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
ov = Image.fromarray(rgba, "RGBA")
ov.putalpha(int(alpha * 255))
base = orig.copy().convert("RGBA")
out = Image.alpha_composite(base, ov)
if box:
from PIL import ImageDraw
draw = ImageDraw.Draw(out, "RGBA")
# Enhanced box visualization
draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
draw.rectangle(
(box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
outline=(0, 0, 0, 200),
width=1,
)
return out
def prepare(img):
"""Prepare image and extract features"""
if img is None:
return None
base = ImageOps.exif_transpose(img.convert("RGB"))
feats, gh, gw = _extract_grid(base)
return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
def click(state, opacity, img_value, evt: gr.SelectData):
"""Handle click events for similarity visualization"""
# If state wasn't prepared (e.g., Example selection), build it now
if state is None and img_value is not None:
state = prepare(img_value)
if not state or evt.index is None:
# Just show whatever is currently in the image component
return img_value, state
base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
x, y = evt.index
px_x, px_y = base.width / gw, base.height / gh
i = min(int(x // px_x), gw - 1)
j = min(int(y // px_y), gh - 1)
d = feats.shape[-1]
grid = F.normalize(feats.reshape(-1, d), dim=1)
v = F.normalize(feats[j, i].reshape(1, d), dim=1)
sims = (grid @ v.T).reshape(gh, gw).numpy()
smin, smax = float(sims.min()), float(sims.max())
heat01 = (sims - smin) / (smax - smin + 1e-12)
box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
overlay = _overlay(base, heat01, alpha=opacity, box=box)
return overlay, state
def reset():
"""Reset the interface"""
return None, None
with gr.Blocks(
theme=gr.themes.Citrus(),
css="""
.container {max-width: 1200px; margin: auto;}
.header {text-align: center; padding: 20px;}
.info-box {
background: rgba(0,0,0,0.03);
border-radius: 8px;
padding: 12px;
margin: 10px 0;
border-left: 4px solid #2563eb;
}
""",
) as demo:
gr.HTML(
"""
<div class="header">
<h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
<p style="font-size: 1.1em; color: #666;">
Click any region to visualize feature similarities across the image
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=list(MODEL_MAP.keys()),
value=DEFAULT_NAME,
label="Model Selection",
info="Select a model (size/pretraining dataset)",
)
status = gr.Textbox(
label="Model Status",
value=f"Loaded: {DEFAULT_NAME}",
interactive=False,
lines=1,
)
opacity = gr.Slider(
0.0,
1.0,
0.55,
step=0.05,
label="Heatmap Opacity",
info="Balance between image and similarity map",
)
with gr.Row():
reset_btn = gr.Button("Reset", variant="secondary", scale=1)
clear_btn = gr.ClearButton(value="Clear All", scale=1)
with gr.Column(scale=2):
img = gr.Image(
type="pil",
label="Interactive Canvas (Click to explore)",
interactive=True,
height=600,
show_download_button=True,
show_share_button=False,
)
state = gr.State()
model_choice.change(
load_model, inputs=model_choice, outputs=status, show_progress="full"
)
img.upload(prepare, inputs=img, outputs=state)
img.select(
click,
inputs=[state, opacity, img],
outputs=[img, state],
show_progress="minimal",
)
reset_btn.click(reset, outputs=[img, state])
clear_btn.add([img, state])
# Examples from current directory
example_files = [
f.name
for f in Path.cwd().iterdir()
if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
]
if example_files:
gr.Examples(
examples=[[f] for f in example_files],
inputs=img,
fn=prepare,
outputs=[state],
label="Example Images",
examples_per_page=4,
cache_examples=False,
)
gr.Markdown(
"""
---
<div style="text-align: center; color: #666; font-size: 0.9em;">
<b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification,
and structural analysis. The 7B model provides exceptional detail but requires significant compute.
<br><br>
Built with DINOv3 | Optimized for satellite and aerial imagery analysis
</div>
"""
)
if __name__ == "__main__":
demo.launch(share=False, debug=True)