Spaces:
Sleeping
Sleeping
Add Saliency maps and classification model
Browse files- animal_model.pkl +3 -0
- app.py +64 -4
- polar_bear.jpg +0 -0
- polar_bear_real.jpg +0 -0
- requirements.txt +5 -0
animal_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97772ffcc896001cbf04287473f5c2f686bac1e9867dbbe067f48dd57825540f
|
3 |
+
size 87473965
|
app.py
CHANGED
@@ -1,9 +1,69 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.vision.all import *
|
2 |
import gradio as gr
|
3 |
+
from captum.attr import Saliency
|
4 |
+
from torchvision import transforms
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
|
9 |
+
learn = load_learner('animal_model.pkl')
|
10 |
|
11 |
+
transform = transforms.Compose([
|
12 |
+
transforms.Resize((128,128)),
|
13 |
+
transforms.ToTensor(),
|
14 |
+
])
|
15 |
|
16 |
+
categories = learn.dls.vocab
|
17 |
|
18 |
+
def generate_saliency(image):
|
19 |
+
# Prepare the image for the model
|
20 |
+
img = PILImage.create(image)
|
21 |
+
|
22 |
+
# Get prediction
|
23 |
+
_, pred, probs = learn.predict(img)
|
24 |
+
|
25 |
+
# Create Captum interpretation object
|
26 |
+
interp = Saliency(learn.model)
|
27 |
+
|
28 |
+
# Transform and prepare image for saliency
|
29 |
+
tensor_image = transform(img).unsqueeze(0)
|
30 |
+
tensor_image = tensor_image.requires_grad_()
|
31 |
+
|
32 |
+
# Generate the saliency map
|
33 |
+
saliency_map = interp.attribute(tensor_image, target=pred)
|
34 |
+
|
35 |
+
# Process saliency map for visualization
|
36 |
+
saliency_np = saliency_map.squeeze().cpu().detach().numpy()
|
37 |
+
saliency_np = np.abs(saliency_np).sum(axis=0)
|
38 |
+
#saliency_np = (saliency_np - saliency_np.min()) / (saliency_np.max() - saliency_np.min())
|
39 |
+
|
40 |
+
# Create heatmap
|
41 |
+
plt.figure(figsize=(10, 10))
|
42 |
+
plt.imshow(saliency_np, cmap='viridis')
|
43 |
+
plt.axis('off')
|
44 |
+
plt.tight_layout()
|
45 |
+
plt.savefig('saliency_heatmap.png', pad_inches=0)
|
46 |
+
plt.close()
|
47 |
+
|
48 |
+
return (
|
49 |
+
dict(zip(categories, map(float, probs))),
|
50 |
+
'saliency_heatmap.png',
|
51 |
+
'saliency_overlay.png'
|
52 |
+
)
|
53 |
+
|
54 |
+
# Gradio interface
|
55 |
+
image = gr.Image(type="pil")
|
56 |
+
label = gr.Label()
|
57 |
+
examples = ['polar_bear_real.jpg', 'polar_bear.jpg']
|
58 |
+
|
59 |
+
interface = gr.Interface(
|
60 |
+
fn=generate_saliency,
|
61 |
+
inputs=image,
|
62 |
+
outputs=[
|
63 |
+
gr.Label(label="Predictions"),
|
64 |
+
gr.Image(type="filepath", label="Saliency Heatmap")
|
65 |
+
],
|
66 |
+
examples=examples
|
67 |
+
)
|
68 |
+
|
69 |
+
interface.launch()
|
polar_bear.jpg
ADDED
![]() |
polar_bear_real.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
captum
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
numpy
|
5 |
+
matplotlib
|