Spaces:
Runtime error
Runtime error
Commit
·
fcc16aa
0
Parent(s):
Duplicate from taquynhnga/CNNs-interpretation-visualization
Browse filesCo-authored-by: Hanna Ta Quynh Nga <[email protected]>
- .gitattributes +4 -0
- .github/workflows/sync_to_huggingface_hub.yml +19 -0
- .gitignore +183 -0
- .vscode/settings.json +5 -0
- Home.py +43 -0
- README.md +19 -0
- Visual-Explanation-Methods-PyTorch +1 -0
- backend/adversarial_attack.py +100 -0
- backend/load_file.py +41 -0
- backend/maximally_activating_patches.py +45 -0
- backend/smooth_grad.py +235 -0
- backend/utils.py +379 -0
- data/ImageNet_metadata.csv +3 -0
- data/activation/convnext_activation.json +3 -0
- data/activation/mobilenet_activation.json +3 -0
- data/activation/resnet_activation.json +3 -0
- data/dot_architectures/convnext_architecture.dot +3 -0
- data/layer_infos/convnext_layer_infos.json +3 -0
- data/layer_infos/mobilenet_layer_infos.json +3 -0
- data/layer_infos/resnet_layer_infos.json +3 -0
- data/preprocessed_image_net/val_data_0.pkl +3 -0
- data/preprocessed_image_net/val_data_1.pkl +3 -0
- data/preprocessed_image_net/val_data_2.pkl +3 -0
- data/preprocessed_image_net/val_data_3.pkl +3 -0
- data/preprocessed_image_net/val_data_4.pkl +3 -0
- frontend/__init__.py +6 -0
- frontend/footer.py +32 -0
- frontend/images/equal-sign.png +0 -0
- frontend/images/minus-sign-2.png +0 -0
- frontend/images/minus-sign-3.png +0 -0
- frontend/images/minus-sign-4.png +0 -0
- frontend/images/minus-sign-5.png +0 -0
- frontend/images/minus-sign.png +0 -0
- frontend/images/plus-sign-2.png +0 -0
- frontend/images/plus-sign.png +0 -0
- frontend/index.html +204 -0
- pages/1_Maximally_activating_patches.py +166 -0
- pages/2_SmoothGrad.py +124 -0
- pages/3_Adversarial_attack.py +215 -0
- pages/4_ImageNet1k.py +56 -0
- requirements.txt +17 -0
.gitattributes
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
2 |
+
.csv filter=lfs diff=lfs merge=lfs -text
|
3 |
+
data/** filter=lfs diff=lfs merge=lfs -text
|
4 |
+
Visual-Explanation-Methods-PyTorch/** filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync_to_huggingface_hub.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
- name: Push to hub
|
17 |
+
env:
|
18 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
19 |
+
run: git push --force https://taquynhnga:[email protected]/spaces/taquynhnga/CNNs-interpretation-visualization main
|
.gitignore
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VSCode
|
2 |
+
.vscode/*
|
3 |
+
!.vscode/settings.json
|
4 |
+
!.vscode/tasks.json
|
5 |
+
!.vscode/launch.json
|
6 |
+
!.vscode/extensions.json
|
7 |
+
*.code-workspace
|
8 |
+
# Local History for Visual Studio Code
|
9 |
+
.history/
|
10 |
+
|
11 |
+
# Common credential files
|
12 |
+
**/credentials.json
|
13 |
+
**/client_secrets.json
|
14 |
+
**/client_secret.json
|
15 |
+
*creds*
|
16 |
+
*.dat
|
17 |
+
*password*
|
18 |
+
*.httr-oauth*
|
19 |
+
|
20 |
+
# Private Node Modules
|
21 |
+
node_modules/
|
22 |
+
creds.js
|
23 |
+
|
24 |
+
# Private Files
|
25 |
+
# *.json
|
26 |
+
# *.csv
|
27 |
+
# *.csv.gz
|
28 |
+
# *.tsv
|
29 |
+
# *.tsv.gz
|
30 |
+
# *.xlsx
|
31 |
+
git-large-file
|
32 |
+
deta_drive.py
|
33 |
+
secret_keys.py
|
34 |
+
|
35 |
+
# Large files
|
36 |
+
# data/preprocessed_image_net/
|
37 |
+
# data/activation/*.pkl
|
38 |
+
# data/activation/*.json
|
39 |
+
# data/layer_infos/*.pkl
|
40 |
+
# data/layer_infos/*.json
|
41 |
+
|
42 |
+
# Mac/OSX
|
43 |
+
.DS_Store
|
44 |
+
|
45 |
+
|
46 |
+
# Byte-compiled / optimized / DLL files
|
47 |
+
__pycache__/
|
48 |
+
*.py[cod]
|
49 |
+
*$py.class
|
50 |
+
|
51 |
+
# C extensions
|
52 |
+
*.so
|
53 |
+
|
54 |
+
# Distribution / packaging
|
55 |
+
.Python
|
56 |
+
build/
|
57 |
+
develop-eggs/
|
58 |
+
dist/
|
59 |
+
downloads/
|
60 |
+
eggs/
|
61 |
+
.eggs/
|
62 |
+
lib/
|
63 |
+
lib64/
|
64 |
+
parts/
|
65 |
+
sdist/
|
66 |
+
var/
|
67 |
+
wheels/
|
68 |
+
share/python-wheels/
|
69 |
+
*.egg-info/
|
70 |
+
.installed.cfg
|
71 |
+
*.egg
|
72 |
+
MANIFEST
|
73 |
+
|
74 |
+
# PyInstaller
|
75 |
+
# Usually these files are written by a python script from a template
|
76 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
77 |
+
*.manifest
|
78 |
+
*.spec
|
79 |
+
|
80 |
+
# Installer logs
|
81 |
+
pip-log.txt
|
82 |
+
pip-delete-this-directory.txt
|
83 |
+
|
84 |
+
# Unit test / coverage reports
|
85 |
+
htmlcov/
|
86 |
+
.tox/
|
87 |
+
.nox/
|
88 |
+
.coverage
|
89 |
+
.coverage.*
|
90 |
+
.cache
|
91 |
+
nosetests.xml
|
92 |
+
coverage.xml
|
93 |
+
*.cover
|
94 |
+
*.py,cover
|
95 |
+
.hypothesis/
|
96 |
+
.pytest_cache/
|
97 |
+
cover/
|
98 |
+
|
99 |
+
# Translations
|
100 |
+
*.mo
|
101 |
+
*.pot
|
102 |
+
|
103 |
+
# Django stuff:
|
104 |
+
*.log
|
105 |
+
local_settings.py
|
106 |
+
db.sqlite3
|
107 |
+
db.sqlite3-journal
|
108 |
+
|
109 |
+
# Flask stuff:
|
110 |
+
instance/
|
111 |
+
.webassets-cache
|
112 |
+
|
113 |
+
# Scrapy stuff:
|
114 |
+
.scrapy
|
115 |
+
|
116 |
+
# Sphinx documentation
|
117 |
+
docs/_build/
|
118 |
+
|
119 |
+
# PyBuilder
|
120 |
+
.pybuilder/
|
121 |
+
target/
|
122 |
+
|
123 |
+
# Jupyter Notebook
|
124 |
+
.ipynb_checkpoints
|
125 |
+
|
126 |
+
# IPython
|
127 |
+
profile_default/
|
128 |
+
ipython_config.py
|
129 |
+
|
130 |
+
# pyenv
|
131 |
+
# For a library or package, you might want to ignore these files since the code is
|
132 |
+
# intended to run in multiple environments; otherwise, check them in:
|
133 |
+
# .python-version
|
134 |
+
|
135 |
+
# pipenv
|
136 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
137 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
138 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
139 |
+
# install all needed dependencies.
|
140 |
+
#Pipfile.lock
|
141 |
+
|
142 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
143 |
+
__pypackages__/
|
144 |
+
|
145 |
+
# Celery stuff
|
146 |
+
celerybeat-schedule
|
147 |
+
celerybeat.pid
|
148 |
+
|
149 |
+
# SageMath parsed files
|
150 |
+
*.sage.py
|
151 |
+
|
152 |
+
# Environments
|
153 |
+
.env
|
154 |
+
.venv
|
155 |
+
env/
|
156 |
+
venv/
|
157 |
+
ENV/
|
158 |
+
env.bak/
|
159 |
+
venv.bak/
|
160 |
+
|
161 |
+
# Spyder project settings
|
162 |
+
.spyderproject
|
163 |
+
.spyproject
|
164 |
+
|
165 |
+
# Rope project settings
|
166 |
+
.ropeproject
|
167 |
+
|
168 |
+
# mkdocs documentation
|
169 |
+
/site
|
170 |
+
|
171 |
+
# mypy
|
172 |
+
.mypy_cache/
|
173 |
+
.dmypy.json
|
174 |
+
dmypy.json
|
175 |
+
|
176 |
+
# Pyre type checker
|
177 |
+
.pyre/
|
178 |
+
|
179 |
+
# pytype static type analyzer
|
180 |
+
.pytype/
|
181 |
+
|
182 |
+
# Cython debug symbols
|
183 |
+
cython_debug/
|
.vscode/settings.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.analysis.extraPaths": [
|
3 |
+
"./Visual-Explanation-Methods-PyTorch"
|
4 |
+
]
|
5 |
+
}
|
Home.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from frontend.footer import add_footer
|
3 |
+
|
4 |
+
st.set_page_config(layout='wide')
|
5 |
+
# st.set_page_config(layout='centered')
|
6 |
+
|
7 |
+
st.title('About')
|
8 |
+
|
9 |
+
# INTRO
|
10 |
+
intro_text = """Convolutional neural networks (ConvNets) have evolved at a rapid speed from the 2010s.
|
11 |
+
Some of the representative ConvNets models are VGGNet, Inceptions, ResNe(X)t, DenseNet, MobileNet, EfficientNet and RegNet, which focus on various factors of accuracy, efficiency, and scalability.
|
12 |
+
In the year 2020, Vision Transformers (ViT) was introduced as a Transformer model solving the computer vision problems.
|
13 |
+
Larger model and dataset sizes allow ViT to perform significantly better than ResNet, however, ViT still encountered challenges in generic computer vision tasks such as object detection and semantic segmentation.
|
14 |
+
Swin Transformer’ s success made Transformers be adopted as a generic vision backbone and showed outstanding performance in a wide range of computer vision tasks.
|
15 |
+
Nevertheless, rather than the intrinsic inductive biases of convolutions, the success of this approach is still primarily attributed to Transformers’ inherent superiority.
|
16 |
+
|
17 |
+
In 2022, Zhuang Liu et. al. proposed a pure convolutional model dubbed ConvNeXt, discovered from the modernization of a standard ResNet towards the design of Vision Transformers and claimed to outperform them.
|
18 |
+
|
19 |
+
The project aims to interpret the ConvNeXt model by several visualization techniques.
|
20 |
+
After that, a web interface would be built to demonstrate the interpretations, helping us look inside the deep ConvNeXt model and answer the questions:
|
21 |
+
> “What patterns maximally activated this filter (channel) in this layer?”\n
|
22 |
+
> “Which features are responsible for the current prediction?”.
|
23 |
+
|
24 |
+
Due to the limitation in time and resources, the project only used the tiny-sized ConvNeXt model, which was trained on ImageNet-1k at resolution 224x224 and used 50,000 images in validation set of ImageNet-1k for demo purpose.
|
25 |
+
|
26 |
+
In this web app, two visualization techniques were implemented and demonstrated, they are **Maximally activating patches** and **SmoothGrad**.
|
27 |
+
Besides, this web app also helps investigate the effect of **adversarial attacks** on ConvNeXt interpretations.
|
28 |
+
Last but not least, there is a last webpage that stores 50,000 images in the **ImageNet-1k** validation set, facilitating the two web pages above in searching and referencing.
|
29 |
+
"""
|
30 |
+
st.write(intro_text)
|
31 |
+
|
32 |
+
# 4 PAGES
|
33 |
+
st.subheader('Features')
|
34 |
+
sections_text = """Overall, there are 4 features in this web app:
|
35 |
+
1) Maximally activating patches: The visualization method in this page answers the question “what patterns maximally activated this filter (channel)?”.
|
36 |
+
2) SmoothGrad: This visualization method in this page answers the question “which features are responsible for the current prediction?”.
|
37 |
+
3) Adversarial attack: How adversarial attacks affect ConvNeXt interpretation?
|
38 |
+
4) ImageNet1k: The storage of 50,000 images in validation set.
|
39 |
+
"""
|
40 |
+
st.write(sections_text)
|
41 |
+
|
42 |
+
|
43 |
+
add_footer('Developed with ❤ by ', 'Hanna Ta Quynh Nga', 'https://www.linkedin.com/in/ta-quynh-nga-hanna/')
|
README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: CNNs Interpretation Visualization
|
3 |
+
emoji: 💡
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.17.0
|
8 |
+
app_file: Home.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: taquynhnga/CNNs-interpretation-visualization
|
11 |
+
---
|
12 |
+
|
13 |
+
# Visualizing Interpretations of CNN models: ConvNeXt, ResNet and MobileNet
|
14 |
+
|
15 |
+
To be change name: CNNs-interpretation-visualization
|
16 |
+
|
17 |
+
This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
|
18 |
+
|
19 |
+
This repo lacks one more folder `data/preprocessed_image_net` which contains 50,000 preprocessed imagenet validation images saved in 5 pickle files.
|
Visual-Explanation-Methods-PyTorch
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 5cb88902729af1d9d85259879b47cb238b841881
|
backend/adversarial_attack.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from matplotlib import pylab as P
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import TensorDataset
|
9 |
+
from torchvision import transforms
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
13 |
+
|
14 |
+
from torchvex.base import ExplanationMethod
|
15 |
+
from torchvex.utils.normalization import clamp_quantile
|
16 |
+
|
17 |
+
from backend.utils import load_image, load_model
|
18 |
+
from backend.smooth_grad import generate_smoothgrad_mask
|
19 |
+
|
20 |
+
import streamlit as st
|
21 |
+
|
22 |
+
IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
|
23 |
+
IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])
|
24 |
+
|
25 |
+
def deprocess_image(image_inputs):
|
26 |
+
return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255
|
27 |
+
|
28 |
+
|
29 |
+
def feed_forward(input_image):
|
30 |
+
model, feature_extractor = load_model('ConvNeXt')
|
31 |
+
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
|
32 |
+
logits = model(inputs).logits
|
33 |
+
prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability
|
34 |
+
# prediction class id, start from 1 to 1000 so it needs to +1 in the end
|
35 |
+
prediction_class = logits.argmax(-1).item()
|
36 |
+
prediction_label = model.config.id2label[prediction_class] # prediction class label
|
37 |
+
return prediction_prob, prediction_class, prediction_label
|
38 |
+
|
39 |
+
# FGSM attack code
|
40 |
+
def fgsm_attack(image, epsilon, data_grad):
|
41 |
+
# Collect the element-wise sign of the data gradient and normalize it
|
42 |
+
sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
|
43 |
+
perturbed_image = image + epsilon*sign_data_grad
|
44 |
+
return perturbed_image
|
45 |
+
|
46 |
+
# perform attack on the model
|
47 |
+
def perform_attack(input_image, target, epsilon):
|
48 |
+
model, feature_extractor = load_model("ConvNeXt")
|
49 |
+
# preprocess input image
|
50 |
+
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
|
51 |
+
inputs.requires_grad = True
|
52 |
+
|
53 |
+
# predict
|
54 |
+
logits = model(inputs).logits
|
55 |
+
prediction_prob = F.softmax(logits, dim=-1).max()
|
56 |
+
prediction_class = logits.argmax(-1).item()
|
57 |
+
prediction_label = model.config.id2label[prediction_class]
|
58 |
+
|
59 |
+
# Calculate the loss
|
60 |
+
loss = F.nll_loss(logits, torch.tensor([target]))
|
61 |
+
|
62 |
+
# Zero all existing gradients
|
63 |
+
model.zero_grad()
|
64 |
+
|
65 |
+
# Calculate gradients of model in backward pass
|
66 |
+
loss.backward()
|
67 |
+
|
68 |
+
# Collect datagrad
|
69 |
+
data_grad = inputs.grad.data
|
70 |
+
|
71 |
+
# Call FGSM Attack
|
72 |
+
perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
|
73 |
+
|
74 |
+
# Re-classify the perturbed image
|
75 |
+
new_prediction = model(perturbed_data).logits
|
76 |
+
new_pred_prob = F.softmax(new_prediction, dim=-1).max()
|
77 |
+
new_pred_class = new_prediction.argmax(-1).item()
|
78 |
+
new_pred_label = model.config.id2label[new_pred_class]
|
79 |
+
|
80 |
+
return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label
|
81 |
+
|
82 |
+
|
83 |
+
def find_smallest_epsilon(input_image, target):
|
84 |
+
epsilons = [i*0.001 for i in range(1000)]
|
85 |
+
|
86 |
+
for epsilon in epsilons:
|
87 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
|
88 |
+
if new_id != target:
|
89 |
+
return perturbed_data, new_prob, new_id, new_label, epsilon
|
90 |
+
return None
|
91 |
+
|
92 |
+
# @st.cache_data
|
93 |
+
@st.cache(allow_output_mutation=True)
|
94 |
+
def generate_images(image_id, epsilon=0):
|
95 |
+
model, feature_extractor = load_model("ConvNeXt")
|
96 |
+
original_image_dict = load_image(image_id)
|
97 |
+
image = original_image_dict['image']
|
98 |
+
return generate_smoothgrad_mask(
|
99 |
+
image, 'ConvNeXt',
|
100 |
+
model, feature_extractor, num_samples=10, return_mask=True)
|
backend/load_file.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
def load_pickle(filename):
|
7 |
+
with open(filename, 'rb') as file:
|
8 |
+
data = pickle.load(file)
|
9 |
+
return data
|
10 |
+
|
11 |
+
def save_pickle_to_json(filename):
|
12 |
+
ordered_dict = load_pickle(filename)
|
13 |
+
json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
|
14 |
+
with open(filename.replace('.pkl', '.json'), 'w') as f:
|
15 |
+
f.write(json_obj)
|
16 |
+
|
17 |
+
def load_json(filename):
|
18 |
+
with open(filename, 'r') as read_file:
|
19 |
+
loaded_dict = json.loads(read_file.read())
|
20 |
+
loaded_dict = OrderedDict(loaded_dict)
|
21 |
+
for k, v in loaded_dict.items():
|
22 |
+
if type(v) == list:
|
23 |
+
loaded_dict[k] = np.asarray(v)
|
24 |
+
else:
|
25 |
+
for k_, v_ in v.items():
|
26 |
+
v[k_] = np.asarray(v_)
|
27 |
+
return loaded_dict
|
28 |
+
|
29 |
+
class NumpyEncoder(json.JSONEncoder):
|
30 |
+
def default(self, obj):
|
31 |
+
if isinstance(obj, np.ndarray):
|
32 |
+
return obj.tolist()
|
33 |
+
return json.JSONEncoder.default(self, obj)
|
34 |
+
|
35 |
+
# save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
|
36 |
+
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
37 |
+
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
38 |
+
|
39 |
+
# file = load_json('data/layer_infos/convnext_layer_infos.json')
|
40 |
+
# print(type(file))
|
41 |
+
# print(type(file['embeddings.patch_embeddings']))
|
backend/maximally_activating_patches.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from backend.load_file import load_json
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache(allow_output_mutation=True)
|
8 |
+
# st.cache_data
|
9 |
+
def load_activation(filename):
|
10 |
+
activation = load_json(filename)
|
11 |
+
return activation
|
12 |
+
|
13 |
+
@st.cache(allow_output_mutation=True)
|
14 |
+
# @st.cache_data
|
15 |
+
def load_dataset(data_index):
|
16 |
+
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
17 |
+
dataset = pickle.load(file)
|
18 |
+
return dataset
|
19 |
+
|
20 |
+
def load_layer_infos(filename):
|
21 |
+
layer_infos = load_json(filename)
|
22 |
+
return layer_infos
|
23 |
+
|
24 |
+
def get_receptive_field_coordinates(layer_infos, layer_name, idx_x, idx_y):
|
25 |
+
"""
|
26 |
+
layer_name: as in layer_infos keys (eg: 'encoder.stages[0].layers[0]')
|
27 |
+
idx_x: integer coordinate of width axis in feature maps. must < n
|
28 |
+
idx_y: integer coordinate of height axis in feature maps. must < n
|
29 |
+
"""
|
30 |
+
layer_name = layer_name.replace('.dwconv', '').replace('.layernorm', '')
|
31 |
+
layer_name = layer_name.replace('.pwconv1', '').replace('.pwconv2', '').replace('.drop_path', '')
|
32 |
+
n = layer_infos[layer_name]['n']
|
33 |
+
j = layer_infos[layer_name]['j']
|
34 |
+
r = layer_infos[layer_name]['r']
|
35 |
+
start = layer_infos[layer_name]['start']
|
36 |
+
assert idx_x < n, f'n={n}'
|
37 |
+
assert idx_y < n, f'n={n}'
|
38 |
+
|
39 |
+
# image tensor (N, H, W, C) or (N, C, H, W) => image_patch=image[y1:y2, x1:x2]
|
40 |
+
center = (start + idx_x*j, start + idx_y*j)
|
41 |
+
x1, x2 = (max(center[0]-r/2, 0), max(center[0]+r/2, 0))
|
42 |
+
y1, y2 = (max(center[1]-r/2, 0), max(center[1]+r/2, 0))
|
43 |
+
x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
|
44 |
+
|
45 |
+
return x1, x2, y1, y2
|
backend/smooth_grad.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from matplotlib import pylab as P
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import TensorDataset
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
# dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
|
12 |
+
# sys.path.append(dirpath_to_modules)
|
13 |
+
|
14 |
+
from torchvex.base import ExplanationMethod
|
15 |
+
from torchvex.utils.normalization import clamp_quantile
|
16 |
+
|
17 |
+
def ShowImage(im, title='', ax=None):
|
18 |
+
image = np.array(im)
|
19 |
+
return image
|
20 |
+
|
21 |
+
def ShowGrayscaleImage(im, title='', ax=None):
|
22 |
+
if ax is None:
|
23 |
+
P.figure()
|
24 |
+
P.axis('off')
|
25 |
+
P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1)
|
26 |
+
P.title(title)
|
27 |
+
return P
|
28 |
+
|
29 |
+
def ShowHeatMap(im, title='', ax=None):
|
30 |
+
im = im - im.min()
|
31 |
+
im = im / im.max()
|
32 |
+
im = im.clip(0,1)
|
33 |
+
im = np.uint8(im * 255)
|
34 |
+
|
35 |
+
im = cv2.resize(im, (224,224))
|
36 |
+
image = cv2.resize(im, (224, 224))
|
37 |
+
|
38 |
+
# Apply JET colormap
|
39 |
+
color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT)
|
40 |
+
# P.imshow(im, cmap='inferno')
|
41 |
+
# P.title(title)
|
42 |
+
return color_heatmap
|
43 |
+
|
44 |
+
def ShowMaskedImage(saliency_map, image, title='', ax=None):
|
45 |
+
"""
|
46 |
+
Save saliency map on image.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
image: Tensor of size (H,W,3)
|
50 |
+
saliency_map: Tensor of size (H,W,1)
|
51 |
+
"""
|
52 |
+
|
53 |
+
# if ax is None:
|
54 |
+
# P.figure()
|
55 |
+
# P.axis('off')
|
56 |
+
|
57 |
+
saliency_map = saliency_map - saliency_map.min()
|
58 |
+
saliency_map = saliency_map / saliency_map.max()
|
59 |
+
saliency_map = saliency_map.clip(0,1)
|
60 |
+
saliency_map = np.uint8(saliency_map * 255)
|
61 |
+
|
62 |
+
saliency_map = cv2.resize(saliency_map, (224,224))
|
63 |
+
image = cv2.resize(image, (224, 224))
|
64 |
+
|
65 |
+
# Apply JET colormap
|
66 |
+
color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT)
|
67 |
+
|
68 |
+
# Blend image with heatmap
|
69 |
+
img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0)
|
70 |
+
|
71 |
+
# P.imshow(img_with_heatmap)
|
72 |
+
# P.title(title)
|
73 |
+
return img_with_heatmap
|
74 |
+
|
75 |
+
def LoadImage(file_path):
|
76 |
+
im = PIL.Image.open(file_path)
|
77 |
+
im = im.resize((224, 224))
|
78 |
+
im = np.asarray(im)
|
79 |
+
return im
|
80 |
+
|
81 |
+
|
82 |
+
def visualize_image_grayscale(image_3d, percentile=99):
|
83 |
+
r"""Returns a 3D tensor as a grayscale 2D tensor.
|
84 |
+
This method sums a 3D tensor across the absolute value of axis=2, and then
|
85 |
+
clips values at a given percentile.
|
86 |
+
"""
|
87 |
+
image_2d = np.sum(np.abs(image_3d), axis=2)
|
88 |
+
|
89 |
+
vmax = np.percentile(image_2d, percentile)
|
90 |
+
vmin = np.min(image_2d)
|
91 |
+
|
92 |
+
return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)
|
93 |
+
|
94 |
+
def visualize_image_diverging(image_3d, percentile=99):
|
95 |
+
r"""Returns a 3D tensor as a 2D tensor with positive and negative values.
|
96 |
+
"""
|
97 |
+
image_2d = np.sum(image_3d, axis=2)
|
98 |
+
|
99 |
+
span = abs(np.percentile(image_2d, percentile))
|
100 |
+
vmin = -span
|
101 |
+
vmax = span
|
102 |
+
|
103 |
+
return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1)
|
104 |
+
|
105 |
+
|
106 |
+
class SimpleGradient(ExplanationMethod):
|
107 |
+
def __init__(self, model, create_graph=False,
|
108 |
+
preprocess=None, postprocess=None):
|
109 |
+
super().__init__(model, preprocess, postprocess)
|
110 |
+
self.create_graph = create_graph
|
111 |
+
|
112 |
+
def predict(self, x):
|
113 |
+
return self.model(x)
|
114 |
+
|
115 |
+
@torch.enable_grad()
|
116 |
+
def process(self, inputs, target):
|
117 |
+
self.model.zero_grad()
|
118 |
+
inputs.requires_grad_(True)
|
119 |
+
|
120 |
+
out = self.model(inputs)
|
121 |
+
out = out if type(out) == torch.Tensor else out.logits
|
122 |
+
|
123 |
+
num_classes = out.size(-1)
|
124 |
+
onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:])
|
125 |
+
onehot = onehot.to(dtype=inputs.dtype, device=inputs.device)
|
126 |
+
onehot.scatter_(1, target.unsqueeze(1), 1)
|
127 |
+
|
128 |
+
grad, = torch.autograd.grad(
|
129 |
+
(out*onehot).sum(), inputs, create_graph=self.create_graph
|
130 |
+
)
|
131 |
+
|
132 |
+
return grad
|
133 |
+
|
134 |
+
|
135 |
+
class SmoothGradient(ExplanationMethod):
|
136 |
+
def __init__(self, model, stdev_spread=0.15, num_samples=25,
|
137 |
+
magnitude=True, batch_size=-1,
|
138 |
+
create_graph=False, preprocess=None, postprocess=None):
|
139 |
+
super().__init__(model, preprocess, postprocess)
|
140 |
+
self.stdev_spread = stdev_spread
|
141 |
+
self.nsample = num_samples
|
142 |
+
self.create_graph = create_graph
|
143 |
+
self.magnitude = magnitude
|
144 |
+
self.batch_size = batch_size
|
145 |
+
if self.batch_size == -1:
|
146 |
+
self.batch_size = self.nsample
|
147 |
+
|
148 |
+
self._simgrad = SimpleGradient(model, create_graph)
|
149 |
+
|
150 |
+
def process(self, inputs, target):
|
151 |
+
self.model.zero_grad()
|
152 |
+
|
153 |
+
maxima = inputs.flatten(1).max(-1)[0]
|
154 |
+
minima = inputs.flatten(1).min(-1)[0]
|
155 |
+
|
156 |
+
stdev = self.stdev_spread * (maxima - minima).cpu()
|
157 |
+
stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs)
|
158 |
+
stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4)
|
159 |
+
noise = torch.normal(0, stdev)
|
160 |
+
|
161 |
+
target_expanded = target.unsqueeze(0).cpu()
|
162 |
+
target_expanded = target_expanded.expand(noise.size(0), -1)
|
163 |
+
|
164 |
+
noiseloader = torch.utils.data.DataLoader(
|
165 |
+
TensorDataset(noise, target_expanded), batch_size=self.batch_size
|
166 |
+
)
|
167 |
+
|
168 |
+
total_gradients = torch.zeros_like(inputs)
|
169 |
+
for noise, t_exp in noiseloader:
|
170 |
+
inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device)
|
171 |
+
inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:])
|
172 |
+
gradients = self._simgrad(inputs_w_noise, t_exp.view(-1))
|
173 |
+
gradients = gradients.view(self.batch_size, *inputs.shape)
|
174 |
+
if self.magnitude:
|
175 |
+
gradients = gradients.pow(2)
|
176 |
+
total_gradients = total_gradients + gradients.sum(0)
|
177 |
+
|
178 |
+
smoothed_gradient = total_gradients / self.nsample
|
179 |
+
return smoothed_gradient
|
180 |
+
|
181 |
+
|
182 |
+
def feed_forward(model_name, image, model=None, feature_extractor=None):
|
183 |
+
if model_name in ['ConvNeXt', 'ResNet']:
|
184 |
+
inputs = feature_extractor(image, return_tensors="pt")
|
185 |
+
logits = model(**inputs).logits
|
186 |
+
prediction_class = logits.argmax(-1).item()
|
187 |
+
else:
|
188 |
+
transform_images = transforms.Compose([
|
189 |
+
transforms.Resize(224),
|
190 |
+
transforms.ToTensor(),
|
191 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
192 |
+
input_tensor = transform_images(image)
|
193 |
+
inputs = input_tensor.unsqueeze(0)
|
194 |
+
|
195 |
+
output = model(inputs)
|
196 |
+
prediction_class = output.argmax(-1).item()
|
197 |
+
#prediction_label = model.config.id2label[prediction_class]
|
198 |
+
return inputs, prediction_class
|
199 |
+
|
200 |
+
def clip_gradient(gradient):
|
201 |
+
gradient = gradient.abs().sum(1, keepdim=True)
|
202 |
+
return clamp_quantile(gradient, q=0.99)
|
203 |
+
|
204 |
+
def fig2img(fig):
|
205 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
206 |
+
import io
|
207 |
+
buf = io.BytesIO()
|
208 |
+
fig.savefig(buf)
|
209 |
+
buf.seek(0)
|
210 |
+
img = Image.open(buf)
|
211 |
+
return img
|
212 |
+
|
213 |
+
def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False):
|
214 |
+
inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
|
215 |
+
|
216 |
+
smoothgrad_gen = SmoothGradient(
|
217 |
+
model, num_samples=num_samples, stdev_spread=0.1,
|
218 |
+
magnitude=False, postprocess=clip_gradient)
|
219 |
+
|
220 |
+
if type(inputs) != torch.Tensor:
|
221 |
+
inputs = inputs['pixel_values']
|
222 |
+
|
223 |
+
smoothgrad_mask = smoothgrad_gen(inputs, prediction_class)
|
224 |
+
smoothgrad_mask = smoothgrad_mask[0].numpy()
|
225 |
+
smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0))
|
226 |
+
|
227 |
+
image = np.asarray(image)
|
228 |
+
# ori_image = ShowImage(image)
|
229 |
+
heat_map_image = ShowHeatMap(smoothgrad_mask)
|
230 |
+
masked_image = ShowMaskedImage(smoothgrad_mask, image)
|
231 |
+
|
232 |
+
if return_mask:
|
233 |
+
return heat_map_image, masked_image, smoothgrad_mask
|
234 |
+
else:
|
235 |
+
return heat_map_image, masked_image
|
backend/utils.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import io
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import markdown
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
import streamlit as st
|
13 |
+
from plotly import express as px
|
14 |
+
from plotly.subplots import make_subplots
|
15 |
+
from tqdm import trange
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
19 |
+
|
20 |
+
@st.cache(allow_output_mutation=True)
|
21 |
+
# @st.cache_resource
|
22 |
+
def load_dataset(data_index):
|
23 |
+
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
24 |
+
dataset = pickle.load(file)
|
25 |
+
return dataset
|
26 |
+
|
27 |
+
@st.cache(allow_output_mutation=True)
|
28 |
+
# @st.cache_resource
|
29 |
+
def load_dataset_dict():
|
30 |
+
dataset_dict = {}
|
31 |
+
progress_empty = st.empty()
|
32 |
+
text_empty = st.empty()
|
33 |
+
text_empty.write("Loading datasets...")
|
34 |
+
progress_bar = progress_empty.progress(0.0)
|
35 |
+
for data_index in trange(5):
|
36 |
+
dataset_dict[data_index] = load_dataset(data_index)
|
37 |
+
progress_bar.progress((data_index+1)/5)
|
38 |
+
progress_empty.empty()
|
39 |
+
text_empty.empty()
|
40 |
+
return dataset_dict
|
41 |
+
|
42 |
+
|
43 |
+
# @st.cache_data
|
44 |
+
@st.cache(allow_output_mutation=True)
|
45 |
+
def load_image(image_id):
|
46 |
+
dataset = load_dataset(image_id//10000)
|
47 |
+
image = dataset[image_id%10000]
|
48 |
+
return image
|
49 |
+
|
50 |
+
# @st.cache_data
|
51 |
+
@st.cache(allow_output_mutation=True)
|
52 |
+
def load_images(image_ids):
|
53 |
+
images = []
|
54 |
+
for image_id in image_ids:
|
55 |
+
image = load_image(image_id)
|
56 |
+
images.append(image)
|
57 |
+
return images
|
58 |
+
|
59 |
+
|
60 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
61 |
+
# @st.cache_resource
|
62 |
+
def load_model(model_name):
|
63 |
+
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
64 |
+
if model_name == 'ResNet':
|
65 |
+
model_file_path = 'microsoft/resnet-50'
|
66 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
67 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
68 |
+
model.eval()
|
69 |
+
elif model_name == 'ConvNeXt':
|
70 |
+
model_file_path = 'facebook/convnext-tiny-224'
|
71 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
72 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
73 |
+
model.eval()
|
74 |
+
else:
|
75 |
+
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
|
76 |
+
model.eval()
|
77 |
+
feature_extractor = None
|
78 |
+
return model, feature_extractor
|
79 |
+
|
80 |
+
|
81 |
+
def make_grid(cols=None,rows=None):
|
82 |
+
grid = [0]*rows
|
83 |
+
for i in range(rows):
|
84 |
+
with st.container():
|
85 |
+
grid[i] = st.columns(cols)
|
86 |
+
return grid
|
87 |
+
|
88 |
+
|
89 |
+
def use_container_width_percentage(percentage_width:int = 75):
|
90 |
+
max_width_str = f"max-width: {percentage_width}%;"
|
91 |
+
st.markdown(f"""
|
92 |
+
<style>
|
93 |
+
.reportview-container .main .block-container{{{max_width_str}}}
|
94 |
+
</style>
|
95 |
+
""",
|
96 |
+
unsafe_allow_html=True,
|
97 |
+
)
|
98 |
+
|
99 |
+
matplotlib.use("Agg")
|
100 |
+
COLOR = "#31333f"
|
101 |
+
BACKGROUND_COLOR = "#ffffff"
|
102 |
+
|
103 |
+
|
104 |
+
def grid_demo():
|
105 |
+
"""Main function. Run this to run the app"""
|
106 |
+
st.sidebar.title("Layout and Style Experiments")
|
107 |
+
st.sidebar.header("Settings")
|
108 |
+
st.markdown(
|
109 |
+
"""
|
110 |
+
# Layout and Style Experiments
|
111 |
+
|
112 |
+
The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
|
113 |
+
the [CSS Grid](https://gridbyexample.com/examples)?
|
114 |
+
|
115 |
+
Can we do it with a nice api?
|
116 |
+
Can have a dark theme?
|
117 |
+
"""
|
118 |
+
)
|
119 |
+
|
120 |
+
select_block_container_style()
|
121 |
+
add_resources_section()
|
122 |
+
|
123 |
+
# My preliminary idea of an API for generating a grid
|
124 |
+
with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
|
125 |
+
grid.cell(
|
126 |
+
class_="a",
|
127 |
+
grid_column_start=2,
|
128 |
+
grid_column_end=3,
|
129 |
+
grid_row_start=1,
|
130 |
+
grid_row_end=2,
|
131 |
+
).markdown("# This is A Markdown Cell")
|
132 |
+
grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
|
133 |
+
grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
|
134 |
+
grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
|
135 |
+
grid.cell("e", 3, 4, 1, 2).markdown(
|
136 |
+
"Try changing the **block container style** in the sidebar!"
|
137 |
+
)
|
138 |
+
grid.cell("f", 1, 3, 3, 4).text(
|
139 |
+
"The cell to the right is a matplotlib svg image"
|
140 |
+
)
|
141 |
+
grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())
|
142 |
+
|
143 |
+
|
144 |
+
def add_resources_section():
|
145 |
+
"""Adds a resources section to the sidebar"""
|
146 |
+
st.sidebar.header("Add_resources_section")
|
147 |
+
st.sidebar.markdown(
|
148 |
+
"""
|
149 |
+
- [gridbyexample.com] (https://gridbyexample.com/examples/)
|
150 |
+
"""
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
class Cell:
|
155 |
+
"""A Cell can hold text, markdown, plots etc."""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
class_: str = None,
|
160 |
+
grid_column_start: Optional[int] = None,
|
161 |
+
grid_column_end: Optional[int] = None,
|
162 |
+
grid_row_start: Optional[int] = None,
|
163 |
+
grid_row_end: Optional[int] = None,
|
164 |
+
):
|
165 |
+
self.class_ = class_
|
166 |
+
self.grid_column_start = grid_column_start
|
167 |
+
self.grid_column_end = grid_column_end
|
168 |
+
self.grid_row_start = grid_row_start
|
169 |
+
self.grid_row_end = grid_row_end
|
170 |
+
self.inner_html = ""
|
171 |
+
|
172 |
+
def _to_style(self) -> str:
|
173 |
+
return f"""
|
174 |
+
.{self.class_} {{
|
175 |
+
grid-column-start: {self.grid_column_start};
|
176 |
+
grid-column-end: {self.grid_column_end};
|
177 |
+
grid-row-start: {self.grid_row_start};
|
178 |
+
grid-row-end: {self.grid_row_end};
|
179 |
+
}}
|
180 |
+
"""
|
181 |
+
|
182 |
+
def text(self, text: str = ""):
|
183 |
+
self.inner_html = text
|
184 |
+
|
185 |
+
def markdown(self, text):
|
186 |
+
self.inner_html = markdown.markdown(text)
|
187 |
+
|
188 |
+
def dataframe(self, dataframe: pd.DataFrame):
|
189 |
+
self.inner_html = dataframe.to_html()
|
190 |
+
|
191 |
+
def plotly_chart(self, fig):
|
192 |
+
self.inner_html = f"""
|
193 |
+
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
194 |
+
<body>
|
195 |
+
<p>This should have been a plotly plot.
|
196 |
+
But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
|
197 |
+
But I could potentially save to svg and insert that.</p>
|
198 |
+
<div id='divPlotly'></div>
|
199 |
+
<script>
|
200 |
+
var plotly_data = {fig.to_json()}
|
201 |
+
Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
|
202 |
+
</script>
|
203 |
+
</body>
|
204 |
+
"""
|
205 |
+
|
206 |
+
def pyplot(self, fig=None, **kwargs):
|
207 |
+
string_io = io.StringIO()
|
208 |
+
plt.savefig(string_io, format="svg", fig=(2, 2))
|
209 |
+
svg = string_io.getvalue()[215:]
|
210 |
+
plt.close(fig)
|
211 |
+
self.inner_html = '<div height="200px">' + svg + "</div>"
|
212 |
+
|
213 |
+
def _to_html(self):
|
214 |
+
return f"""<div class="box {self.class_}">{self.inner_html}</div>"""
|
215 |
+
|
216 |
+
|
217 |
+
class Grid:
|
218 |
+
"""A (CSS) Grid"""
|
219 |
+
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
template_columns="1 1 1",
|
223 |
+
gap="10px",
|
224 |
+
background_color=COLOR,
|
225 |
+
color=BACKGROUND_COLOR,
|
226 |
+
):
|
227 |
+
self.template_columns = template_columns
|
228 |
+
self.gap = gap
|
229 |
+
self.background_color = background_color
|
230 |
+
self.color = color
|
231 |
+
self.cells: List[Cell] = []
|
232 |
+
|
233 |
+
def __enter__(self):
|
234 |
+
return self
|
235 |
+
|
236 |
+
def __exit__(self, type, value, traceback):
|
237 |
+
st.markdown(self._get_grid_style(), unsafe_allow_html=True)
|
238 |
+
st.markdown(self._get_cells_style(), unsafe_allow_html=True)
|
239 |
+
st.markdown(self._get_cells_html(), unsafe_allow_html=True)
|
240 |
+
|
241 |
+
def _get_grid_style(self):
|
242 |
+
return f"""
|
243 |
+
<style>
|
244 |
+
.wrapper {{
|
245 |
+
display: grid;
|
246 |
+
grid-template-columns: {self.template_columns};
|
247 |
+
grid-gap: {self.gap};
|
248 |
+
background-color: {self.color};
|
249 |
+
color: {self.background_color};
|
250 |
+
}}
|
251 |
+
.box {{
|
252 |
+
background-color: {self.color};
|
253 |
+
color: {self.background_color};
|
254 |
+
border-radius: 0px;
|
255 |
+
padding: 0px;
|
256 |
+
font-size: 100%;
|
257 |
+
text-align: center;
|
258 |
+
}}
|
259 |
+
table {{
|
260 |
+
color: {self.color}
|
261 |
+
}}
|
262 |
+
</style>
|
263 |
+
"""
|
264 |
+
|
265 |
+
def _get_cells_style(self):
|
266 |
+
return (
|
267 |
+
"<style>"
|
268 |
+
+ "\n".join([cell._to_style() for cell in self.cells])
|
269 |
+
+ "</style>"
|
270 |
+
)
|
271 |
+
|
272 |
+
def _get_cells_html(self):
|
273 |
+
return (
|
274 |
+
'<div class="wrapper">'
|
275 |
+
+ "\n".join([cell._to_html() for cell in self.cells])
|
276 |
+
+ "</div>"
|
277 |
+
)
|
278 |
+
|
279 |
+
def cell(
|
280 |
+
self,
|
281 |
+
class_: str = None,
|
282 |
+
grid_column_start: Optional[int] = None,
|
283 |
+
grid_column_end: Optional[int] = None,
|
284 |
+
grid_row_start: Optional[int] = None,
|
285 |
+
grid_row_end: Optional[int] = None,
|
286 |
+
):
|
287 |
+
cell = Cell(
|
288 |
+
class_=class_,
|
289 |
+
grid_column_start=grid_column_start,
|
290 |
+
grid_column_end=grid_column_end,
|
291 |
+
grid_row_start=grid_row_start,
|
292 |
+
grid_row_end=grid_row_end,
|
293 |
+
)
|
294 |
+
self.cells.append(cell)
|
295 |
+
return cell
|
296 |
+
|
297 |
+
|
298 |
+
def select_block_container_style():
|
299 |
+
"""Add selection section for setting setting the max-width and padding
|
300 |
+
of the main block container"""
|
301 |
+
st.sidebar.header("Block Container Style")
|
302 |
+
max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
|
303 |
+
if not max_width_100_percent:
|
304 |
+
max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
|
305 |
+
else:
|
306 |
+
max_width = 1200
|
307 |
+
dark_theme = st.sidebar.checkbox("Dark Theme?", False)
|
308 |
+
padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
|
309 |
+
padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
|
310 |
+
padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
|
311 |
+
padding_bottom = st.sidebar.number_input(
|
312 |
+
"Select padding bottom in rem", 0, 200, 10, 1
|
313 |
+
)
|
314 |
+
if dark_theme:
|
315 |
+
global COLOR
|
316 |
+
global BACKGROUND_COLOR
|
317 |
+
BACKGROUND_COLOR = "rgb(17,17,17)"
|
318 |
+
COLOR = "#fff"
|
319 |
+
|
320 |
+
_set_block_container_style(
|
321 |
+
max_width,
|
322 |
+
max_width_100_percent,
|
323 |
+
padding_top,
|
324 |
+
padding_right,
|
325 |
+
padding_left,
|
326 |
+
padding_bottom,
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
def _set_block_container_style(
|
331 |
+
max_width: int = 1200,
|
332 |
+
max_width_100_percent: bool = False,
|
333 |
+
padding_top: int = 5,
|
334 |
+
padding_right: int = 1,
|
335 |
+
padding_left: int = 1,
|
336 |
+
padding_bottom: int = 10,
|
337 |
+
):
|
338 |
+
if max_width_100_percent:
|
339 |
+
max_width_str = f"max-width: 100%;"
|
340 |
+
else:
|
341 |
+
max_width_str = f"max-width: {max_width}px;"
|
342 |
+
st.markdown(
|
343 |
+
f"""
|
344 |
+
<style>
|
345 |
+
.reportview-container .main .block-container{{
|
346 |
+
{max_width_str}
|
347 |
+
padding-top: {padding_top}rem;
|
348 |
+
padding-right: {padding_right}rem;
|
349 |
+
padding-left: {padding_left}rem;
|
350 |
+
padding-bottom: {padding_bottom}rem;
|
351 |
+
}}
|
352 |
+
.reportview-container .main {{
|
353 |
+
color: {COLOR};
|
354 |
+
background-color: {BACKGROUND_COLOR};
|
355 |
+
}}
|
356 |
+
</style>
|
357 |
+
""",
|
358 |
+
unsafe_allow_html=True,
|
359 |
+
)
|
360 |
+
|
361 |
+
|
362 |
+
# @st.cache
|
363 |
+
# def get_dataframe() -> pd.DataFrame():
|
364 |
+
# """Dummy DataFrame"""
|
365 |
+
# data = [
|
366 |
+
# {"quantity": 1, "price": 2},
|
367 |
+
# {"quantity": 3, "price": 5},
|
368 |
+
# {"quantity": 4, "price": 8},
|
369 |
+
# ]
|
370 |
+
# return pd.DataFrame(data)
|
371 |
+
|
372 |
+
|
373 |
+
# def get_plotly_fig():
|
374 |
+
# """Dummy Plotly Plot"""
|
375 |
+
# return px.line(data_frame=get_dataframe(), x="quantity", y="price")
|
376 |
+
|
377 |
+
|
378 |
+
# def get_matplotlib_plt():
|
379 |
+
# get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
|
data/ImageNet_metadata.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e53b0fc17cd5c8811ca08b7ff908cd2bbd625147686ef8bc020cb85a5a4546e5
|
3 |
+
size 3027633
|
data/activation/convnext_activation.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0354b28bcca4e3673888124740e3d82882cbf38af8cd3007f48a7a5db983f487
|
3 |
+
size 33350177
|
data/activation/mobilenet_activation.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5abc76e9318fadee18f35bb54e90201bf28699cf75140b5d2482d42243fad302
|
3 |
+
size 13564581
|
data/activation/resnet_activation.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:668bea355a5504d74f79d20d02954040ad572f50455361d7d17125c7c8b1561c
|
3 |
+
size 23362905
|
data/dot_architectures/convnext_architecture.dot
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41a258a40a93615638ae504770c14e44836c934badbe48f18148f5a750514ac9
|
3 |
+
size 9108
|
data/layer_infos/convnext_layer_infos.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e82ea48865493107b97f37da58e370f0eead5677bf10f25f237f10970aedb6f
|
3 |
+
size 1678
|
data/layer_infos/mobilenet_layer_infos.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a11df5f0b23040d11ce817658a989c8faf19faa06a8cbad727b635bac824e917
|
3 |
+
size 3578
|
data/layer_infos/resnet_layer_infos.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21e1787382f1e1c206b81d2c4fe207fb6d41f4cf186d5afc32fc056dd21e10d6
|
3 |
+
size 5155
|
data/preprocessed_image_net/val_data_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2698bdc240555e2a46a40936df87275bc5852142d30e921ae0dad9289b0f576f
|
3 |
+
size 906108480
|
data/preprocessed_image_net/val_data_1.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21780d77e212695dbee84d6d2ad17a5a520bc1634f68e1c8fd120f069ad76da1
|
3 |
+
size 907109023
|
data/preprocessed_image_net/val_data_2.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2cfc83b78420baa1b2c3a8da92e7fba1f33443d506f483ecff13cdba2035ab3c
|
3 |
+
size 907435149
|
data/preprocessed_image_net/val_data_3.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f5e2c7cb4d6bae17fbd062a0b46f2cee457ad466b725f7bdf0f8426069cafee
|
3 |
+
size 906089333
|
data/preprocessed_image_net/val_data_4.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ed53c87ec8b9945db31f910eb44b7e3092324643de25ea53a99fc29137df854
|
3 |
+
size 905439763
|
frontend/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit.components.v1 as components
|
2 |
+
|
3 |
+
on_click_graph = components.declare_component(
|
4 |
+
"on_click_graph",
|
5 |
+
path="./frontend"
|
6 |
+
)
|
frontend/footer.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
footer="""<style>
|
4 |
+
a:link , a:visited{
|
5 |
+
background-color: transparent;
|
6 |
+
# text-decoration: underline;
|
7 |
+
}
|
8 |
+
|
9 |
+
a:hover, a:active {
|
10 |
+
color: orange;
|
11 |
+
background-color: transparent;
|
12 |
+
text-decoration: underline;
|
13 |
+
}
|
14 |
+
|
15 |
+
.footer {
|
16 |
+
position: fixed;
|
17 |
+
left: 0;
|
18 |
+
bottom: 0;
|
19 |
+
width: 100%;
|
20 |
+
background-color: white;
|
21 |
+
text-align: center;
|
22 |
+
}
|
23 |
+
</style>
|
24 |
+
|
25 |
+
<div class="footer">
|
26 |
+
<p>USER_DEFINED_TEXT <a href="LINK" target="_blank" color="blue">LINKED_TEXT</a></p>
|
27 |
+
</div>
|
28 |
+
"""
|
29 |
+
|
30 |
+
def add_footer(text, linked_text, link):
|
31 |
+
custom_footer = footer.replace('USER_DEFINED_TEXT', text).replace('LINKED_TEXT', linked_text).replace('LINK', link)
|
32 |
+
st.markdown(custom_footer, unsafe_allow_html=True)
|
frontend/images/equal-sign.png
ADDED
![]() |
frontend/images/minus-sign-2.png
ADDED
![]() |
frontend/images/minus-sign-3.png
ADDED
![]() |
frontend/images/minus-sign-4.png
ADDED
![]() |
frontend/images/minus-sign-5.png
ADDED
![]() |
frontend/images/minus-sign.png
ADDED
![]() |
frontend/images/plus-sign-2.png
ADDED
![]() |
frontend/images/plus-sign.png
ADDED
![]() |
frontend/index.html
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
|
3 |
+
<head>
|
4 |
+
<style type="text/css">
|
5 |
+
</style>
|
6 |
+
</head>
|
7 |
+
|
8 |
+
<!--
|
9 |
+
----------------------------------------------------
|
10 |
+
Your custom static HTML goes in the body:
|
11 |
+
-->
|
12 |
+
|
13 |
+
<body>
|
14 |
+
</body>
|
15 |
+
|
16 |
+
<script type="text/javascript">
|
17 |
+
// Helper function to send type and data messages to Streamlit client
|
18 |
+
|
19 |
+
const SET_COMPONENT_VALUE = "streamlit:setComponentValue"
|
20 |
+
const RENDER = "streamlit:render"
|
21 |
+
const COMPONENT_READY = "streamlit:componentReady"
|
22 |
+
const SET_FRAME_HEIGHT = "streamlit:setFrameHeight"
|
23 |
+
var HIGHTLIGHT_COLOR;
|
24 |
+
var original_colors;
|
25 |
+
|
26 |
+
function _sendMessage(type, data) {
|
27 |
+
// copy data into object
|
28 |
+
var outboundData = Object.assign({
|
29 |
+
isStreamlitMessage: true,
|
30 |
+
type: type,
|
31 |
+
}, data)
|
32 |
+
|
33 |
+
if (type == SET_COMPONENT_VALUE) {
|
34 |
+
console.log("_sendMessage data: ", SET_COMPONENT_VALUE)
|
35 |
+
// console.log("_sendMessage data: " + JSON.stringify(data))
|
36 |
+
// console.log("_sendMessage outboundData: " + JSON.stringify(outboundData))
|
37 |
+
}
|
38 |
+
|
39 |
+
window.parent.postMessage(outboundData, "*")
|
40 |
+
}
|
41 |
+
|
42 |
+
function initialize(pipeline) {
|
43 |
+
|
44 |
+
// Hook Streamlit's message events into a simple dispatcher of pipeline handlers
|
45 |
+
window.addEventListener("message", (event) => {
|
46 |
+
if (event.data.type == RENDER) {
|
47 |
+
// The event.data.args dict holds any JSON-serializable value
|
48 |
+
// sent from the Streamlit client. It is already deserialized.
|
49 |
+
pipeline.forEach(handler => {
|
50 |
+
handler(event.data.args)
|
51 |
+
})
|
52 |
+
}
|
53 |
+
})
|
54 |
+
|
55 |
+
_sendMessage(COMPONENT_READY, { apiVersion: 1 });
|
56 |
+
|
57 |
+
// Component should be mounted by Streamlit in an iframe, so try to autoset the iframe height.
|
58 |
+
window.addEventListener("load", () => {
|
59 |
+
window.setTimeout(function () {
|
60 |
+
setFrameHeight(document.documentElement.clientHeight)
|
61 |
+
}, 0)
|
62 |
+
})
|
63 |
+
|
64 |
+
// Optionally, if auto-height computation fails, you can manually set it
|
65 |
+
// (uncomment below)
|
66 |
+
setFrameHeight(0)
|
67 |
+
}
|
68 |
+
|
69 |
+
function setFrameHeight(height) {
|
70 |
+
_sendMessage(SET_FRAME_HEIGHT, { height: height })
|
71 |
+
}
|
72 |
+
|
73 |
+
// The `data` argument can be any JSON-serializable value.
|
74 |
+
function notifyHost(data) {
|
75 |
+
_sendMessage(SET_COMPONENT_VALUE, data)
|
76 |
+
}
|
77 |
+
|
78 |
+
function changeButtonColor(button, color) {
|
79 |
+
pol = button.querySelectorAll('polygon')[0]
|
80 |
+
pol.setAttribute('fill', color)
|
81 |
+
pol.setAttribute('stroke', color)
|
82 |
+
}
|
83 |
+
|
84 |
+
function getButtonColor(button) {
|
85 |
+
pol = button.querySelectorAll('polygon')[0]
|
86 |
+
return pol.getAttribute('fill')
|
87 |
+
}
|
88 |
+
// ----------------------------------------------------
|
89 |
+
// Your custom functionality for the component goes here:
|
90 |
+
|
91 |
+
function toggle(button) {
|
92 |
+
group = 'node'
|
93 |
+
let button_color;
|
94 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
95 |
+
console.log("nodes.length = ", nodes.length)
|
96 |
+
// for (let i = 0; i < nodes.length; i++) {
|
97 |
+
// console.log(nodes.item(i))
|
98 |
+
// }
|
99 |
+
console.log("selected button ", button, button.getAttribute('class'), button.id)
|
100 |
+
|
101 |
+
for (let i = 0; i < nodes.length; i++) {
|
102 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
103 |
+
if (polygons.length == 0) {
|
104 |
+
continue
|
105 |
+
}
|
106 |
+
if (button.id == nodes.item(i).id & button.getAttribute('class').includes("off")) {
|
107 |
+
button.setAttribute('class', group + " on")
|
108 |
+
button_color = original_colors[i]
|
109 |
+
|
110 |
+
} else if (button.id == nodes.item(i).id & button.getAttribute('class').includes("on")) {
|
111 |
+
button.setAttribute('class', group + " off")
|
112 |
+
button_color = original_colors[i]
|
113 |
+
} else if (button.id == nodes.item(i).id) {
|
114 |
+
button.setAttribute('class', group + " on")
|
115 |
+
button_color = original_colors[i]
|
116 |
+
|
117 |
+
} else if (button.id != nodes.item(i).id & nodes.item(i).getAttribute('class').includes("on")) {
|
118 |
+
nodes.item(i).className = group + " off"
|
119 |
+
} else {
|
120 |
+
nodes.item(i).className = group + " off"
|
121 |
+
}
|
122 |
+
}
|
123 |
+
|
124 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
125 |
+
actions = []
|
126 |
+
for (let i = 0; i < nodes.length; i++) {
|
127 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
128 |
+
if (polygons.length == 0) {
|
129 |
+
continue
|
130 |
+
}
|
131 |
+
btn = nodes.item(i)
|
132 |
+
ori_color = original_colors[i]
|
133 |
+
color = btn.querySelectorAll('polygon')[0].getAttribute('fill')
|
134 |
+
actions.push({ "action": btn.getAttribute("class").includes("on"), "original_color": ori_color, "color": color})
|
135 |
+
}
|
136 |
+
|
137 |
+
states = {}
|
138 |
+
states['choice'] = {
|
139 |
+
"node_title": button.querySelectorAll("title")[0].innerHTML,
|
140 |
+
"node_id": button.id,
|
141 |
+
"state": {
|
142 |
+
"action": button.getAttribute("class").includes("on"),
|
143 |
+
"original_color": button_color,
|
144 |
+
"color": button.querySelectorAll('polygon')[0].getAttribute('fill')
|
145 |
+
}
|
146 |
+
}
|
147 |
+
states["options"] = {"states": actions }
|
148 |
+
|
149 |
+
notifyHost({
|
150 |
+
value: states,
|
151 |
+
dataType: "json",
|
152 |
+
})
|
153 |
+
}
|
154 |
+
|
155 |
+
// ----------------------------------------------------
|
156 |
+
// Here you can customize a pipeline of handlers for
|
157 |
+
// inbound properties from the Streamlit client app
|
158 |
+
|
159 |
+
// Set initial value sent from Streamlit!
|
160 |
+
function initializeProps_Handler(props) {
|
161 |
+
HIGHTLIGHT_COLOR = props['hightlight_color']
|
162 |
+
original_colors = []
|
163 |
+
// nodes = document.getElementsByClassName('node')
|
164 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
165 |
+
console.log(nodes)
|
166 |
+
for (let i = 0; i < nodes.length; i++) {
|
167 |
+
// color = nodes.item(i).getElementsByTagName('POLYGON')[0].getAttribute("fill")
|
168 |
+
// nodes.item(i).addEventListener("click", toggle)
|
169 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
170 |
+
if (polygons.length == 0) {
|
171 |
+
original_colors.push('none')
|
172 |
+
continue
|
173 |
+
}
|
174 |
+
|
175 |
+
color = polygons[0].getAttribute("fill")
|
176 |
+
if (!nodes.item(i).hasAttribute('color')) {
|
177 |
+
nodes.item(i).setAttribute("color", color)
|
178 |
+
original_colors.push(color)
|
179 |
+
} else {
|
180 |
+
original_colors.push(nodes.item(i).getAttribute("color"))
|
181 |
+
}
|
182 |
+
nodes.item(i).addEventListener("click", function (event) {toggle(this)})
|
183 |
+
}
|
184 |
+
// console.log("original colors:", original_colors)
|
185 |
+
}
|
186 |
+
// Access values sent from Streamlit!
|
187 |
+
function dataUpdate_Handler(props) {
|
188 |
+
console.log('dataUpdate_Handler...........')
|
189 |
+
let msgLabel = document.getElementById("message_label")
|
190 |
+
}
|
191 |
+
// Simply log received data dictionary
|
192 |
+
function log_Handler(props) {
|
193 |
+
console.log("Received from Streamlit: " + JSON.stringify(props))
|
194 |
+
}
|
195 |
+
|
196 |
+
let pipeline = [initializeProps_Handler, dataUpdate_Handler, log_Handler]
|
197 |
+
|
198 |
+
// ----------------------------------------------------
|
199 |
+
// Finally, initialize component passing in pipeline
|
200 |
+
initialize(pipeline)
|
201 |
+
|
202 |
+
</script>
|
203 |
+
|
204 |
+
</html>
|
pages/1_Maximally_activating_patches.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from plotly.subplots import make_subplots
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
|
7 |
+
import graphviz
|
8 |
+
|
9 |
+
from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
|
10 |
+
from frontend import on_click_graph
|
11 |
+
from backend.utils import load_dataset_dict
|
12 |
+
|
13 |
+
HIGHTLIGHT_COLOR = '#e7bcc5'
|
14 |
+
st.set_page_config(layout='wide')
|
15 |
+
|
16 |
+
|
17 |
+
st.title('Maximally activating image patches')
|
18 |
+
st.write('> **What patterns maximally activate this channel in ConvNeXt model?**')
|
19 |
+
st.write("""The maximally activating image patches method is a technique used in visualizing the interpretation of convolutional neural networks.
|
20 |
+
It works by identifying the regions of the input image that activate a particular neuron in the convolutional layer,
|
21 |
+
thus revealing the features that the neuron is detecting. To achieve this, the method generates image patches then feeds into the model while monitoring the neuron's activation.
|
22 |
+
The algorithm then selects the patch that produces the highest activation and overlays it on the original image to visualize the features that the neuron is responding to.
|
23 |
+
""")
|
24 |
+
|
25 |
+
# -------------------------- LOAD DATASET ---------------------------------
|
26 |
+
dataset_dict = load_dataset_dict()
|
27 |
+
|
28 |
+
# -------------------------- LOAD GRAPH -----------------------------------
|
29 |
+
|
30 |
+
def load_dot_to_graph(filename):
|
31 |
+
dot = graphviz.Source.from_file(filename)
|
32 |
+
source_lines = str(dot).splitlines()
|
33 |
+
source_lines.pop(0)
|
34 |
+
source_lines.pop(-1)
|
35 |
+
graph = graphviz.Digraph()
|
36 |
+
graph.body += source_lines
|
37 |
+
return graph, dot
|
38 |
+
|
39 |
+
|
40 |
+
# st.header('ConvNeXt')
|
41 |
+
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
|
42 |
+
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
|
43 |
+
|
44 |
+
convnext_graph.graph_attr['size'] = '4,40'
|
45 |
+
|
46 |
+
# -------------------------- DISPLAY GRAPH -----------------------------------
|
47 |
+
|
48 |
+
def chosen_node_text(clicked_node_title):
|
49 |
+
clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
|
50 |
+
stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
|
51 |
+
block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
|
52 |
+
layer_id = clicked_node_title.split()[-1]
|
53 |
+
|
54 |
+
if 'embeddings' in layer_id:
|
55 |
+
display_text = 'Patchify layer'
|
56 |
+
activation_key = 'embeddings.patch_embeddings'
|
57 |
+
elif 'downsampling' in layer_id:
|
58 |
+
display_text = f'Stage {stage_id} > Downsampling layer'
|
59 |
+
activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
|
60 |
+
else:
|
61 |
+
display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
|
62 |
+
activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
|
63 |
+
return display_text, activation_key
|
64 |
+
|
65 |
+
|
66 |
+
props = {
|
67 |
+
'hightlight_color': HIGHTLIGHT_COLOR,
|
68 |
+
'initial_state': {
|
69 |
+
'group_1_header': 'Choose an option from group 1',
|
70 |
+
'group_2_header': 'Choose an option from group 2'
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
col1, col2 = st.columns((2,5))
|
76 |
+
col1.markdown("#### Architecture")
|
77 |
+
col1.write('')
|
78 |
+
col1.write('Click on a layer below to generate top-k maximally activating image patches')
|
79 |
+
col1.graphviz_chart(convnext_graph)
|
80 |
+
|
81 |
+
with col2:
|
82 |
+
st.markdown("#### Output")
|
83 |
+
nodes = on_click_graph(key='toggle_buttons', **props)
|
84 |
+
|
85 |
+
# -------------------------- DISPLAY OUTPUT -----------------------------------
|
86 |
+
|
87 |
+
if nodes != None:
|
88 |
+
clicked_node_title = nodes["choice"]["node_title"]
|
89 |
+
clicked_node_id = nodes["choice"]["node_id"]
|
90 |
+
display_text, activation_key = chosen_node_text(clicked_node_title)
|
91 |
+
col2.write(f'**Chosen layer:** {display_text}')
|
92 |
+
# col2.write(f'**Activation key:** {activation_key}')
|
93 |
+
|
94 |
+
hightlight_syle = f'''
|
95 |
+
<style>
|
96 |
+
div[data-stale]:has(iframe) {{
|
97 |
+
height: 0;
|
98 |
+
}}
|
99 |
+
#{clicked_node_id}>polygon {{
|
100 |
+
fill: {HIGHTLIGHT_COLOR};
|
101 |
+
stroke: {HIGHTLIGHT_COLOR};
|
102 |
+
}}
|
103 |
+
</style>
|
104 |
+
'''
|
105 |
+
col2.markdown(hightlight_syle, unsafe_allow_html=True)
|
106 |
+
|
107 |
+
with col2:
|
108 |
+
layer_infos = None
|
109 |
+
with st.form('top_k_form'):
|
110 |
+
activation_path = './data/activation/convnext_activation.json'
|
111 |
+
activation = load_activation(activation_path)
|
112 |
+
num_channels = activation[activation_key].shape[1]
|
113 |
+
|
114 |
+
top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
|
115 |
+
channel_start, channel_end = st.slider(
|
116 |
+
'Choose channel range of this layer (recommend to choose small range less than 30)',
|
117 |
+
1, num_channels, value=(1, 30))
|
118 |
+
summit_button = st.form_submit_button('Generate image patches')
|
119 |
+
if summit_button:
|
120 |
+
|
121 |
+
activation = activation[activation_key][:top_k,:,:]
|
122 |
+
layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
|
123 |
+
# st.write(channel_start, channel_end)
|
124 |
+
# st.write(activation.shape, activation.shape[1])
|
125 |
+
|
126 |
+
if layer_infos != None:
|
127 |
+
num_cols, num_rows = top_k, channel_end - channel_start + 1
|
128 |
+
# num_rows = activation.shape[1]
|
129 |
+
top_k_coor_max_ = activation
|
130 |
+
st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
|
131 |
+
|
132 |
+
for row in range(channel_start, channel_end+1):
|
133 |
+
if row == channel_start:
|
134 |
+
top_margin = 50
|
135 |
+
fig = make_subplots(
|
136 |
+
rows=1, cols=num_cols,
|
137 |
+
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
|
138 |
+
else:
|
139 |
+
top_margin = 0
|
140 |
+
fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
|
141 |
+
for col in range(1, num_cols+1):
|
142 |
+
k, c = col-1, row-1
|
143 |
+
img_index = int(top_k_coor_max_[k, c, 3])
|
144 |
+
activation_value = top_k_coor_max_[k, c, 0]
|
145 |
+
img = dataset_dict[img_index//10_000][img_index%10_000]['image']
|
146 |
+
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
|
147 |
+
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
|
148 |
+
|
149 |
+
idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
|
150 |
+
x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
|
151 |
+
img = np.array(img)[y1:y2, x1:x2, :]
|
152 |
+
|
153 |
+
hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
|
154 |
+
fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
|
155 |
+
fig.update_xaxes(showticklabels=False, showgrid=False)
|
156 |
+
fig.update_yaxes(showticklabels=False, showgrid=False)
|
157 |
+
fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
|
158 |
+
fig.update_layout(showlegend=False, yaxis_title=row)
|
159 |
+
fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
|
160 |
+
fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
|
161 |
+
st.plotly_chart(fig, use_container_width=True)
|
162 |
+
|
163 |
+
|
164 |
+
else:
|
165 |
+
col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
|
166 |
+
col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
|
pages/2_SmoothGrad.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
from backend.utils import make_grid, load_dataset, load_model, load_images
|
6 |
+
|
7 |
+
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
|
8 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
12 |
+
_lock = RendererAgg.lock
|
13 |
+
|
14 |
+
st.set_page_config(layout='wide')
|
15 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
16 |
+
|
17 |
+
|
18 |
+
st.title('Feature attribution visualization with SmoothGrad')
|
19 |
+
st.write("""> **Which features are responsible for the current prediction of ConvNeXt?**
|
20 |
+
|
21 |
+
In machine learning, it is helpful to identify the significant features of the input (e.g., pixels for images) that affect the model's prediction.
|
22 |
+
If the model makes an incorrect prediction, we might want to determine which features contributed to the mistake.
|
23 |
+
To do this, we can generate a feature importance mask, which is a grayscale image with the same size as the original image.
|
24 |
+
The brightness of each pixel in the mask represents the importance of that feature to the model's prediction.
|
25 |
+
|
26 |
+
There are various methods to calculate an image sensitivity mask for a specific prediction.
|
27 |
+
One simple way is to use the gradient of a class prediction neuron concerning the input pixels, indicating how the prediction is affected by small pixel changes.
|
28 |
+
However, this method usually produces a noisy mask.
|
29 |
+
To reduce the noise, the SmoothGrad technique as described in [SmoothGrad: Removing noise by adding noise](https://arxiv.org/abs/1706.03825) by Daniel _et al_ is used,
|
30 |
+
which adds Gaussian noise to multiple copies of the image and averages the resulting gradients.
|
31 |
+
""")
|
32 |
+
|
33 |
+
instruction_text = """Users need to input the model(s), type of image set and image set setting to use this functionality.
|
34 |
+
1. Choose model: Users can choose one or more models for comparison.
|
35 |
+
There are 3 models supported: [ConvNeXt](https://huggingface.co/facebook/convnext-tiny-224),
|
36 |
+
[ResNet](https://huggingface.co/microsoft/resnet-50) and [MobileNet](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/).
|
37 |
+
These 3 models have similar number of parameters.
|
38 |
+
\n2. Choose type of Image set: There are 2 types of Image set. They are _User-defined set_ and _Random set_.
|
39 |
+
\n3. Image set setting: If users choose _User-defined set_ in Image set,
|
40 |
+
users need to enter a list of image IDs separated by commas (,). For example, `0,1,4,7` is a valid input.
|
41 |
+
Check the page [ImageNet1k](/ImageNet1k) to see all the Image IDs.
|
42 |
+
If users choose _Random set_ in Image set, users just need to choose the number of random images to display here.
|
43 |
+
"""
|
44 |
+
with st.expander("See more instruction", expanded=False):
|
45 |
+
st.write(instruction_text)
|
46 |
+
|
47 |
+
|
48 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
49 |
+
|
50 |
+
# --------------------------- LOAD function -----------------------------
|
51 |
+
|
52 |
+
|
53 |
+
images = []
|
54 |
+
image_ids = []
|
55 |
+
# INPUT ------------------------------
|
56 |
+
st.header('Input')
|
57 |
+
with st.form('smooth_grad_form'):
|
58 |
+
st.markdown('**Model and Input Setting**')
|
59 |
+
selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
|
60 |
+
selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
|
61 |
+
|
62 |
+
summit_button = st.form_submit_button('Set')
|
63 |
+
if summit_button:
|
64 |
+
setting_container = st.container()
|
65 |
+
# for id in image_ids:
|
66 |
+
# images = load_images(image_ids)
|
67 |
+
|
68 |
+
with st.form('2nd_form'):
|
69 |
+
st.markdown('**Image set setting**')
|
70 |
+
if selected_image_set == 'Random set':
|
71 |
+
no_images = st.slider('Number of images', 1, 50, value=10)
|
72 |
+
image_ids = random.sample(list(range(50_000)), k=no_images)
|
73 |
+
else:
|
74 |
+
text = st.text_area('Specific Image IDs', value='0')
|
75 |
+
image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
|
76 |
+
|
77 |
+
run_button = st.form_submit_button('Display output')
|
78 |
+
if run_button:
|
79 |
+
for id in image_ids:
|
80 |
+
images = load_images(image_ids)
|
81 |
+
|
82 |
+
st.header('Output')
|
83 |
+
|
84 |
+
models = {}
|
85 |
+
feature_extractors = {}
|
86 |
+
|
87 |
+
for i, model_name in enumerate(selected_models):
|
88 |
+
models[model_name], feature_extractors[model_name] = load_model(model_name)
|
89 |
+
|
90 |
+
|
91 |
+
# DISPLAY ----------------------------------
|
92 |
+
if run_button:
|
93 |
+
header_cols = st.columns([1, 1] + [2]*len(selected_models))
|
94 |
+
header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
|
95 |
+
header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
|
96 |
+
for i, model_name in enumerate(selected_models):
|
97 |
+
header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
|
98 |
+
|
99 |
+
grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
|
100 |
+
|
101 |
+
|
102 |
+
@st.cache(allow_output_mutation=True)
|
103 |
+
# @st.cache_data
|
104 |
+
def generate_images(image_id, model_name):
|
105 |
+
j = image_ids.index(image_id)
|
106 |
+
image = images[j]['image']
|
107 |
+
return generate_smoothgrad_mask(
|
108 |
+
image, model_name,
|
109 |
+
models[model_name], feature_extractors[model_name], num_samples=10)
|
110 |
+
|
111 |
+
with _lock:
|
112 |
+
for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
|
113 |
+
grids[j][0].write(f'{image_id}. {image_dict["label"]}')
|
114 |
+
image = image_dict['image']
|
115 |
+
ori_image = ShowImage(np.asarray(image))
|
116 |
+
grids[j][1].image(ori_image)
|
117 |
+
|
118 |
+
for i, model_name in enumerate(selected_models):
|
119 |
+
# ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
|
120 |
+
# model_name, models[model_name], feature_extractors[model_name], num_samples=10)
|
121 |
+
heatmap_image, masked_image = generate_images(image_id, model_name)
|
122 |
+
# grids[j][1].image(ori_image)
|
123 |
+
grids[j][i*2+2].image(heatmap_image)
|
124 |
+
grids[j][i*2+3].image(masked_image)
|
pages/3_Adversarial_attack.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
from backend.utils import make_grid, load_dataset, load_model, load_image
|
6 |
+
|
7 |
+
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img, LoadImage, ShowHeatMap, ShowMaskedImage
|
8 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
12 |
+
|
13 |
+
from backend.adversarial_attack import *
|
14 |
+
|
15 |
+
_lock = RendererAgg.lock
|
16 |
+
|
17 |
+
st.set_page_config(layout='wide')
|
18 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
19 |
+
SECONDARY_COLOR = '#bce7db'
|
20 |
+
|
21 |
+
|
22 |
+
st.title('Adversarial Attack')
|
23 |
+
st.write('> **How adversarial attacks affect ConvNeXt interpretation?**')
|
24 |
+
st.write("""Adversarial examples are inputs crafted to confuse neural networks, causing them to misclassify a given input.
|
25 |
+
These examples are not distinguishable by humans but cause the network to fail to recognize the image content.
|
26 |
+
One type of such attack is the fast gradient sign method (FGSM) attack, which is a white box attack that aims to ensure misclassification.
|
27 |
+
A white box attack is where the attacker has full access to the model being attacked.
|
28 |
+
|
29 |
+
The FGSM attack is one of the earliest and most popular adversarial attacks.
|
30 |
+
It is described by Goodfellow _et al_ in their work on [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572).
|
31 |
+
The attack is simple yet powerful, using the gradients that neural networks use to learn.
|
32 |
+
Instead of adjusting the weights based on the backpropagated gradients to minimize loss, the attack adjusts the input data to maximize the loss using the gradient of the loss with respect to the input data.
|
33 |
+
""")
|
34 |
+
|
35 |
+
instruction_text = """Instruction to input:
|
36 |
+
1. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
|
37 |
+
2. Choosing epsilon: **Epsilon** is the amount of perturbation on the original image under attack. The higher the epsilon is, the more pertubed the image is, the more confusion made to the model.
|
38 |
+
Users can choose a specific epsilon by engtering **Epsilon** and hit the _Choose the defined epsilon_ button.
|
39 |
+
Users can also let the algorithm find the smallest epsilon automatically by hitting the _Find the smallest epsilon automatically_ button.
|
40 |
+
The underlying algorithm will iterate through a set of epsilon in ascending order until reaching the **maximum value of epsilon**.
|
41 |
+
After each iteration, the epsilon will increase by an amount equal to **step** variable.
|
42 |
+
Users can change the default values of the two variable value optionally.
|
43 |
+
"""
|
44 |
+
st.write("To use the functionality below, users need to input the **image** and the **epsilon**.")
|
45 |
+
with st.expander("See more instruction", expanded=False):
|
46 |
+
st.write(instruction_text)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
51 |
+
image_id = None
|
52 |
+
|
53 |
+
if 'image_id' not in st.session_state:
|
54 |
+
st.session_state.image_id = 0
|
55 |
+
|
56 |
+
# def on_change_random_input():
|
57 |
+
# st.session_state.image_id = st.session_state.image_id
|
58 |
+
|
59 |
+
# ----------------------------- INPUT ----------------------------------
|
60 |
+
st.header('Input')
|
61 |
+
input_col_1, input_col_2, input_col_3 = st.columns(3)
|
62 |
+
# --------------------------- INPUT column 1 ---------------------------
|
63 |
+
with input_col_1:
|
64 |
+
with st.form('image_form'):
|
65 |
+
|
66 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
67 |
+
st.write('**Choose or generate a random image**')
|
68 |
+
chosen_image_id_input = st.empty()
|
69 |
+
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
70 |
+
|
71 |
+
choose_image_button = st.form_submit_button('Choose the defined image')
|
72 |
+
random_id = st.form_submit_button('Generate a random image')
|
73 |
+
|
74 |
+
if random_id:
|
75 |
+
image_id = random.randint(0, 50000)
|
76 |
+
st.session_state.image_id = image_id
|
77 |
+
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
78 |
+
|
79 |
+
if choose_image_button:
|
80 |
+
image_id = int(image_id)
|
81 |
+
st.session_state.image_id = int(image_id)
|
82 |
+
# st.write(image_id, st.session_state.image_id)
|
83 |
+
|
84 |
+
# ---------------------------- SET UP OUTPUT ------------------------------
|
85 |
+
epsilon_container = st.empty()
|
86 |
+
st.header('Output')
|
87 |
+
st.subheader('Perform attack')
|
88 |
+
|
89 |
+
# perform attack container
|
90 |
+
header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
91 |
+
output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
92 |
+
|
93 |
+
# prediction error container
|
94 |
+
error_container = st.empty()
|
95 |
+
smoothgrad_header_container = st.empty()
|
96 |
+
|
97 |
+
# smoothgrad container
|
98 |
+
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
|
99 |
+
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
|
100 |
+
|
101 |
+
original_image_dict = load_image(st.session_state.image_id)
|
102 |
+
input_image = original_image_dict['image']
|
103 |
+
input_label = original_image_dict['label']
|
104 |
+
input_id = original_image_dict['id']
|
105 |
+
|
106 |
+
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
107 |
+
with output_col_1:
|
108 |
+
pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
|
109 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
110 |
+
st.image(input_image)
|
111 |
+
header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
if pred_class_id != (input_id-1):
|
116 |
+
with error_container.container():
|
117 |
+
st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
|
118 |
+
st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
|
119 |
+
icon = "🚫")
|
120 |
+
|
121 |
+
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
122 |
+
with input_col_2:
|
123 |
+
with st.form('epsilon_form'):
|
124 |
+
st.write('**Set epsilon or find the smallest epsilon automatically**')
|
125 |
+
chosen_epsilon_input = st.empty()
|
126 |
+
epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
|
127 |
+
|
128 |
+
epsilon_button = st.form_submit_button('Choose the defined epsilon')
|
129 |
+
find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
|
130 |
+
|
131 |
+
|
132 |
+
with input_col_3:
|
133 |
+
with st.form('iterate_epsilon_form'):
|
134 |
+
max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
|
135 |
+
step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
|
136 |
+
setting_button = st.form_submit_button('Set iterating mode')
|
137 |
+
|
138 |
+
|
139 |
+
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
140 |
+
if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
|
141 |
+
with output_col_3:
|
142 |
+
if epsilon_button:
|
143 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
|
144 |
+
else:
|
145 |
+
epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
|
146 |
+
with epsilon_container.container():
|
147 |
+
epsilon_container_text = 'Checking epsilon'
|
148 |
+
st.write(epsilon_container_text)
|
149 |
+
st.progress(0)
|
150 |
+
|
151 |
+
for i, e in enumerate(epsilons):
|
152 |
+
|
153 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
|
154 |
+
with epsilon_container.container():
|
155 |
+
epsilon_container_text = f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%'
|
156 |
+
st.write(epsilon_container_text)
|
157 |
+
st.progress(i/len(epsilons))
|
158 |
+
|
159 |
+
epsilon = e
|
160 |
+
|
161 |
+
if new_id != input_id - 1:
|
162 |
+
epsilon_container.empty()
|
163 |
+
st.balloons()
|
164 |
+
break
|
165 |
+
if i == len(epsilons)-1:
|
166 |
+
epsilon_container.error(f'FGSM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
|
167 |
+
icon = "🚫")
|
168 |
+
|
169 |
+
perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
|
170 |
+
perturbed_amount = perturbed_image - input_image
|
171 |
+
header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
|
172 |
+
st.image(ShowImage(perturbed_amount))
|
173 |
+
|
174 |
+
with output_col_2:
|
175 |
+
# st.write('plus sign')
|
176 |
+
st.image(LoadImage('frontend/images/plus-sign.png'))
|
177 |
+
|
178 |
+
with output_col_4:
|
179 |
+
# st.write('equal sign')
|
180 |
+
st.image(LoadImage('frontend/images/equal-sign.png'))
|
181 |
+
|
182 |
+
# ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
|
183 |
+
with output_col_5:
|
184 |
+
# st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
|
185 |
+
st.image(ShowImage(perturbed_image))
|
186 |
+
header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
|
187 |
+
|
188 |
+
# -------------------------- DISPLAY SMOOTHGRAD ---------------------------
|
189 |
+
smoothgrad_header_container.subheader('SmoothGrad visualization')
|
190 |
+
|
191 |
+
with smoothgrad_col_1:
|
192 |
+
smooth_head_1.write(f'SmoothGrad before attacked')
|
193 |
+
heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
|
194 |
+
st.image(heatmap_image)
|
195 |
+
st.image(masked_image)
|
196 |
+
with smoothgrad_col_3:
|
197 |
+
smooth_head_3.write('SmoothGrad after attacked')
|
198 |
+
heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
|
199 |
+
st.image(heatmap_image_attacked)
|
200 |
+
st.image(masked_image_attacked)
|
201 |
+
|
202 |
+
with smoothgrad_col_2:
|
203 |
+
st.image(LoadImage('frontend/images/minus-sign-5.png'))
|
204 |
+
|
205 |
+
with smoothgrad_col_5:
|
206 |
+
smooth_head_5.write('SmoothGrad difference')
|
207 |
+
difference_mask = abs(attacked_mask-mask)
|
208 |
+
st.image(ShowHeatMap(difference_mask))
|
209 |
+
masked_image = ShowMaskedImage(difference_mask, perturbed_image)
|
210 |
+
st.image(masked_image)
|
211 |
+
|
212 |
+
with smoothgrad_col_4:
|
213 |
+
st.image(LoadImage('frontend/images/equal-sign.png'))
|
214 |
+
|
215 |
+
|
pages/4_ImageNet1k.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from backend.utils import load_dataset, use_container_width_percentage
|
5 |
+
|
6 |
+
st.set_page_config(layout='wide')
|
7 |
+
|
8 |
+
st.title('ImageNet-1k')
|
9 |
+
st.markdown('This page shows the summary of 50,000 images in the validation set of [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)')
|
10 |
+
|
11 |
+
# SCREEN_WIDTH, SCREEN_HEIGHT = 2560, 1664
|
12 |
+
|
13 |
+
with st.spinner("Loading dataset..."):
|
14 |
+
dataset_dict = {}
|
15 |
+
for data_index in range(5):
|
16 |
+
dataset_dict[data_index] = load_dataset(data_index)
|
17 |
+
|
18 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
19 |
+
|
20 |
+
class_labels = imagenet_df.ClassLabel.unique().tolist()
|
21 |
+
class_labels.sort()
|
22 |
+
selected_classes = st.multiselect('Class filter: ', options=['All'] + class_labels)
|
23 |
+
if not ('All' in selected_classes or len(selected_classes) == 0):
|
24 |
+
imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
|
25 |
+
# st.write(class_labels)
|
26 |
+
|
27 |
+
col1, col2 = st.columns([2, 1])
|
28 |
+
with col1:
|
29 |
+
st.dataframe(imagenet_df)
|
30 |
+
use_container_width_percentage(100)
|
31 |
+
|
32 |
+
with col2:
|
33 |
+
st.text_area('Type anything here to copy later :)')
|
34 |
+
image = None
|
35 |
+
with st.form("display image"):
|
36 |
+
img_index = st.text_input('Image ID to display')
|
37 |
+
|
38 |
+
submitted = st.form_submit_button('Display this image')
|
39 |
+
error_container = st.empty()
|
40 |
+
|
41 |
+
if submitted:
|
42 |
+
try:
|
43 |
+
img_index = int(img_index)
|
44 |
+
if img_index > 50000-1 or img_index < 0:
|
45 |
+
error_container.error('The Image ID must be in range from 0 to 49999', icon="🚫")
|
46 |
+
else:
|
47 |
+
image = dataset_dict[img_index//10_000][img_index%10_000]['image']
|
48 |
+
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
|
49 |
+
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
|
50 |
+
except ValueError:
|
51 |
+
error_container.error('Please enter an integer number for Image ID', icon = "🚫")
|
52 |
+
|
53 |
+
if image != None:
|
54 |
+
st.image(image)
|
55 |
+
st.write('**Class label:** ', class_label)
|
56 |
+
st.write('\n**Class id:** ', str(class_id))
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
captum==0.5.0
|
2 |
+
graphviz==0.20.1
|
3 |
+
Markdown==3.4.1
|
4 |
+
matplotlib==3.6.2
|
5 |
+
numpy==1.22.3
|
6 |
+
opencv_python_headless==4.6.0.66
|
7 |
+
pandas==1.5.2
|
8 |
+
Pillow==9.4.0
|
9 |
+
plotly==5.11.0
|
10 |
+
scipy==1.10.1
|
11 |
+
setuptools==65.5.0
|
12 |
+
streamlit==1.19.0
|
13 |
+
torch==1.10.1
|
14 |
+
torchvision==0.11.2
|
15 |
+
tqdm==4.64.1
|
16 |
+
transformers==4.25.1
|
17 |
+
git+https://github.com/vlue-c/Visual-Explanation-Methods-PyTorch.git
|