Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,296 +7,224 @@ import numpy as np
|
|
| 7 |
import spaces
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
-
from PIL import Image,
|
| 11 |
from transformers import AutoImageProcessor, AutoModel
|
| 12 |
|
| 13 |
-
#
|
| 14 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
|
| 16 |
MODEL_MAP = {
|
| 17 |
-
"DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m",
|
| 18 |
-
"DINOv3 ViT-L/16 LVD (
|
| 19 |
-
|
| 20 |
}
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
-
#
|
| 25 |
processor = None
|
| 26 |
model = None
|
| 27 |
|
| 28 |
|
| 29 |
-
# --- Core Functions ---
|
| 30 |
def cleanup_memory():
|
| 31 |
-
"""
|
| 32 |
global processor, model
|
|
|
|
| 33 |
if model is not None:
|
| 34 |
del model
|
|
|
|
|
|
|
| 35 |
if processor is not None:
|
| 36 |
del processor
|
| 37 |
-
|
|
|
|
| 38 |
gc.collect()
|
|
|
|
| 39 |
if torch.cuda.is_available():
|
| 40 |
torch.cuda.empty_cache()
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
-
def load_model(name
|
| 44 |
-
"""
|
| 45 |
global processor, model
|
|
|
|
| 46 |
try:
|
|
|
|
| 47 |
cleanup_memory()
|
|
|
|
| 48 |
model_id = MODEL_MAP[name]
|
| 49 |
|
|
|
|
| 50 |
processor = AutoImageProcessor.from_pretrained(model_id)
|
|
|
|
|
|
|
| 51 |
model = (
|
| 52 |
-
AutoModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
|
|
|
|
| 55 |
param_count = sum(p.numel() for p in model.parameters()) / 1e9
|
| 56 |
-
dtype_str = str(next(model.parameters()).dtype).split(".")[-1]
|
| 57 |
|
| 58 |
-
return f"
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
cleanup_memory()
|
| 61 |
-
return f"
|
| 62 |
-
|
| 63 |
|
| 64 |
-
@spaces.GPU(duration=60)
|
| 65 |
-
def _extract_grid(img: Image.Image):
|
| 66 |
-
"""Extracts a grid of feature vectors from an image."""
|
| 67 |
-
with torch.inference_mode():
|
| 68 |
-
inputs = processor(images=img, return_tensors="pt").to(DEVICE)
|
| 69 |
-
outputs = model(**inputs)
|
| 70 |
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
# Correctly calculate grid dimensions from model config
|
| 74 |
-
p = model.config.patch_size
|
| 75 |
-
h, w = inputs.pixel_values.shape[-2] // p, inputs.pixel_values.shape[-1] // p
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
-
|
|
|
|
| 83 |
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
colormap: str,
|
| 90 |
-
box: tuple = None,
|
| 91 |
-
):
|
| 92 |
-
"""Overlays a heatmap on the base image with a specified colormap and opacity."""
|
| 93 |
-
h, w = base_img.height, base_img.width
|
| 94 |
|
| 95 |
-
|
| 96 |
-
heatmap = Image.fromarray((heatmap_01 * 255).astype(np.uint8)).resize(
|
| 97 |
-
(w, h), resample=Image.LANCZOS
|
| 98 |
-
)
|
| 99 |
-
cmap_func = cm.get_cmap(colormap.lower())
|
| 100 |
-
rgba_heatmap = (cmap_func(np.asarray(heatmap) / 255.0) * 255).astype(np.uint8)
|
| 101 |
|
| 102 |
-
|
| 103 |
-
overlay = Image.fromarray(rgba_heatmap, "RGBA")
|
| 104 |
-
overlay.putalpha(int(opacity * 255))
|
| 105 |
|
| 106 |
-
# Composite overlay onto a copy of the base image
|
| 107 |
-
out_img = Image.alpha_composite(base_img.copy().convert("RGBA"), overlay)
|
| 108 |
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if box:
|
| 111 |
-
|
| 112 |
-
# Draw a thick white border with a thin black outline for visibility
|
| 113 |
-
draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
|
| 114 |
-
draw.rectangle(
|
| 115 |
-
(box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
|
| 116 |
-
outline=(0, 0, 0, 200),
|
| 117 |
-
width=1,
|
| 118 |
-
)
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
"""Prepares the image by extracting features and storing them in the state."""
|
| 126 |
if img is None:
|
| 127 |
-
return None
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
| 130 |
-
feats, gh, gw = _extract_grid(base_img)
|
| 131 |
|
| 132 |
-
state = {"orig": base_img, "feats": feats, "gh": gh, "gw": gw}
|
| 133 |
-
return state, "Click on the image to compute similarity."
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
def on_click(state: dict, opacity: float, colormap: str, evt: gr.SelectData):
|
| 137 |
-
"""Handles click events to compute and display the similarity heatmap."""
|
| 138 |
if not state or evt.index is None:
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
|
| 142 |
|
| 143 |
-
# Calculate patch index from click coordinates
|
| 144 |
x, y = evt.index
|
| 145 |
-
|
| 146 |
-
i = min(int(x //
|
| 147 |
-
j = min(int(y //
|
| 148 |
|
| 149 |
-
# Compute similarity
|
| 150 |
d = feats.shape[-1]
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
sims = (
|
| 154 |
|
| 155 |
-
# Normalize similarity map to [0, 1] for visualization
|
| 156 |
smin, smax = float(sims.min()), float(sims.max())
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# Define selection box coordinates
|
| 160 |
-
box = (int(i * px_w), int(j * px_h), int((i + 1) * px_w), int((j + 1) * px_h))
|
| 161 |
-
|
| 162 |
-
# Generate overlay image
|
| 163 |
-
output_img = _overlay_heatmap(base, heatmap_01, opacity, colormap, box)
|
| 164 |
-
|
| 165 |
-
# Create statistics string
|
| 166 |
-
stats = f"""📊 **Similarity Statistics**
|
| 167 |
-
- **Min**: `{smin:.3f}`
|
| 168 |
-
- **Max**: `{smax:.3f}`
|
| 169 |
-
- **Range**: `{smax - smin:.3f}`
|
| 170 |
-
- **Patch Index**: `({i}, {j})`
|
| 171 |
-
- **Grid Size**: `{gw}×{gh}`"""
|
| 172 |
-
|
| 173 |
-
return output_img, state, stats
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def reset_overlay(state: dict):
|
| 177 |
-
"""Resets the image to its original state, removing the heatmap."""
|
| 178 |
-
if state and "orig" in state:
|
| 179 |
-
return state["orig"], state, "Overlay reset. Click the image again."
|
| 180 |
-
return None, None, "Upload an image to begin."
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
# --- Gradio UI ---
|
| 184 |
-
with gr.Blocks(
|
| 185 |
-
theme=gr.themes.Soft(primary_hue="blue"),
|
| 186 |
-
css=".container {max-width: 1200px; margin: auto;}",
|
| 187 |
-
) as demo:
|
| 188 |
-
gr.HTML(
|
| 189 |
-
"""
|
| 190 |
-
<div style="text-align: center; padding: 20px;">
|
| 191 |
-
<h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
|
| 192 |
-
<p style="font-size: 1.1em; color: #666;">Explore how DINOv3 models trained on satellite imagery understand visual patterns</p>
|
| 193 |
-
</div>
|
| 194 |
-
"""
|
| 195 |
-
)
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
4. **Adjust visualization** - Fine-tune opacity and colormap for clarity.
|
| 206 |
-
"""
|
| 207 |
-
)
|
| 208 |
-
with gr.Column(scale=2):
|
| 209 |
-
gr.HTML(
|
| 210 |
-
"""
|
| 211 |
-
<div style="background: rgba(0,0,0,0.03); border-radius: 8px; padding: 12px; border-left: 4px solid #2563eb;">
|
| 212 |
-
<b>💡 Model Info:</b><br>
|
| 213 |
-
• <b>Satellite models</b>: Trained on 493M satellite images.<br>
|
| 214 |
-
• <b>LVD model</b>: Trained on 1.7B diverse images.<br>
|
| 215 |
-
• <b>7B model</b>: Massive capacity, slower but more nuanced.
|
| 216 |
-
</div>
|
| 217 |
-
"""
|
| 218 |
-
)
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
model_choice = gr.Dropdown(
|
| 223 |
-
choices=list(MODEL_MAP.keys()),
|
| 224 |
-
value=DEFAULT_MODEL_NAME,
|
| 225 |
-
label="🤖 Model Selection",
|
| 226 |
)
|
| 227 |
status = gr.Textbox(
|
| 228 |
-
label="
|
| 229 |
-
value="Loading initial model...",
|
| 230 |
-
interactive=False,
|
| 231 |
)
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
opacity = gr.Slider(
|
| 235 |
-
0.2, 0.9, 0.55, step=0.05, label="🎨 Heatmap Opacity"
|
| 236 |
-
)
|
| 237 |
-
colormap = gr.Dropdown(
|
| 238 |
-
["Turbo", "Inferno", "Viridis", "Plasma", "Jet"],
|
| 239 |
-
value="Turbo",
|
| 240 |
-
label="🌈 Colormap",
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
info_panel = gr.Markdown(
|
| 244 |
-
value="*Upload an image and click on it to see statistics here.*",
|
| 245 |
-
label="Statistics",
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
with gr.Row():
|
| 249 |
-
reset_btn = gr.Button("🔄 Reset Overlay")
|
| 250 |
-
# A ClearButton is simpler for clearing multiple components
|
| 251 |
-
clear_btn = gr.ClearButton(value="🗑️ Clear All")
|
| 252 |
-
|
| 253 |
-
with gr.Column(scale=2):
|
| 254 |
img = gr.Image(
|
| 255 |
-
type="pil",
|
| 256 |
-
label="Interactive Canvas (Click to explore)",
|
| 257 |
-
height=600,
|
| 258 |
-
show_download_button=True,
|
| 259 |
)
|
| 260 |
|
| 261 |
-
# Define a state object to hold persistent data (original image, features)
|
| 262 |
state = gr.State()
|
| 263 |
|
| 264 |
-
|
| 265 |
-
# Make sure these images are in your repository.
|
| 266 |
-
gr.Examples(
|
| 267 |
-
examples=[
|
| 268 |
-
["examples/satellite_city.jpg"],
|
| 269 |
-
["examples/coastal_area.png"],
|
| 270 |
-
["examples/farmland.webp"],
|
| 271 |
-
],
|
| 272 |
-
inputs=[img],
|
| 273 |
-
outputs=[state, info_panel],
|
| 274 |
-
fn=prepare,
|
| 275 |
-
cache_examples=torch.cuda.is_available(), # Cache on GPU instances
|
| 276 |
-
)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
demo.load(lambda: load_model(DEFAULT_MODEL_NAME), outputs=status)
|
| 280 |
-
|
| 281 |
-
model_choice.change(
|
| 282 |
-
load_model, inputs=model_choice, outputs=status, show_progress="full"
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
img.upload(
|
| 286 |
-
prepare, inputs=img, outputs=[state, info_panel], show_progress="minimal"
|
| 287 |
-
)
|
| 288 |
|
| 289 |
img.select(
|
| 290 |
-
|
| 291 |
-
inputs=[state, opacity,
|
| 292 |
-
outputs=[img, state
|
| 293 |
show_progress="minimal",
|
| 294 |
)
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
#
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
demo.launch(
|
|
|
|
| 7 |
import spaces
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
+
from PIL import Image, ImageOps
|
| 11 |
from transformers import AutoImageProcessor, AutoModel
|
| 12 |
|
| 13 |
+
# Device configuration with memory management
|
| 14 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
|
| 16 |
MODEL_MAP = {
|
| 17 |
+
"DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m",
|
| 18 |
+
"DINOv3 ViT-L/16 LVD (1.7B)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
|
| 19 |
+
"⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
|
| 20 |
}
|
| 21 |
|
| 22 |
+
DEFAULT_NAME = list(MODEL_MAP.keys())[0]
|
| 23 |
|
| 24 |
+
# Global model state
|
| 25 |
processor = None
|
| 26 |
model = None
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
def cleanup_memory():
|
| 30 |
+
"""Aggressive memory cleanup for model switching"""
|
| 31 |
global processor, model
|
| 32 |
+
|
| 33 |
if model is not None:
|
| 34 |
del model
|
| 35 |
+
model = None
|
| 36 |
+
|
| 37 |
if processor is not None:
|
| 38 |
del processor
|
| 39 |
+
processor = None
|
| 40 |
+
|
| 41 |
gc.collect()
|
| 42 |
+
|
| 43 |
if torch.cuda.is_available():
|
| 44 |
torch.cuda.empty_cache()
|
| 45 |
+
torch.cuda.synchronize()
|
| 46 |
|
| 47 |
|
| 48 |
+
def load_model(name):
|
| 49 |
+
"""Load model with proper memory management and dtype handling"""
|
| 50 |
global processor, model
|
| 51 |
+
|
| 52 |
try:
|
| 53 |
+
# Clean up existing model
|
| 54 |
cleanup_memory()
|
| 55 |
+
|
| 56 |
model_id = MODEL_MAP[name]
|
| 57 |
|
| 58 |
+
# Load processor
|
| 59 |
processor = AutoImageProcessor.from_pretrained(model_id)
|
| 60 |
+
|
| 61 |
+
# Load model with auto dtype for optimal performance
|
| 62 |
model = (
|
| 63 |
+
AutoModel.from_pretrained(
|
| 64 |
+
model_id,
|
| 65 |
+
torch_dtype="auto",
|
| 66 |
+
)
|
| 67 |
+
.to(DEVICE)
|
| 68 |
+
.eval()
|
| 69 |
)
|
| 70 |
|
| 71 |
+
# Get model info
|
| 72 |
param_count = sum(p.numel() for p in model.parameters()) / 1e9
|
|
|
|
| 73 |
|
| 74 |
+
return f"Loaded: {name} | {param_count:.1f}B params | {DEVICE.upper()}"
|
| 75 |
+
|
| 76 |
except Exception as e:
|
| 77 |
cleanup_memory()
|
| 78 |
+
return f"Failed to load {name}: {str(e)}"
|
|
|
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# Initialize default model
|
| 82 |
+
load_model(DEFAULT_NAME)
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
@spaces.GPU(duration=60)
|
| 86 |
+
def _extract_grid(img):
|
| 87 |
+
"""Extract feature grid from image"""
|
| 88 |
+
with torch.inference_mode():
|
| 89 |
+
pv = processor(images=img, return_tensors="pt").pixel_values
|
| 90 |
|
| 91 |
+
if DEVICE == "cuda":
|
| 92 |
+
pv = pv.to(DEVICE)
|
| 93 |
|
| 94 |
+
out = model(pixel_values=pv)
|
| 95 |
+
last = out.last_hidden_state[0].to(torch.float32)
|
| 96 |
|
| 97 |
+
num_reg = getattr(model.config, "num_register_tokens", 0)
|
| 98 |
+
p = model.config.patch_size
|
| 99 |
+
_, _, Ht, Wt = pv.shape
|
| 100 |
+
gh, gw = Ht // p, Wt // p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
return feats, gh, gw
|
|
|
|
|
|
|
| 105 |
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
def _overlay(orig, heat01, alpha=0.55, box=None):
|
| 108 |
+
"""Create heatmap overlay"""
|
| 109 |
+
H, W = orig.height, orig.width
|
| 110 |
+
heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
|
| 111 |
+
rgba = (cm.get_cmap("inferno")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
|
| 112 |
+
ov = Image.fromarray(rgba, "RGBA")
|
| 113 |
+
ov.putalpha(int(alpha * 255))
|
| 114 |
+
base = orig.copy().convert("RGBA")
|
| 115 |
+
out = Image.alpha_composite(base, ov)
|
| 116 |
if box:
|
| 117 |
+
from PIL import ImageDraw
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
ImageDraw.Draw(out, "RGBA").rectangle(
|
| 120 |
+
box, outline=(255, 255, 255, 220), width=2
|
| 121 |
+
)
|
| 122 |
+
return out
|
| 123 |
|
| 124 |
|
| 125 |
+
def prepare(img):
|
| 126 |
+
"""Prepare image and extract features"""
|
|
|
|
| 127 |
if img is None:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
base = ImageOps.exif_transpose(img.convert("RGB"))
|
| 131 |
+
feats, gh, gw = _extract_grid(base)
|
| 132 |
|
| 133 |
+
return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
|
|
|
|
| 134 |
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
def click(state, opacity, img_value, evt: gr.SelectData):
|
| 137 |
+
"""Handle click events for similarity visualization"""
|
| 138 |
+
# If state wasn't prepared (e.g., Example selection), build it now
|
| 139 |
+
if state is None and img_value is not None:
|
| 140 |
+
state = prepare(img_value)
|
| 141 |
|
|
|
|
|
|
|
| 142 |
if not state or evt.index is None:
|
| 143 |
+
# Just show whatever is currently in the image component
|
| 144 |
+
return img_value, state
|
| 145 |
|
| 146 |
base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
|
| 147 |
|
|
|
|
| 148 |
x, y = evt.index
|
| 149 |
+
px_x, px_y = base.width / gw, base.height / gh
|
| 150 |
+
i = min(int(x // px_x), gw - 1)
|
| 151 |
+
j = min(int(y // px_y), gh - 1)
|
| 152 |
|
|
|
|
| 153 |
d = feats.shape[-1]
|
| 154 |
+
grid = F.normalize(feats.reshape(-1, d), dim=1)
|
| 155 |
+
v = F.normalize(feats[j, i].reshape(1, d), dim=1)
|
| 156 |
+
sims = (grid @ v.T).reshape(gh, gw).numpy()
|
| 157 |
|
|
|
|
| 158 |
smin, smax = float(sims.min()), float(sims.max())
|
| 159 |
+
heat01 = (sims - smin) / (smax - smin + 1e-12)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
|
| 162 |
+
overlay = _overlay(base, heat01, alpha=opacity, box=box)
|
| 163 |
+
return overlay, state
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def reset():
|
| 167 |
+
"""Reset the interface"""
|
| 168 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
|
| 171 |
+
with gr.Blocks() as demo:
|
| 172 |
+
gr.Markdown("## DINOv3: patch similarity visualizer")
|
| 173 |
+
gr.Markdown(
|
| 174 |
+
"This is an app where you can upload an image, click on an object in the image and get the most similar patches to it according to DINOv3, revealing the way DINOv3 segments objects through features natively."
|
| 175 |
+
)
|
| 176 |
+
gr.Markdown("There's multiple model options you can pick from the dropdown.")
|
| 177 |
+
gr.Markdown(
|
| 178 |
+
"Please click Reset before you want to upload another image, as this app keeps features of the images."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
with gr.Column():
|
| 182 |
+
with gr.Row(scale=1):
|
| 183 |
model_choice = gr.Dropdown(
|
| 184 |
+
choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model"
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
status = gr.Textbox(
|
| 187 |
+
label="Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False
|
|
|
|
|
|
|
| 188 |
)
|
| 189 |
+
opacity = gr.Slider(0.0, 1.0, 0.55, step=0.05, label="Opacity for the Map")
|
| 190 |
|
| 191 |
+
with gr.Row(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
img = gr.Image(
|
| 193 |
+
type="pil", label="Image", interactive=True, height=750, width=750
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
|
|
|
|
| 196 |
state = gr.State()
|
| 197 |
|
| 198 |
+
model_choice.change(load_model, inputs=model_choice, outputs=status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
img.upload(prepare, inputs=img, outputs=state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
img.select(
|
| 203 |
+
click,
|
| 204 |
+
inputs=[state, opacity, img],
|
| 205 |
+
outputs=[img, state],
|
| 206 |
show_progress="minimal",
|
| 207 |
)
|
| 208 |
|
| 209 |
+
gr.Button("Reset").click(reset, outputs=[img, state])
|
| 210 |
+
|
| 211 |
+
# Examples from current directory
|
| 212 |
+
example_files = [
|
| 213 |
+
str(f)
|
| 214 |
+
for f in Path.cwd().iterdir()
|
| 215 |
+
if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
if example_files:
|
| 219 |
+
gr.Examples(
|
| 220 |
+
examples=[[f] for f in example_files],
|
| 221 |
+
inputs=img,
|
| 222 |
+
fn=prepare,
|
| 223 |
+
outputs=[state],
|
| 224 |
+
label="Try an example image and then click on the objects.",
|
| 225 |
+
examples_per_page=4,
|
| 226 |
+
cache_examples=False,
|
| 227 |
+
)
|
| 228 |
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
demo.launch()
|