Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import gc | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image, ImageDraw | |
| from matplotlib import cm | |
| from model_utils.extractor_dino import ViTExtractor | |
| from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork | |
| def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'): | |
| filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST | |
| original_width, original_height = img.size | |
| original_channels = len(img.getbands()) | |
| if not edge: | |
| canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8) | |
| if original_channels == 1: | |
| canvas = np.zeros([target_res, target_res], dtype=np.uint8) | |
| if original_height <= original_width: | |
| if resize: | |
| img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt) | |
| width, height = img.size | |
| img = np.asarray(img) | |
| canvas[(width - height) // 2: (width + height) // 2] = img | |
| else: | |
| if resize: | |
| img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt) | |
| width, height = img.size | |
| img = np.asarray(img) | |
| canvas[:, (height - width) // 2: (height + width) // 2] = img | |
| else: | |
| if original_height <= original_width: | |
| if resize: | |
| img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt) | |
| width, height = img.size | |
| img = np.asarray(img) | |
| top_pad = (target_res - height) // 2 | |
| bottom_pad = target_res - height - top_pad | |
| img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge') | |
| else: | |
| if resize: | |
| img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt) | |
| width, height = img.size | |
| img = np.asarray(img) | |
| left_pad = (target_res - width) // 2 | |
| right_pad = target_res - width - left_pad | |
| img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge') | |
| canvas = img | |
| if to_pil: | |
| canvas = Image.fromarray(canvas) | |
| return canvas | |
| # βββ Feature extraction ββββββββββββββββββββββββββββββββββββββββββ | |
| def get_processed_features_dino(num_patches, img,use_dummy): | |
| with torch.no_grad(): | |
| batch = extractor_vit.preprocess_pil(img) | |
| features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \ | |
| .permute(0,1,3,2) \ | |
| .reshape(1, -1, num_patches, num_patches) | |
| if use_dummy == "DINOv2": | |
| desc = aggre_net_dummy(features_dino) | |
| else: | |
| desc = aggre_net(features_dino) | |
| norms = torch.linalg.norm(desc, dim=1, keepdim=True) | |
| desc = desc / (norms + 1e-8) | |
| desc = desc.cpu().detach() | |
| del batch, features_dino | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return desc # shape [1, C, num_patches, num_patches] | |
| # βββ Similarity computation βββββββββββββββββββββββββββββββββββββββ | |
| def get_sim( | |
| coord: tuple[int,int], | |
| feat1: torch.Tensor, | |
| feat2: torch.Tensor, | |
| img_size: int = 420 | |
| ) -> np.ndarray: | |
| """ | |
| Upsamples the DINO features to `img_size`, then computes cosineβsimilarity | |
| between the feature at `coord` in source and every spatial location in target. | |
| """ | |
| y, x = coord # row, col | |
| # Upsample both feature maps to [1, C, img_size, img_size] | |
| src_ft = upsampler(feat1) # [1, C, img_size, img_size] | |
| trg_ft = upsampler(feat2) | |
| # Extract the Cβdim vector at the clicked location | |
| C = src_ft.size(1) | |
| src_vec = src_ft[0, :, y, x].view(1, C, 1, 1) # [1, C, 1, 1] | |
| # Cosine similarity along channelβdim | |
| cos = nn.CosineSimilarity(dim=1) | |
| cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size] | |
| return cos_map.cpu().numpy() | |
| # βββ Drawing helper βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray: | |
| pil = Image.fromarray(img_arr) | |
| draw = ImageDraw.Draw(pil) | |
| r = size // 2 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color) | |
| return np.array(pil) | |
| # βββ Featureβupdating callback βββββββββββββββββββββββββββββββββββ | |
| def update_features( | |
| img: Image, | |
| num_patches, | |
| use_dummy | |
| ): | |
| torch.cuda.empty_cache() | |
| """ | |
| Given a PIL image, returns: | |
| 1) the same PIL image (so it can be displayed) | |
| 2) its DINO descriptor tensor, stored in a gr.State | |
| """ | |
| if img is None: | |
| return None, None, None | |
| img = resize(img, target_res=target_res, resize=True, to_pil=True) | |
| feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy) | |
| return img, feat.cpu(), Image.fromarray(np.array(img)) | |
| # βββ Click handler βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_select( | |
| source_pil: Image, | |
| target_pil: Image, | |
| feat1: torch.Tensor, | |
| feat2: torch.Tensor, | |
| alpha: float, | |
| scatter_size: int, | |
| or_tgt_img: Image, | |
| or_src_img: Image, | |
| sel: gr.SelectData | |
| ): | |
| # Convert to numpy arrays | |
| src_arr = np.array(or_src_img) | |
| tgt_arr = np.array(or_tgt_img) | |
| # Get click coords (row, col) | |
| y, x = sel.index | |
| src_marked = draw_point(src_arr, y, x, scatter_size) | |
| # Compute similarity map | |
| sim_map = get_sim((x, y), feat1, feat2, img_size=target_res) | |
| mn, mx = sim_map.min(), sim_map.max() | |
| sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12) | |
| # Build RGBA heatmap | |
| heat = cm.viridis(sim_norm) # HΓWΓ4 | |
| heat[..., 3] = sim_norm * alpha # alpha channel | |
| # Composite over fresh target | |
| tgt_f = tgt_arr.astype(np.float32) / 255.0 | |
| comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4]) | |
| overlay = (comp * 255).astype(np.uint8) | |
| # Draw a red dot at the best match | |
| my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape) | |
| overlay_marked = draw_point(overlay, mx_, my, scatter_size) | |
| return src_marked,overlay_marked | |
| def reload_img( | |
| or_src_img: Image, | |
| or_tgt_img: Image, | |
| ): | |
| return or_src_img,or_tgt_img | |
| # βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ | |
| num_patches = 30 | |
| target_res = num_patches * 14 | |
| ckpt_file = "./ckpts/dino_spair_0300.pth" | |
| # βββ Model setup βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device) | |
| aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device)) | |
| aggre_net_dummy = DummyAggregationNetwork() | |
| extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device) | |
| aggre_net = aggre_net.eval() | |
| extractor_vit.model.eval() | |
| upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False) | |
| # βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks() as demo: | |
| # Hidden states to hold features | |
| feat1_state = gr.State() | |
| feat2_state = gr.State() | |
| or_tgt_img = gr.State() | |
| or_src_img = gr.State() | |
| # Introduction text box | |
| intro_text = gr.Markdown(""" | |
| ## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels | |
| [Project Page](https://genintel.github.io/DIY-SC) | [GitHub Repository](https://github.com/odunkel/DIY-SC) | |
| Welcome to the DIY-SC demo! | |
| Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image. | |
| You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor. | |
| """) | |
| # Image upload / display components | |
| with gr.Row(): | |
| src = gr.Image(interactive=True, type="pil", label="Source Image") | |
| tgt = gr.Image(interactive=True, type="pil", label="Target Image") | |
| # Controls | |
| alpha = gr.State(0.7) | |
| scatter = gr.State(10) | |
| use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor") | |
| src.input( | |
| fn=update_features, | |
| inputs=[src, gr.State(num_patches),use_dummy], | |
| outputs=[src, feat1_state,or_src_img,] | |
| ) | |
| tgt.input( | |
| fn=update_features, | |
| inputs=[tgt, gr.State(num_patches),use_dummy], | |
| outputs=[tgt, feat2_state,or_tgt_img] | |
| ) | |
| use_dummy.change( | |
| fn=update_features, | |
| inputs=[or_src_img, gr.State(num_patches), use_dummy], | |
| outputs=[src, feat1_state, or_src_img] | |
| ) | |
| use_dummy.change( | |
| fn=update_features, | |
| inputs=[or_tgt_img, gr.State(num_patches), use_dummy], | |
| outputs=[tgt, feat2_state, or_tgt_img] | |
| ) | |
| src.select( | |
| fn=on_select, | |
| inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img], | |
| outputs=[src,tgt] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |