import os import matplotlib.pyplot as plt import matplotlib.cm as cm import numpy as np import gradio as gr from transformers import AutoModel, AutoImageProcessor from PIL import Image import torch os.environ["HF_HUB_OFFLINE"] = "0" # Global state to store loaded model + processors state = { "model_type": None, "model": None, "processor": None, "repo_id": None, } def similarity_heatmap(image): """ Compute cosine similarity between CLS token and patch tokens """ model, processor = state["model"], state["processor"] inputs = processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(model.device) # shape: (1, 3, H, W) # get ViT patch size (from model config) patch_size = model.config.patch_size # usually 16 # Compute patch grid (needed for resizing later) H_patch = pixel_values.shape[2] // patch_size W_patch = pixel_values.shape[3] // patch_size with torch.no_grad(): outputs = model(pixel_values) # last_hidden_state: (1, seq_len, hidden_dim) last_hidden_state = outputs.last_hidden_state cls_token = last_hidden_state[:, 0, :] # shape: (1, hidden_dim) patch_tokens = last_hidden_state[:, 1:, :] # shape: (1, num_patches, hidden_dim) cls_norm = cls_token / cls_token.norm(dim=-1, keepdim=True) patch_norm = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True) cos_sim = torch.einsum("bd,bpd->bp", cls_norm, patch_norm) # shape: (1, num_patches) cos_sim = cos_sim.reshape((H_patch, W_patch)) return np.array(cos_sim) def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha=0.5, colormap="viridis"): """ cos_grid: (H_patch, W_patch) numpy array of cosine similarities image: PIL.Image alpha: blending factor colormap: matplotlib colormap name """ # Normalize cosine values to [0, 1] for colormap norm_grid = (cos_grid - cos_grid.min()) / (cos_grid.max() - cos_grid.min() + 1e-8) # Apply colormap cmap = cm.get_cmap(colormap) heatmap_rgba = cmap(norm_grid) # shape: (H_patch, W_patch, 4) # Convert to RGB 0-255 heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) heatmap_img = Image.fromarray(heatmap_rgb) # Resize heatmap to match original image size heatmap_resized = heatmap_img.resize(image.size, resample=Image.BILINEAR) # Blend with original image blended = Image.blend(image.convert("RGBA"), heatmap_resized.convert("RGBA"), alpha=alpha) return blended def load_model(repo_id: str, revision: str = None): """ Load a Hugging Face model + processor from Hub. Works with any public repo_id. """ try: # Clean up inputs repo_id = repo_id.strip() if not repo_id: return "Please enter a model repo ID" if revision and revision.strip() == "": revision = None # First try without cache_dir to avoid permission issues try: model = AutoModel.from_pretrained( repo_id, revision=revision, trust_remote_code=True, use_auth_token=False # Explicitly no auth for public models ) processor = AutoImageProcessor.from_pretrained( repo_id, revision=revision, trust_remote_code=True, use_auth_token=False ) except Exception as e1: # If that fails, try with explicit cache directory model = AutoModel.from_pretrained( repo_id, revision=revision, cache_dir="/tmp/model_cache", # Use /tmp for better permissions trust_remote_code=True, use_auth_token=False, local_files_only=False # Ensure we can download ) processor = AutoImageProcessor.from_pretrained( repo_id, revision=revision, cache_dir="/tmp/model_cache", trust_remote_code=True, use_auth_token=False, local_files_only=False ) # Move to appropriate device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() # Validate it's a Vision Transformer if not hasattr(model.config, 'patch_size'): return f"Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)" # Update global state state["model"] = model state["processor"] = processor state["repo_id"] = repo_id state["model_type"] = "custom" patch_size = model.config.patch_size return f"Successfully loaded ViT model '{repo_id}' (patch size: {patch_size}) on {device}" except Exception as e: error_str = str(e).lower() if "repository not found" in error_str or "404" in error_str: return f"Repository '{repo_id}' not found. Please check the repo ID." elif "connection" in error_str or "network" in error_str or "offline" in error_str: return f"Network error: {str(e)}" elif "permission" in error_str or "forbidden" in error_str: return f"Permission denied. This might be a private repository." else: return f"Error loading model: {str(e)}" def display_image(image: Image): """ Simply returns the uploaded image. """ return image def visualize_cosine_heatmap(image: Image): """ Generate and overlay cosine similarity heatmap on the input image. """ if state["model"] is None: return None # Return None if no model is loaded try: cos_grid = similarity_heatmap(image) blended = overlay_cosine_grid_on_image(cos_grid, image) return blended except Exception as e: print(f"Error generating heatmap: {e}") return None # Gradio UI with gr.Blocks(title="ViT CLS Visualizer") as demo: gr.Markdown("# ViT CLS-Visualizer") gr.Markdown( "Enter the Hugging Face model repo ID (must be public), upload an image, " "and visualize the cosine similarity between the CLS token and patches." ) gr.Markdown("### Popular Vision Transformer models to try:") gr.Markdown( "- `google/vit-base-patch16-224`\n" "- `facebook/deit-base-distilled-patch16-224`\n" "- `microsoft/dit-base`" ) with gr.Row(): repo_input = gr.Textbox( label="Hugging Face Model Repo ID", placeholder="e.g. google/vit-base-patch16-224", value="google/vit-base-patch16-224" ) revision_input = gr.Textbox( label="Revision (optional)", placeholder="branch, tag, or commit hash" ) load_btn = gr.Button("Load Model", variant="primary") load_status = gr.Textbox(label="Model Status", interactive=False) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") image_output = gr.Image(label="Uploaded Image") with gr.Column(): compute_btn = gr.Button("Compute Heatmap", variant="primary") heatmap_output = gr.Image(label="Cosine Similarity Heatmap") # Events load_btn.click( fn=load_model, inputs=[repo_input, revision_input], outputs=load_status ) image_input.change( fn=display_image, inputs=image_input, outputs=image_output ) compute_btn.click( fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output ) if __name__ == "__main__": demo.launch()