Felix Konrad commited on
Commit
57c8491
·
1 Parent(s): 4bd55a6

Quick test.

Browse files
Files changed (2) hide show
  1. README.md +10 -0
  2. app.py +28 -1
README.md CHANGED
@@ -12,3 +12,13 @@ short_description: Visualizes cosine-similarity of CLS-token with image patches
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+
17
+ # Dynamic ViT Visualizer
18
+ Upload a Hugging Face ViT model and an image.
19
+ Visualize the cosine similarity between the CLS token and image patches.
20
+
21
+ **How to use:**
22
+ 1. Enter the model repo ID.
23
+ 2. Upload an image.
24
+ 3. Get visualization of cosine-similarity between CLS-token and tokens/patches of the uploaded image.
app.py CHANGED
@@ -1,4 +1,5 @@
1
- # app.py
 
2
  import gradio as gr
3
  from transformers import AutoModel, AutoImageProcessor
4
  from PIL import Image
@@ -11,6 +12,26 @@ state = {
11
  "repo_id": None,
12
  }
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def load_model(repo_id: str, revision: str = None):
15
  """
16
  Load a Hugging Face model and processor from a repo ID.
@@ -50,6 +71,12 @@ with gr.Blocks() as demo:
50
  image_input = gr.Image(type="pil", label="Upload Image")
51
  image_output = gr.Image(label="Displayed Image")
52
 
 
 
 
 
 
 
53
  # Button clicks / image upload handlers
54
  load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
55
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
  import gradio as gr
4
  from transformers import AutoModel, AutoImageProcessor
5
  from PIL import Image
 
12
  "repo_id": None,
13
  }
14
 
15
+
16
+ def plot_similarity_heatmap(sim_array: np.ndarray):
17
+ """
18
+ sim_array: 2D numpy array of shape (h, w)
19
+ Returns a PIL image that can be displayed in Gradio
20
+ """
21
+ fig, ax = plt.subplots(figsize=(5, 5))
22
+ cax = ax.imshow(sim_array, cmap='viridis')
23
+ ax.set_xticks([])
24
+ ax.set_yticks([])
25
+
26
+ fig.colorbar(cax)
27
+ fig.canvas.draw()
28
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
29
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
30
+
31
+ plt.close(fig)
32
+ return img
33
+
34
+
35
  def load_model(repo_id: str, revision: str = None):
36
  """
37
  Load a Hugging Face model and processor from a repo ID.
 
71
  image_input = gr.Image(type="pil", label="Upload Image")
72
  image_output = gr.Image(label="Displayed Image")
73
 
74
+ # cos-sim visualization:
75
+ # sim_array is your (h, w) numpy array
76
+ sim_array = np.random.normal((128, 128))
77
+ heatmap_img = plot_similarity_heatmap(sim_array)
78
+ gr.Image(value=heatmap_img, label="Cosine Similarity Heatmap")
79
+
80
  # Button clicks / image upload handlers
81
  load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
82
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)