vinayp27 commited on
Commit
dc441f1
·
1 Parent(s): 39ba8d3

Add Saliency maps and classification model

Browse files
Files changed (5) hide show
  1. animal_model.pkl +3 -0
  2. app.py +64 -4
  3. polar_bear.jpg +0 -0
  4. polar_bear_real.jpg +0 -0
  5. requirements.txt +5 -0
animal_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97772ffcc896001cbf04287473f5c2f686bac1e9867dbbe067f48dd57825540f
3
+ size 87473965
app.py CHANGED
@@ -1,9 +1,69 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
6
 
 
7
 
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision.all import *
2
  import gradio as gr
3
+ from captum.attr import Saliency
4
+ from torchvision import transforms
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
 
9
+ learn = load_learner('animal_model.pkl')
10
 
11
+ transform = transforms.Compose([
12
+ transforms.Resize((128,128)),
13
+ transforms.ToTensor(),
14
+ ])
15
 
16
+ categories = learn.dls.vocab
17
 
18
+ def generate_saliency(image):
19
+ # Prepare the image for the model
20
+ img = PILImage.create(image)
21
+
22
+ # Get prediction
23
+ _, pred, probs = learn.predict(img)
24
+
25
+ # Create Captum interpretation object
26
+ interp = Saliency(learn.model)
27
+
28
+ # Transform and prepare image for saliency
29
+ tensor_image = transform(img).unsqueeze(0)
30
+ tensor_image = tensor_image.requires_grad_()
31
+
32
+ # Generate the saliency map
33
+ saliency_map = interp.attribute(tensor_image, target=pred)
34
+
35
+ # Process saliency map for visualization
36
+ saliency_np = saliency_map.squeeze().cpu().detach().numpy()
37
+ saliency_np = np.abs(saliency_np).sum(axis=0)
38
+ #saliency_np = (saliency_np - saliency_np.min()) / (saliency_np.max() - saliency_np.min())
39
+
40
+ # Create heatmap
41
+ plt.figure(figsize=(10, 10))
42
+ plt.imshow(saliency_np, cmap='viridis')
43
+ plt.axis('off')
44
+ plt.tight_layout()
45
+ plt.savefig('saliency_heatmap.png', pad_inches=0)
46
+ plt.close()
47
+
48
+ return (
49
+ dict(zip(categories, map(float, probs))),
50
+ 'saliency_heatmap.png',
51
+ 'saliency_overlay.png'
52
+ )
53
+
54
+ # Gradio interface
55
+ image = gr.Image(type="pil")
56
+ label = gr.Label()
57
+ examples = ['polar_bear_real.jpg', 'polar_bear.jpg']
58
+
59
+ interface = gr.Interface(
60
+ fn=generate_saliency,
61
+ inputs=image,
62
+ outputs=[
63
+ gr.Label(label="Predictions"),
64
+ gr.Image(type="filepath", label="Saliency Heatmap")
65
+ ],
66
+ examples=examples
67
+ )
68
+
69
+ interface.launch()
polar_bear.jpg ADDED
polar_bear_real.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ captum
2
+ torch
3
+ torchvision
4
+ numpy
5
+ matplotlib