Spaces:
Sleeping
Sleeping
import os | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
import numpy as np | |
import gradio as gr | |
from transformers import AutoModel, AutoImageProcessor | |
from PIL import Image | |
import torch | |
os.environ["HF_HUB_OFFLINE"] = "0" | |
# Global state to store loaded model + processors | |
state = { | |
"model_type": None, | |
"model": None, | |
"processor": None, | |
"repo_id": None, | |
} | |
def similarity_heatmap(image): | |
""" | |
Compute cosine similarity between CLS token and patch tokens | |
""" | |
model, processor = state["model"], state["processor"] | |
inputs = processor(images=image, return_tensors="pt") | |
pixel_values = inputs["pixel_values"].to(model.device) # shape: (1, 3, H, W) | |
# get ViT patch size (from model config) | |
patch_size = model.config.patch_size # usually 16 | |
# Compute patch grid (needed for resizing later) | |
H_patch = pixel_values.shape[2] // patch_size | |
W_patch = pixel_values.shape[3] // patch_size | |
with torch.no_grad(): | |
outputs = model(pixel_values) # last_hidden_state: (1, seq_len, hidden_dim) | |
last_hidden_state = outputs.last_hidden_state | |
cls_token = last_hidden_state[:, 0, :] # shape: (1, hidden_dim) | |
patch_tokens = last_hidden_state[:, 1:, :] # shape: (1, num_patches, hidden_dim) | |
cls_norm = cls_token / cls_token.norm(dim=-1, keepdim=True) | |
patch_norm = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True) | |
cos_sim = torch.einsum("bd,bpd->bp", cls_norm, patch_norm) # shape: (1, num_patches) | |
cos_sim = cos_sim.reshape((H_patch, W_patch)) | |
return np.array(cos_sim) | |
def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha=0.5, colormap="viridis"): | |
""" | |
cos_grid: (H_patch, W_patch) numpy array of cosine similarities | |
image: PIL.Image | |
alpha: blending factor | |
colormap: matplotlib colormap name | |
""" | |
# Normalize cosine values to [0, 1] for colormap | |
norm_grid = (cos_grid - cos_grid.min()) / (cos_grid.max() - cos_grid.min() + 1e-8) | |
# Apply colormap | |
cmap = cm.get_cmap(colormap) | |
heatmap_rgba = cmap(norm_grid) # shape: (H_patch, W_patch, 4) | |
# Convert to RGB 0-255 | |
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) | |
heatmap_img = Image.fromarray(heatmap_rgb) | |
# Resize heatmap to match original image size | |
heatmap_resized = heatmap_img.resize(image.size, resample=Image.BILINEAR) | |
# Blend with original image | |
blended = Image.blend(image.convert("RGBA"), heatmap_resized.convert("RGBA"), alpha=alpha) | |
return blended | |
def load_model(repo_id: str, revision: str = None): | |
""" | |
Load a Hugging Face model + processor from Hub. | |
Works with any public repo_id. | |
""" | |
try: | |
# Clean up inputs | |
repo_id = repo_id.strip() | |
if not repo_id: | |
return "Please enter a model repo ID" | |
if revision and revision.strip() == "": | |
revision = None | |
# First try without cache_dir to avoid permission issues | |
try: | |
model = AutoModel.from_pretrained( | |
repo_id, | |
revision=revision, | |
trust_remote_code=True, | |
use_auth_token=False # Explicitly no auth for public models | |
) | |
processor = AutoImageProcessor.from_pretrained( | |
repo_id, | |
revision=revision, | |
trust_remote_code=True, | |
use_auth_token=False | |
) | |
except Exception as e1: | |
# If that fails, try with explicit cache directory | |
model = AutoModel.from_pretrained( | |
repo_id, | |
revision=revision, | |
cache_dir="/tmp/model_cache", # Use /tmp for better permissions | |
trust_remote_code=True, | |
use_auth_token=False, | |
local_files_only=False # Ensure we can download | |
) | |
processor = AutoImageProcessor.from_pretrained( | |
repo_id, | |
revision=revision, | |
cache_dir="/tmp/model_cache", | |
trust_remote_code=True, | |
use_auth_token=False, | |
local_files_only=False | |
) | |
# Move to appropriate device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
# Validate it's a Vision Transformer | |
if not hasattr(model.config, 'patch_size'): | |
return f"Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)" | |
# Update global state | |
state["model"] = model | |
state["processor"] = processor | |
state["repo_id"] = repo_id | |
state["model_type"] = "custom" | |
patch_size = model.config.patch_size | |
return f"Successfully loaded ViT model '{repo_id}' (patch size: {patch_size}) on {device}" | |
except Exception as e: | |
error_str = str(e).lower() | |
if "repository not found" in error_str or "404" in error_str: | |
return f"Repository '{repo_id}' not found. Please check the repo ID." | |
elif "connection" in error_str or "network" in error_str or "offline" in error_str: | |
return f"Network error: {str(e)}" | |
elif "permission" in error_str or "forbidden" in error_str: | |
return f"Permission denied. This might be a private repository." | |
else: | |
return f"Error loading model: {str(e)}" | |
def display_image(image: Image): | |
""" | |
Simply returns the uploaded image. | |
""" | |
return image | |
def visualize_cosine_heatmap(image: Image): | |
""" | |
Generate and overlay cosine similarity heatmap on the input image. | |
""" | |
if state["model"] is None: | |
return None # Return None if no model is loaded | |
try: | |
cos_grid = similarity_heatmap(image) | |
blended = overlay_cosine_grid_on_image(cos_grid, image) | |
return blended | |
except Exception as e: | |
print(f"Error generating heatmap: {e}") | |
return None | |
# Gradio UI | |
with gr.Blocks(title="ViT CLS Visualizer") as demo: | |
gr.Markdown("# ViT CLS-Visualizer") | |
gr.Markdown( | |
"Enter the Hugging Face model repo ID (must be public), upload an image, " | |
"and visualize the cosine similarity between the CLS token and patches." | |
) | |
gr.Markdown("### Popular Vision Transformer models to try:") | |
gr.Markdown( | |
"- `google/vit-base-patch16-224`\n" | |
"- `facebook/deit-base-distilled-patch16-224`\n" | |
"- `microsoft/dit-base`" | |
) | |
with gr.Row(): | |
repo_input = gr.Textbox( | |
label="Hugging Face Model Repo ID", | |
placeholder="e.g. google/vit-base-patch16-224", | |
value="google/vit-base-patch16-224" | |
) | |
revision_input = gr.Textbox( | |
label="Revision (optional)", | |
placeholder="branch, tag, or commit hash" | |
) | |
load_btn = gr.Button("Load Model", variant="primary") | |
load_status = gr.Textbox(label="Model Status", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
image_output = gr.Image(label="Uploaded Image") | |
with gr.Column(): | |
compute_btn = gr.Button("Compute Heatmap", variant="primary") | |
heatmap_output = gr.Image(label="Cosine Similarity Heatmap") | |
# Events | |
load_btn.click( | |
fn=load_model, | |
inputs=[repo_input, revision_input], | |
outputs=load_status | |
) | |
image_input.change( | |
fn=display_image, | |
inputs=image_input, | |
outputs=image_output | |
) | |
compute_btn.click( | |
fn=visualize_cosine_heatmap, | |
inputs=image_input, | |
outputs=heatmap_output | |
) | |
if __name__ == "__main__": | |
demo.launch() |