Ludovica Schaerf taquynhnga commited on
Commit
fcc16aa
·
0 Parent(s):

Duplicate from taquynhnga/CNNs-interpretation-visualization

Browse files

Co-authored-by: Hanna Ta Quynh Nga <[email protected]>

Files changed (41) hide show
  1. .gitattributes +4 -0
  2. .github/workflows/sync_to_huggingface_hub.yml +19 -0
  3. .gitignore +183 -0
  4. .vscode/settings.json +5 -0
  5. Home.py +43 -0
  6. README.md +19 -0
  7. Visual-Explanation-Methods-PyTorch +1 -0
  8. backend/adversarial_attack.py +100 -0
  9. backend/load_file.py +41 -0
  10. backend/maximally_activating_patches.py +45 -0
  11. backend/smooth_grad.py +235 -0
  12. backend/utils.py +379 -0
  13. data/ImageNet_metadata.csv +3 -0
  14. data/activation/convnext_activation.json +3 -0
  15. data/activation/mobilenet_activation.json +3 -0
  16. data/activation/resnet_activation.json +3 -0
  17. data/dot_architectures/convnext_architecture.dot +3 -0
  18. data/layer_infos/convnext_layer_infos.json +3 -0
  19. data/layer_infos/mobilenet_layer_infos.json +3 -0
  20. data/layer_infos/resnet_layer_infos.json +3 -0
  21. data/preprocessed_image_net/val_data_0.pkl +3 -0
  22. data/preprocessed_image_net/val_data_1.pkl +3 -0
  23. data/preprocessed_image_net/val_data_2.pkl +3 -0
  24. data/preprocessed_image_net/val_data_3.pkl +3 -0
  25. data/preprocessed_image_net/val_data_4.pkl +3 -0
  26. frontend/__init__.py +6 -0
  27. frontend/footer.py +32 -0
  28. frontend/images/equal-sign.png +0 -0
  29. frontend/images/minus-sign-2.png +0 -0
  30. frontend/images/minus-sign-3.png +0 -0
  31. frontend/images/minus-sign-4.png +0 -0
  32. frontend/images/minus-sign-5.png +0 -0
  33. frontend/images/minus-sign.png +0 -0
  34. frontend/images/plus-sign-2.png +0 -0
  35. frontend/images/plus-sign.png +0 -0
  36. frontend/index.html +204 -0
  37. pages/1_Maximally_activating_patches.py +166 -0
  38. pages/2_SmoothGrad.py +124 -0
  39. pages/3_Adversarial_attack.py +215 -0
  40. pages/4_ImageNet1k.py +56 -0
  41. requirements.txt +17 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pkl filter=lfs diff=lfs merge=lfs -text
2
+ .csv filter=lfs diff=lfs merge=lfs -text
3
+ data/** filter=lfs diff=lfs merge=lfs -text
4
+ Visual-Explanation-Methods-PyTorch/** filter=lfs diff=lfs merge=lfs -text
.github/workflows/sync_to_huggingface_hub.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ - name: Push to hub
17
+ env:
18
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
+ run: git push --force https://taquynhnga:[email protected]/spaces/taquynhnga/CNNs-interpretation-visualization main
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VSCode
2
+ .vscode/*
3
+ !.vscode/settings.json
4
+ !.vscode/tasks.json
5
+ !.vscode/launch.json
6
+ !.vscode/extensions.json
7
+ *.code-workspace
8
+ # Local History for Visual Studio Code
9
+ .history/
10
+
11
+ # Common credential files
12
+ **/credentials.json
13
+ **/client_secrets.json
14
+ **/client_secret.json
15
+ *creds*
16
+ *.dat
17
+ *password*
18
+ *.httr-oauth*
19
+
20
+ # Private Node Modules
21
+ node_modules/
22
+ creds.js
23
+
24
+ # Private Files
25
+ # *.json
26
+ # *.csv
27
+ # *.csv.gz
28
+ # *.tsv
29
+ # *.tsv.gz
30
+ # *.xlsx
31
+ git-large-file
32
+ deta_drive.py
33
+ secret_keys.py
34
+
35
+ # Large files
36
+ # data/preprocessed_image_net/
37
+ # data/activation/*.pkl
38
+ # data/activation/*.json
39
+ # data/layer_infos/*.pkl
40
+ # data/layer_infos/*.json
41
+
42
+ # Mac/OSX
43
+ .DS_Store
44
+
45
+
46
+ # Byte-compiled / optimized / DLL files
47
+ __pycache__/
48
+ *.py[cod]
49
+ *$py.class
50
+
51
+ # C extensions
52
+ *.so
53
+
54
+ # Distribution / packaging
55
+ .Python
56
+ build/
57
+ develop-eggs/
58
+ dist/
59
+ downloads/
60
+ eggs/
61
+ .eggs/
62
+ lib/
63
+ lib64/
64
+ parts/
65
+ sdist/
66
+ var/
67
+ wheels/
68
+ share/python-wheels/
69
+ *.egg-info/
70
+ .installed.cfg
71
+ *.egg
72
+ MANIFEST
73
+
74
+ # PyInstaller
75
+ # Usually these files are written by a python script from a template
76
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
77
+ *.manifest
78
+ *.spec
79
+
80
+ # Installer logs
81
+ pip-log.txt
82
+ pip-delete-this-directory.txt
83
+
84
+ # Unit test / coverage reports
85
+ htmlcov/
86
+ .tox/
87
+ .nox/
88
+ .coverage
89
+ .coverage.*
90
+ .cache
91
+ nosetests.xml
92
+ coverage.xml
93
+ *.cover
94
+ *.py,cover
95
+ .hypothesis/
96
+ .pytest_cache/
97
+ cover/
98
+
99
+ # Translations
100
+ *.mo
101
+ *.pot
102
+
103
+ # Django stuff:
104
+ *.log
105
+ local_settings.py
106
+ db.sqlite3
107
+ db.sqlite3-journal
108
+
109
+ # Flask stuff:
110
+ instance/
111
+ .webassets-cache
112
+
113
+ # Scrapy stuff:
114
+ .scrapy
115
+
116
+ # Sphinx documentation
117
+ docs/_build/
118
+
119
+ # PyBuilder
120
+ .pybuilder/
121
+ target/
122
+
123
+ # Jupyter Notebook
124
+ .ipynb_checkpoints
125
+
126
+ # IPython
127
+ profile_default/
128
+ ipython_config.py
129
+
130
+ # pyenv
131
+ # For a library or package, you might want to ignore these files since the code is
132
+ # intended to run in multiple environments; otherwise, check them in:
133
+ # .python-version
134
+
135
+ # pipenv
136
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
137
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
138
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
139
+ # install all needed dependencies.
140
+ #Pipfile.lock
141
+
142
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
143
+ __pypackages__/
144
+
145
+ # Celery stuff
146
+ celerybeat-schedule
147
+ celerybeat.pid
148
+
149
+ # SageMath parsed files
150
+ *.sage.py
151
+
152
+ # Environments
153
+ .env
154
+ .venv
155
+ env/
156
+ venv/
157
+ ENV/
158
+ env.bak/
159
+ venv.bak/
160
+
161
+ # Spyder project settings
162
+ .spyderproject
163
+ .spyproject
164
+
165
+ # Rope project settings
166
+ .ropeproject
167
+
168
+ # mkdocs documentation
169
+ /site
170
+
171
+ # mypy
172
+ .mypy_cache/
173
+ .dmypy.json
174
+ dmypy.json
175
+
176
+ # Pyre type checker
177
+ .pyre/
178
+
179
+ # pytype static type analyzer
180
+ .pytype/
181
+
182
+ # Cython debug symbols
183
+ cython_debug/
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python.analysis.extraPaths": [
3
+ "./Visual-Explanation-Methods-PyTorch"
4
+ ]
5
+ }
Home.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from frontend.footer import add_footer
3
+
4
+ st.set_page_config(layout='wide')
5
+ # st.set_page_config(layout='centered')
6
+
7
+ st.title('About')
8
+
9
+ # INTRO
10
+ intro_text = """Convolutional neural networks (ConvNets) have evolved at a rapid speed from the 2010s.
11
+ Some of the representative ConvNets models are VGGNet, Inceptions, ResNe(X)t, DenseNet, MobileNet, EfficientNet and RegNet, which focus on various factors of accuracy, efficiency, and scalability.
12
+ In the year 2020, Vision Transformers (ViT) was introduced as a Transformer model solving the computer vision problems.
13
+ Larger model and dataset sizes allow ViT to perform significantly better than ResNet, however, ViT still encountered challenges in generic computer vision tasks such as object detection and semantic segmentation.
14
+ Swin Transformer’ s success made Transformers be adopted as a generic vision backbone and showed outstanding performance in a wide range of computer vision tasks.
15
+ Nevertheless, rather than the intrinsic inductive biases of convolutions, the success of this approach is still primarily attributed to Transformers’ inherent superiority.
16
+
17
+ In 2022, Zhuang Liu et. al. proposed a pure convolutional model dubbed ConvNeXt, discovered from the modernization of a standard ResNet towards the design of Vision Transformers and claimed to outperform them.
18
+
19
+ The project aims to interpret the ConvNeXt model by several visualization techniques.
20
+ After that, a web interface would be built to demonstrate the interpretations, helping us look inside the deep ConvNeXt model and answer the questions:
21
+ > “What patterns maximally activated this filter (channel) in this layer?”\n
22
+ > “Which features are responsible for the current prediction?”.
23
+
24
+ Due to the limitation in time and resources, the project only used the tiny-sized ConvNeXt model, which was trained on ImageNet-1k at resolution 224x224 and used 50,000 images in validation set of ImageNet-1k for demo purpose.
25
+
26
+ In this web app, two visualization techniques were implemented and demonstrated, they are **Maximally activating patches** and **SmoothGrad**.
27
+ Besides, this web app also helps investigate the effect of **adversarial attacks** on ConvNeXt interpretations.
28
+ Last but not least, there is a last webpage that stores 50,000 images in the **ImageNet-1k** validation set, facilitating the two web pages above in searching and referencing.
29
+ """
30
+ st.write(intro_text)
31
+
32
+ # 4 PAGES
33
+ st.subheader('Features')
34
+ sections_text = """Overall, there are 4 features in this web app:
35
+ 1) Maximally activating patches: The visualization method in this page answers the question “what patterns maximally activated this filter (channel)?”.
36
+ 2) SmoothGrad: This visualization method in this page answers the question “which features are responsible for the current prediction?”.
37
+ 3) Adversarial attack: How adversarial attacks affect ConvNeXt interpretation?
38
+ 4) ImageNet1k: The storage of 50,000 images in validation set.
39
+ """
40
+ st.write(sections_text)
41
+
42
+
43
+ add_footer('Developed with ❤ by ', 'Hanna Ta Quynh Nga', 'https://www.linkedin.com/in/ta-quynh-nga-hanna/')
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CNNs Interpretation Visualization
3
+ emoji: 💡
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.17.0
8
+ app_file: Home.py
9
+ pinned: false
10
+ duplicated_from: taquynhnga/CNNs-interpretation-visualization
11
+ ---
12
+
13
+ # Visualizing Interpretations of CNN models: ConvNeXt, ResNet and MobileNet
14
+
15
+ To be change name: CNNs-interpretation-visualization
16
+
17
+ This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
18
+
19
+ This repo lacks one more folder `data/preprocessed_image_net` which contains 50,000 preprocessed imagenet validation images saved in 5 pickle files.
Visual-Explanation-Methods-PyTorch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 5cb88902729af1d9d85259879b47cb238b841881
backend/adversarial_attack.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ import numpy as np
4
+ from matplotlib import pylab as P
5
+ import cv2
6
+
7
+ import torch
8
+ from torch.utils.data import TensorDataset
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+
14
+ from torchvex.base import ExplanationMethod
15
+ from torchvex.utils.normalization import clamp_quantile
16
+
17
+ from backend.utils import load_image, load_model
18
+ from backend.smooth_grad import generate_smoothgrad_mask
19
+
20
+ import streamlit as st
21
+
22
+ IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
23
+ IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])
24
+
25
+ def deprocess_image(image_inputs):
26
+ return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255
27
+
28
+
29
+ def feed_forward(input_image):
30
+ model, feature_extractor = load_model('ConvNeXt')
31
+ inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
32
+ logits = model(inputs).logits
33
+ prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability
34
+ # prediction class id, start from 1 to 1000 so it needs to +1 in the end
35
+ prediction_class = logits.argmax(-1).item()
36
+ prediction_label = model.config.id2label[prediction_class] # prediction class label
37
+ return prediction_prob, prediction_class, prediction_label
38
+
39
+ # FGSM attack code
40
+ def fgsm_attack(image, epsilon, data_grad):
41
+ # Collect the element-wise sign of the data gradient and normalize it
42
+ sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
43
+ perturbed_image = image + epsilon*sign_data_grad
44
+ return perturbed_image
45
+
46
+ # perform attack on the model
47
+ def perform_attack(input_image, target, epsilon):
48
+ model, feature_extractor = load_model("ConvNeXt")
49
+ # preprocess input image
50
+ inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
51
+ inputs.requires_grad = True
52
+
53
+ # predict
54
+ logits = model(inputs).logits
55
+ prediction_prob = F.softmax(logits, dim=-1).max()
56
+ prediction_class = logits.argmax(-1).item()
57
+ prediction_label = model.config.id2label[prediction_class]
58
+
59
+ # Calculate the loss
60
+ loss = F.nll_loss(logits, torch.tensor([target]))
61
+
62
+ # Zero all existing gradients
63
+ model.zero_grad()
64
+
65
+ # Calculate gradients of model in backward pass
66
+ loss.backward()
67
+
68
+ # Collect datagrad
69
+ data_grad = inputs.grad.data
70
+
71
+ # Call FGSM Attack
72
+ perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
73
+
74
+ # Re-classify the perturbed image
75
+ new_prediction = model(perturbed_data).logits
76
+ new_pred_prob = F.softmax(new_prediction, dim=-1).max()
77
+ new_pred_class = new_prediction.argmax(-1).item()
78
+ new_pred_label = model.config.id2label[new_pred_class]
79
+
80
+ return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label
81
+
82
+
83
+ def find_smallest_epsilon(input_image, target):
84
+ epsilons = [i*0.001 for i in range(1000)]
85
+
86
+ for epsilon in epsilons:
87
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
88
+ if new_id != target:
89
+ return perturbed_data, new_prob, new_id, new_label, epsilon
90
+ return None
91
+
92
+ # @st.cache_data
93
+ @st.cache(allow_output_mutation=True)
94
+ def generate_images(image_id, epsilon=0):
95
+ model, feature_extractor = load_model("ConvNeXt")
96
+ original_image_dict = load_image(image_id)
97
+ image = original_image_dict['image']
98
+ return generate_smoothgrad_mask(
99
+ image, 'ConvNeXt',
100
+ model, feature_extractor, num_samples=10, return_mask=True)
backend/load_file.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+
6
+ def load_pickle(filename):
7
+ with open(filename, 'rb') as file:
8
+ data = pickle.load(file)
9
+ return data
10
+
11
+ def save_pickle_to_json(filename):
12
+ ordered_dict = load_pickle(filename)
13
+ json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
14
+ with open(filename.replace('.pkl', '.json'), 'w') as f:
15
+ f.write(json_obj)
16
+
17
+ def load_json(filename):
18
+ with open(filename, 'r') as read_file:
19
+ loaded_dict = json.loads(read_file.read())
20
+ loaded_dict = OrderedDict(loaded_dict)
21
+ for k, v in loaded_dict.items():
22
+ if type(v) == list:
23
+ loaded_dict[k] = np.asarray(v)
24
+ else:
25
+ for k_, v_ in v.items():
26
+ v[k_] = np.asarray(v_)
27
+ return loaded_dict
28
+
29
+ class NumpyEncoder(json.JSONEncoder):
30
+ def default(self, obj):
31
+ if isinstance(obj, np.ndarray):
32
+ return obj.tolist()
33
+ return json.JSONEncoder.default(self, obj)
34
+
35
+ # save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
36
+ # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
37
+ # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
38
+
39
+ # file = load_json('data/layer_infos/convnext_layer_infos.json')
40
+ # print(type(file))
41
+ # print(type(file['embeddings.patch_embeddings']))
backend/maximally_activating_patches.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import streamlit as st
3
+
4
+ from backend.load_file import load_json
5
+
6
+
7
+ @st.cache(allow_output_mutation=True)
8
+ # st.cache_data
9
+ def load_activation(filename):
10
+ activation = load_json(filename)
11
+ return activation
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ # @st.cache_data
15
+ def load_dataset(data_index):
16
+ with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
17
+ dataset = pickle.load(file)
18
+ return dataset
19
+
20
+ def load_layer_infos(filename):
21
+ layer_infos = load_json(filename)
22
+ return layer_infos
23
+
24
+ def get_receptive_field_coordinates(layer_infos, layer_name, idx_x, idx_y):
25
+ """
26
+ layer_name: as in layer_infos keys (eg: 'encoder.stages[0].layers[0]')
27
+ idx_x: integer coordinate of width axis in feature maps. must < n
28
+ idx_y: integer coordinate of height axis in feature maps. must < n
29
+ """
30
+ layer_name = layer_name.replace('.dwconv', '').replace('.layernorm', '')
31
+ layer_name = layer_name.replace('.pwconv1', '').replace('.pwconv2', '').replace('.drop_path', '')
32
+ n = layer_infos[layer_name]['n']
33
+ j = layer_infos[layer_name]['j']
34
+ r = layer_infos[layer_name]['r']
35
+ start = layer_infos[layer_name]['start']
36
+ assert idx_x < n, f'n={n}'
37
+ assert idx_y < n, f'n={n}'
38
+
39
+ # image tensor (N, H, W, C) or (N, C, H, W) => image_patch=image[y1:y2, x1:x2]
40
+ center = (start + idx_x*j, start + idx_y*j)
41
+ x1, x2 = (max(center[0]-r/2, 0), max(center[0]+r/2, 0))
42
+ y1, y2 = (max(center[1]-r/2, 0), max(center[1]+r/2, 0))
43
+ x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
44
+
45
+ return x1, x2, y1, y2
backend/smooth_grad.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ import numpy as np
4
+ from matplotlib import pylab as P
5
+ import cv2
6
+
7
+ import torch
8
+ from torch.utils.data import TensorDataset
9
+ from torchvision import transforms
10
+
11
+ # dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
12
+ # sys.path.append(dirpath_to_modules)
13
+
14
+ from torchvex.base import ExplanationMethod
15
+ from torchvex.utils.normalization import clamp_quantile
16
+
17
+ def ShowImage(im, title='', ax=None):
18
+ image = np.array(im)
19
+ return image
20
+
21
+ def ShowGrayscaleImage(im, title='', ax=None):
22
+ if ax is None:
23
+ P.figure()
24
+ P.axis('off')
25
+ P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1)
26
+ P.title(title)
27
+ return P
28
+
29
+ def ShowHeatMap(im, title='', ax=None):
30
+ im = im - im.min()
31
+ im = im / im.max()
32
+ im = im.clip(0,1)
33
+ im = np.uint8(im * 255)
34
+
35
+ im = cv2.resize(im, (224,224))
36
+ image = cv2.resize(im, (224, 224))
37
+
38
+ # Apply JET colormap
39
+ color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT)
40
+ # P.imshow(im, cmap='inferno')
41
+ # P.title(title)
42
+ return color_heatmap
43
+
44
+ def ShowMaskedImage(saliency_map, image, title='', ax=None):
45
+ """
46
+ Save saliency map on image.
47
+
48
+ Args:
49
+ image: Tensor of size (H,W,3)
50
+ saliency_map: Tensor of size (H,W,1)
51
+ """
52
+
53
+ # if ax is None:
54
+ # P.figure()
55
+ # P.axis('off')
56
+
57
+ saliency_map = saliency_map - saliency_map.min()
58
+ saliency_map = saliency_map / saliency_map.max()
59
+ saliency_map = saliency_map.clip(0,1)
60
+ saliency_map = np.uint8(saliency_map * 255)
61
+
62
+ saliency_map = cv2.resize(saliency_map, (224,224))
63
+ image = cv2.resize(image, (224, 224))
64
+
65
+ # Apply JET colormap
66
+ color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT)
67
+
68
+ # Blend image with heatmap
69
+ img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0)
70
+
71
+ # P.imshow(img_with_heatmap)
72
+ # P.title(title)
73
+ return img_with_heatmap
74
+
75
+ def LoadImage(file_path):
76
+ im = PIL.Image.open(file_path)
77
+ im = im.resize((224, 224))
78
+ im = np.asarray(im)
79
+ return im
80
+
81
+
82
+ def visualize_image_grayscale(image_3d, percentile=99):
83
+ r"""Returns a 3D tensor as a grayscale 2D tensor.
84
+ This method sums a 3D tensor across the absolute value of axis=2, and then
85
+ clips values at a given percentile.
86
+ """
87
+ image_2d = np.sum(np.abs(image_3d), axis=2)
88
+
89
+ vmax = np.percentile(image_2d, percentile)
90
+ vmin = np.min(image_2d)
91
+
92
+ return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)
93
+
94
+ def visualize_image_diverging(image_3d, percentile=99):
95
+ r"""Returns a 3D tensor as a 2D tensor with positive and negative values.
96
+ """
97
+ image_2d = np.sum(image_3d, axis=2)
98
+
99
+ span = abs(np.percentile(image_2d, percentile))
100
+ vmin = -span
101
+ vmax = span
102
+
103
+ return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1)
104
+
105
+
106
+ class SimpleGradient(ExplanationMethod):
107
+ def __init__(self, model, create_graph=False,
108
+ preprocess=None, postprocess=None):
109
+ super().__init__(model, preprocess, postprocess)
110
+ self.create_graph = create_graph
111
+
112
+ def predict(self, x):
113
+ return self.model(x)
114
+
115
+ @torch.enable_grad()
116
+ def process(self, inputs, target):
117
+ self.model.zero_grad()
118
+ inputs.requires_grad_(True)
119
+
120
+ out = self.model(inputs)
121
+ out = out if type(out) == torch.Tensor else out.logits
122
+
123
+ num_classes = out.size(-1)
124
+ onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:])
125
+ onehot = onehot.to(dtype=inputs.dtype, device=inputs.device)
126
+ onehot.scatter_(1, target.unsqueeze(1), 1)
127
+
128
+ grad, = torch.autograd.grad(
129
+ (out*onehot).sum(), inputs, create_graph=self.create_graph
130
+ )
131
+
132
+ return grad
133
+
134
+
135
+ class SmoothGradient(ExplanationMethod):
136
+ def __init__(self, model, stdev_spread=0.15, num_samples=25,
137
+ magnitude=True, batch_size=-1,
138
+ create_graph=False, preprocess=None, postprocess=None):
139
+ super().__init__(model, preprocess, postprocess)
140
+ self.stdev_spread = stdev_spread
141
+ self.nsample = num_samples
142
+ self.create_graph = create_graph
143
+ self.magnitude = magnitude
144
+ self.batch_size = batch_size
145
+ if self.batch_size == -1:
146
+ self.batch_size = self.nsample
147
+
148
+ self._simgrad = SimpleGradient(model, create_graph)
149
+
150
+ def process(self, inputs, target):
151
+ self.model.zero_grad()
152
+
153
+ maxima = inputs.flatten(1).max(-1)[0]
154
+ minima = inputs.flatten(1).min(-1)[0]
155
+
156
+ stdev = self.stdev_spread * (maxima - minima).cpu()
157
+ stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs)
158
+ stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4)
159
+ noise = torch.normal(0, stdev)
160
+
161
+ target_expanded = target.unsqueeze(0).cpu()
162
+ target_expanded = target_expanded.expand(noise.size(0), -1)
163
+
164
+ noiseloader = torch.utils.data.DataLoader(
165
+ TensorDataset(noise, target_expanded), batch_size=self.batch_size
166
+ )
167
+
168
+ total_gradients = torch.zeros_like(inputs)
169
+ for noise, t_exp in noiseloader:
170
+ inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device)
171
+ inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:])
172
+ gradients = self._simgrad(inputs_w_noise, t_exp.view(-1))
173
+ gradients = gradients.view(self.batch_size, *inputs.shape)
174
+ if self.magnitude:
175
+ gradients = gradients.pow(2)
176
+ total_gradients = total_gradients + gradients.sum(0)
177
+
178
+ smoothed_gradient = total_gradients / self.nsample
179
+ return smoothed_gradient
180
+
181
+
182
+ def feed_forward(model_name, image, model=None, feature_extractor=None):
183
+ if model_name in ['ConvNeXt', 'ResNet']:
184
+ inputs = feature_extractor(image, return_tensors="pt")
185
+ logits = model(**inputs).logits
186
+ prediction_class = logits.argmax(-1).item()
187
+ else:
188
+ transform_images = transforms.Compose([
189
+ transforms.Resize(224),
190
+ transforms.ToTensor(),
191
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
192
+ input_tensor = transform_images(image)
193
+ inputs = input_tensor.unsqueeze(0)
194
+
195
+ output = model(inputs)
196
+ prediction_class = output.argmax(-1).item()
197
+ #prediction_label = model.config.id2label[prediction_class]
198
+ return inputs, prediction_class
199
+
200
+ def clip_gradient(gradient):
201
+ gradient = gradient.abs().sum(1, keepdim=True)
202
+ return clamp_quantile(gradient, q=0.99)
203
+
204
+ def fig2img(fig):
205
+ """Convert a Matplotlib figure to a PIL Image and return it"""
206
+ import io
207
+ buf = io.BytesIO()
208
+ fig.savefig(buf)
209
+ buf.seek(0)
210
+ img = Image.open(buf)
211
+ return img
212
+
213
+ def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False):
214
+ inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
215
+
216
+ smoothgrad_gen = SmoothGradient(
217
+ model, num_samples=num_samples, stdev_spread=0.1,
218
+ magnitude=False, postprocess=clip_gradient)
219
+
220
+ if type(inputs) != torch.Tensor:
221
+ inputs = inputs['pixel_values']
222
+
223
+ smoothgrad_mask = smoothgrad_gen(inputs, prediction_class)
224
+ smoothgrad_mask = smoothgrad_mask[0].numpy()
225
+ smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0))
226
+
227
+ image = np.asarray(image)
228
+ # ori_image = ShowImage(image)
229
+ heat_map_image = ShowHeatMap(smoothgrad_mask)
230
+ masked_image = ShowMaskedImage(smoothgrad_mask, image)
231
+
232
+ if return_mask:
233
+ return heat_map_image, masked_image, smoothgrad_mask
234
+ else:
235
+ return heat_map_image, masked_image
backend/utils.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+
4
+ import io
5
+ from typing import List, Optional
6
+
7
+ import markdown
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ import streamlit as st
13
+ from plotly import express as px
14
+ from plotly.subplots import make_subplots
15
+ from tqdm import trange
16
+
17
+ import torch
18
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
19
+
20
+ @st.cache(allow_output_mutation=True)
21
+ # @st.cache_resource
22
+ def load_dataset(data_index):
23
+ with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
24
+ dataset = pickle.load(file)
25
+ return dataset
26
+
27
+ @st.cache(allow_output_mutation=True)
28
+ # @st.cache_resource
29
+ def load_dataset_dict():
30
+ dataset_dict = {}
31
+ progress_empty = st.empty()
32
+ text_empty = st.empty()
33
+ text_empty.write("Loading datasets...")
34
+ progress_bar = progress_empty.progress(0.0)
35
+ for data_index in trange(5):
36
+ dataset_dict[data_index] = load_dataset(data_index)
37
+ progress_bar.progress((data_index+1)/5)
38
+ progress_empty.empty()
39
+ text_empty.empty()
40
+ return dataset_dict
41
+
42
+
43
+ # @st.cache_data
44
+ @st.cache(allow_output_mutation=True)
45
+ def load_image(image_id):
46
+ dataset = load_dataset(image_id//10000)
47
+ image = dataset[image_id%10000]
48
+ return image
49
+
50
+ # @st.cache_data
51
+ @st.cache(allow_output_mutation=True)
52
+ def load_images(image_ids):
53
+ images = []
54
+ for image_id in image_ids:
55
+ image = load_image(image_id)
56
+ images.append(image)
57
+ return images
58
+
59
+
60
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
61
+ # @st.cache_resource
62
+ def load_model(model_name):
63
+ with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
64
+ if model_name == 'ResNet':
65
+ model_file_path = 'microsoft/resnet-50'
66
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
67
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
68
+ model.eval()
69
+ elif model_name == 'ConvNeXt':
70
+ model_file_path = 'facebook/convnext-tiny-224'
71
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
72
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
73
+ model.eval()
74
+ else:
75
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
76
+ model.eval()
77
+ feature_extractor = None
78
+ return model, feature_extractor
79
+
80
+
81
+ def make_grid(cols=None,rows=None):
82
+ grid = [0]*rows
83
+ for i in range(rows):
84
+ with st.container():
85
+ grid[i] = st.columns(cols)
86
+ return grid
87
+
88
+
89
+ def use_container_width_percentage(percentage_width:int = 75):
90
+ max_width_str = f"max-width: {percentage_width}%;"
91
+ st.markdown(f"""
92
+ <style>
93
+ .reportview-container .main .block-container{{{max_width_str}}}
94
+ </style>
95
+ """,
96
+ unsafe_allow_html=True,
97
+ )
98
+
99
+ matplotlib.use("Agg")
100
+ COLOR = "#31333f"
101
+ BACKGROUND_COLOR = "#ffffff"
102
+
103
+
104
+ def grid_demo():
105
+ """Main function. Run this to run the app"""
106
+ st.sidebar.title("Layout and Style Experiments")
107
+ st.sidebar.header("Settings")
108
+ st.markdown(
109
+ """
110
+ # Layout and Style Experiments
111
+
112
+ The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
113
+ the [CSS Grid](https://gridbyexample.com/examples)?
114
+
115
+ Can we do it with a nice api?
116
+ Can have a dark theme?
117
+ """
118
+ )
119
+
120
+ select_block_container_style()
121
+ add_resources_section()
122
+
123
+ # My preliminary idea of an API for generating a grid
124
+ with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
125
+ grid.cell(
126
+ class_="a",
127
+ grid_column_start=2,
128
+ grid_column_end=3,
129
+ grid_row_start=1,
130
+ grid_row_end=2,
131
+ ).markdown("# This is A Markdown Cell")
132
+ grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
133
+ grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
134
+ grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
135
+ grid.cell("e", 3, 4, 1, 2).markdown(
136
+ "Try changing the **block container style** in the sidebar!"
137
+ )
138
+ grid.cell("f", 1, 3, 3, 4).text(
139
+ "The cell to the right is a matplotlib svg image"
140
+ )
141
+ grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())
142
+
143
+
144
+ def add_resources_section():
145
+ """Adds a resources section to the sidebar"""
146
+ st.sidebar.header("Add_resources_section")
147
+ st.sidebar.markdown(
148
+ """
149
+ - [gridbyexample.com] (https://gridbyexample.com/examples/)
150
+ """
151
+ )
152
+
153
+
154
+ class Cell:
155
+ """A Cell can hold text, markdown, plots etc."""
156
+
157
+ def __init__(
158
+ self,
159
+ class_: str = None,
160
+ grid_column_start: Optional[int] = None,
161
+ grid_column_end: Optional[int] = None,
162
+ grid_row_start: Optional[int] = None,
163
+ grid_row_end: Optional[int] = None,
164
+ ):
165
+ self.class_ = class_
166
+ self.grid_column_start = grid_column_start
167
+ self.grid_column_end = grid_column_end
168
+ self.grid_row_start = grid_row_start
169
+ self.grid_row_end = grid_row_end
170
+ self.inner_html = ""
171
+
172
+ def _to_style(self) -> str:
173
+ return f"""
174
+ .{self.class_} {{
175
+ grid-column-start: {self.grid_column_start};
176
+ grid-column-end: {self.grid_column_end};
177
+ grid-row-start: {self.grid_row_start};
178
+ grid-row-end: {self.grid_row_end};
179
+ }}
180
+ """
181
+
182
+ def text(self, text: str = ""):
183
+ self.inner_html = text
184
+
185
+ def markdown(self, text):
186
+ self.inner_html = markdown.markdown(text)
187
+
188
+ def dataframe(self, dataframe: pd.DataFrame):
189
+ self.inner_html = dataframe.to_html()
190
+
191
+ def plotly_chart(self, fig):
192
+ self.inner_html = f"""
193
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
194
+ <body>
195
+ <p>This should have been a plotly plot.
196
+ But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
197
+ But I could potentially save to svg and insert that.</p>
198
+ <div id='divPlotly'></div>
199
+ <script>
200
+ var plotly_data = {fig.to_json()}
201
+ Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
202
+ </script>
203
+ </body>
204
+ """
205
+
206
+ def pyplot(self, fig=None, **kwargs):
207
+ string_io = io.StringIO()
208
+ plt.savefig(string_io, format="svg", fig=(2, 2))
209
+ svg = string_io.getvalue()[215:]
210
+ plt.close(fig)
211
+ self.inner_html = '<div height="200px">' + svg + "</div>"
212
+
213
+ def _to_html(self):
214
+ return f"""<div class="box {self.class_}">{self.inner_html}</div>"""
215
+
216
+
217
+ class Grid:
218
+ """A (CSS) Grid"""
219
+
220
+ def __init__(
221
+ self,
222
+ template_columns="1 1 1",
223
+ gap="10px",
224
+ background_color=COLOR,
225
+ color=BACKGROUND_COLOR,
226
+ ):
227
+ self.template_columns = template_columns
228
+ self.gap = gap
229
+ self.background_color = background_color
230
+ self.color = color
231
+ self.cells: List[Cell] = []
232
+
233
+ def __enter__(self):
234
+ return self
235
+
236
+ def __exit__(self, type, value, traceback):
237
+ st.markdown(self._get_grid_style(), unsafe_allow_html=True)
238
+ st.markdown(self._get_cells_style(), unsafe_allow_html=True)
239
+ st.markdown(self._get_cells_html(), unsafe_allow_html=True)
240
+
241
+ def _get_grid_style(self):
242
+ return f"""
243
+ <style>
244
+ .wrapper {{
245
+ display: grid;
246
+ grid-template-columns: {self.template_columns};
247
+ grid-gap: {self.gap};
248
+ background-color: {self.color};
249
+ color: {self.background_color};
250
+ }}
251
+ .box {{
252
+ background-color: {self.color};
253
+ color: {self.background_color};
254
+ border-radius: 0px;
255
+ padding: 0px;
256
+ font-size: 100%;
257
+ text-align: center;
258
+ }}
259
+ table {{
260
+ color: {self.color}
261
+ }}
262
+ </style>
263
+ """
264
+
265
+ def _get_cells_style(self):
266
+ return (
267
+ "<style>"
268
+ + "\n".join([cell._to_style() for cell in self.cells])
269
+ + "</style>"
270
+ )
271
+
272
+ def _get_cells_html(self):
273
+ return (
274
+ '<div class="wrapper">'
275
+ + "\n".join([cell._to_html() for cell in self.cells])
276
+ + "</div>"
277
+ )
278
+
279
+ def cell(
280
+ self,
281
+ class_: str = None,
282
+ grid_column_start: Optional[int] = None,
283
+ grid_column_end: Optional[int] = None,
284
+ grid_row_start: Optional[int] = None,
285
+ grid_row_end: Optional[int] = None,
286
+ ):
287
+ cell = Cell(
288
+ class_=class_,
289
+ grid_column_start=grid_column_start,
290
+ grid_column_end=grid_column_end,
291
+ grid_row_start=grid_row_start,
292
+ grid_row_end=grid_row_end,
293
+ )
294
+ self.cells.append(cell)
295
+ return cell
296
+
297
+
298
+ def select_block_container_style():
299
+ """Add selection section for setting setting the max-width and padding
300
+ of the main block container"""
301
+ st.sidebar.header("Block Container Style")
302
+ max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
303
+ if not max_width_100_percent:
304
+ max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
305
+ else:
306
+ max_width = 1200
307
+ dark_theme = st.sidebar.checkbox("Dark Theme?", False)
308
+ padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
309
+ padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
310
+ padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
311
+ padding_bottom = st.sidebar.number_input(
312
+ "Select padding bottom in rem", 0, 200, 10, 1
313
+ )
314
+ if dark_theme:
315
+ global COLOR
316
+ global BACKGROUND_COLOR
317
+ BACKGROUND_COLOR = "rgb(17,17,17)"
318
+ COLOR = "#fff"
319
+
320
+ _set_block_container_style(
321
+ max_width,
322
+ max_width_100_percent,
323
+ padding_top,
324
+ padding_right,
325
+ padding_left,
326
+ padding_bottom,
327
+ )
328
+
329
+
330
+ def _set_block_container_style(
331
+ max_width: int = 1200,
332
+ max_width_100_percent: bool = False,
333
+ padding_top: int = 5,
334
+ padding_right: int = 1,
335
+ padding_left: int = 1,
336
+ padding_bottom: int = 10,
337
+ ):
338
+ if max_width_100_percent:
339
+ max_width_str = f"max-width: 100%;"
340
+ else:
341
+ max_width_str = f"max-width: {max_width}px;"
342
+ st.markdown(
343
+ f"""
344
+ <style>
345
+ .reportview-container .main .block-container{{
346
+ {max_width_str}
347
+ padding-top: {padding_top}rem;
348
+ padding-right: {padding_right}rem;
349
+ padding-left: {padding_left}rem;
350
+ padding-bottom: {padding_bottom}rem;
351
+ }}
352
+ .reportview-container .main {{
353
+ color: {COLOR};
354
+ background-color: {BACKGROUND_COLOR};
355
+ }}
356
+ </style>
357
+ """,
358
+ unsafe_allow_html=True,
359
+ )
360
+
361
+
362
+ # @st.cache
363
+ # def get_dataframe() -> pd.DataFrame():
364
+ # """Dummy DataFrame"""
365
+ # data = [
366
+ # {"quantity": 1, "price": 2},
367
+ # {"quantity": 3, "price": 5},
368
+ # {"quantity": 4, "price": 8},
369
+ # ]
370
+ # return pd.DataFrame(data)
371
+
372
+
373
+ # def get_plotly_fig():
374
+ # """Dummy Plotly Plot"""
375
+ # return px.line(data_frame=get_dataframe(), x="quantity", y="price")
376
+
377
+
378
+ # def get_matplotlib_plt():
379
+ # get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
data/ImageNet_metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e53b0fc17cd5c8811ca08b7ff908cd2bbd625147686ef8bc020cb85a5a4546e5
3
+ size 3027633
data/activation/convnext_activation.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0354b28bcca4e3673888124740e3d82882cbf38af8cd3007f48a7a5db983f487
3
+ size 33350177
data/activation/mobilenet_activation.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5abc76e9318fadee18f35bb54e90201bf28699cf75140b5d2482d42243fad302
3
+ size 13564581
data/activation/resnet_activation.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:668bea355a5504d74f79d20d02954040ad572f50455361d7d17125c7c8b1561c
3
+ size 23362905
data/dot_architectures/convnext_architecture.dot ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41a258a40a93615638ae504770c14e44836c934badbe48f18148f5a750514ac9
3
+ size 9108
data/layer_infos/convnext_layer_infos.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e82ea48865493107b97f37da58e370f0eead5677bf10f25f237f10970aedb6f
3
+ size 1678
data/layer_infos/mobilenet_layer_infos.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a11df5f0b23040d11ce817658a989c8faf19faa06a8cbad727b635bac824e917
3
+ size 3578
data/layer_infos/resnet_layer_infos.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21e1787382f1e1c206b81d2c4fe207fb6d41f4cf186d5afc32fc056dd21e10d6
3
+ size 5155
data/preprocessed_image_net/val_data_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2698bdc240555e2a46a40936df87275bc5852142d30e921ae0dad9289b0f576f
3
+ size 906108480
data/preprocessed_image_net/val_data_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21780d77e212695dbee84d6d2ad17a5a520bc1634f68e1c8fd120f069ad76da1
3
+ size 907109023
data/preprocessed_image_net/val_data_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cfc83b78420baa1b2c3a8da92e7fba1f33443d506f483ecff13cdba2035ab3c
3
+ size 907435149
data/preprocessed_image_net/val_data_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f5e2c7cb4d6bae17fbd062a0b46f2cee457ad466b725f7bdf0f8426069cafee
3
+ size 906089333
data/preprocessed_image_net/val_data_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed53c87ec8b9945db31f910eb44b7e3092324643de25ea53a99fc29137df854
3
+ size 905439763
frontend/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import streamlit.components.v1 as components
2
+
3
+ on_click_graph = components.declare_component(
4
+ "on_click_graph",
5
+ path="./frontend"
6
+ )
frontend/footer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ footer="""<style>
4
+ a:link , a:visited{
5
+ background-color: transparent;
6
+ # text-decoration: underline;
7
+ }
8
+
9
+ a:hover, a:active {
10
+ color: orange;
11
+ background-color: transparent;
12
+ text-decoration: underline;
13
+ }
14
+
15
+ .footer {
16
+ position: fixed;
17
+ left: 0;
18
+ bottom: 0;
19
+ width: 100%;
20
+ background-color: white;
21
+ text-align: center;
22
+ }
23
+ </style>
24
+
25
+ <div class="footer">
26
+ <p>USER_DEFINED_TEXT <a href="LINK" target="_blank" color="blue">LINKED_TEXT</a></p>
27
+ </div>
28
+ """
29
+
30
+ def add_footer(text, linked_text, link):
31
+ custom_footer = footer.replace('USER_DEFINED_TEXT', text).replace('LINKED_TEXT', linked_text).replace('LINK', link)
32
+ st.markdown(custom_footer, unsafe_allow_html=True)
frontend/images/equal-sign.png ADDED
frontend/images/minus-sign-2.png ADDED
frontend/images/minus-sign-3.png ADDED
frontend/images/minus-sign-4.png ADDED
frontend/images/minus-sign-5.png ADDED
frontend/images/minus-sign.png ADDED
frontend/images/plus-sign-2.png ADDED
frontend/images/plus-sign.png ADDED
frontend/index.html ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+
3
+ <head>
4
+ <style type="text/css">
5
+ </style>
6
+ </head>
7
+
8
+ <!--
9
+ ----------------------------------------------------
10
+ Your custom static HTML goes in the body:
11
+ -->
12
+
13
+ <body>
14
+ </body>
15
+
16
+ <script type="text/javascript">
17
+ // Helper function to send type and data messages to Streamlit client
18
+
19
+ const SET_COMPONENT_VALUE = "streamlit:setComponentValue"
20
+ const RENDER = "streamlit:render"
21
+ const COMPONENT_READY = "streamlit:componentReady"
22
+ const SET_FRAME_HEIGHT = "streamlit:setFrameHeight"
23
+ var HIGHTLIGHT_COLOR;
24
+ var original_colors;
25
+
26
+ function _sendMessage(type, data) {
27
+ // copy data into object
28
+ var outboundData = Object.assign({
29
+ isStreamlitMessage: true,
30
+ type: type,
31
+ }, data)
32
+
33
+ if (type == SET_COMPONENT_VALUE) {
34
+ console.log("_sendMessage data: ", SET_COMPONENT_VALUE)
35
+ // console.log("_sendMessage data: " + JSON.stringify(data))
36
+ // console.log("_sendMessage outboundData: " + JSON.stringify(outboundData))
37
+ }
38
+
39
+ window.parent.postMessage(outboundData, "*")
40
+ }
41
+
42
+ function initialize(pipeline) {
43
+
44
+ // Hook Streamlit's message events into a simple dispatcher of pipeline handlers
45
+ window.addEventListener("message", (event) => {
46
+ if (event.data.type == RENDER) {
47
+ // The event.data.args dict holds any JSON-serializable value
48
+ // sent from the Streamlit client. It is already deserialized.
49
+ pipeline.forEach(handler => {
50
+ handler(event.data.args)
51
+ })
52
+ }
53
+ })
54
+
55
+ _sendMessage(COMPONENT_READY, { apiVersion: 1 });
56
+
57
+ // Component should be mounted by Streamlit in an iframe, so try to autoset the iframe height.
58
+ window.addEventListener("load", () => {
59
+ window.setTimeout(function () {
60
+ setFrameHeight(document.documentElement.clientHeight)
61
+ }, 0)
62
+ })
63
+
64
+ // Optionally, if auto-height computation fails, you can manually set it
65
+ // (uncomment below)
66
+ setFrameHeight(0)
67
+ }
68
+
69
+ function setFrameHeight(height) {
70
+ _sendMessage(SET_FRAME_HEIGHT, { height: height })
71
+ }
72
+
73
+ // The `data` argument can be any JSON-serializable value.
74
+ function notifyHost(data) {
75
+ _sendMessage(SET_COMPONENT_VALUE, data)
76
+ }
77
+
78
+ function changeButtonColor(button, color) {
79
+ pol = button.querySelectorAll('polygon')[0]
80
+ pol.setAttribute('fill', color)
81
+ pol.setAttribute('stroke', color)
82
+ }
83
+
84
+ function getButtonColor(button) {
85
+ pol = button.querySelectorAll('polygon')[0]
86
+ return pol.getAttribute('fill')
87
+ }
88
+ // ----------------------------------------------------
89
+ // Your custom functionality for the component goes here:
90
+
91
+ function toggle(button) {
92
+ group = 'node'
93
+ let button_color;
94
+ nodes = window.parent.document.getElementsByClassName('node')
95
+ console.log("nodes.length = ", nodes.length)
96
+ // for (let i = 0; i < nodes.length; i++) {
97
+ // console.log(nodes.item(i))
98
+ // }
99
+ console.log("selected button ", button, button.getAttribute('class'), button.id)
100
+
101
+ for (let i = 0; i < nodes.length; i++) {
102
+ polygons = nodes.item(i).querySelectorAll('polygon')
103
+ if (polygons.length == 0) {
104
+ continue
105
+ }
106
+ if (button.id == nodes.item(i).id & button.getAttribute('class').includes("off")) {
107
+ button.setAttribute('class', group + " on")
108
+ button_color = original_colors[i]
109
+
110
+ } else if (button.id == nodes.item(i).id & button.getAttribute('class').includes("on")) {
111
+ button.setAttribute('class', group + " off")
112
+ button_color = original_colors[i]
113
+ } else if (button.id == nodes.item(i).id) {
114
+ button.setAttribute('class', group + " on")
115
+ button_color = original_colors[i]
116
+
117
+ } else if (button.id != nodes.item(i).id & nodes.item(i).getAttribute('class').includes("on")) {
118
+ nodes.item(i).className = group + " off"
119
+ } else {
120
+ nodes.item(i).className = group + " off"
121
+ }
122
+ }
123
+
124
+ nodes = window.parent.document.getElementsByClassName('node')
125
+ actions = []
126
+ for (let i = 0; i < nodes.length; i++) {
127
+ polygons = nodes.item(i).querySelectorAll('polygon')
128
+ if (polygons.length == 0) {
129
+ continue
130
+ }
131
+ btn = nodes.item(i)
132
+ ori_color = original_colors[i]
133
+ color = btn.querySelectorAll('polygon')[0].getAttribute('fill')
134
+ actions.push({ "action": btn.getAttribute("class").includes("on"), "original_color": ori_color, "color": color})
135
+ }
136
+
137
+ states = {}
138
+ states['choice'] = {
139
+ "node_title": button.querySelectorAll("title")[0].innerHTML,
140
+ "node_id": button.id,
141
+ "state": {
142
+ "action": button.getAttribute("class").includes("on"),
143
+ "original_color": button_color,
144
+ "color": button.querySelectorAll('polygon')[0].getAttribute('fill')
145
+ }
146
+ }
147
+ states["options"] = {"states": actions }
148
+
149
+ notifyHost({
150
+ value: states,
151
+ dataType: "json",
152
+ })
153
+ }
154
+
155
+ // ----------------------------------------------------
156
+ // Here you can customize a pipeline of handlers for
157
+ // inbound properties from the Streamlit client app
158
+
159
+ // Set initial value sent from Streamlit!
160
+ function initializeProps_Handler(props) {
161
+ HIGHTLIGHT_COLOR = props['hightlight_color']
162
+ original_colors = []
163
+ // nodes = document.getElementsByClassName('node')
164
+ nodes = window.parent.document.getElementsByClassName('node')
165
+ console.log(nodes)
166
+ for (let i = 0; i < nodes.length; i++) {
167
+ // color = nodes.item(i).getElementsByTagName('POLYGON')[0].getAttribute("fill")
168
+ // nodes.item(i).addEventListener("click", toggle)
169
+ polygons = nodes.item(i).querySelectorAll('polygon')
170
+ if (polygons.length == 0) {
171
+ original_colors.push('none')
172
+ continue
173
+ }
174
+
175
+ color = polygons[0].getAttribute("fill")
176
+ if (!nodes.item(i).hasAttribute('color')) {
177
+ nodes.item(i).setAttribute("color", color)
178
+ original_colors.push(color)
179
+ } else {
180
+ original_colors.push(nodes.item(i).getAttribute("color"))
181
+ }
182
+ nodes.item(i).addEventListener("click", function (event) {toggle(this)})
183
+ }
184
+ // console.log("original colors:", original_colors)
185
+ }
186
+ // Access values sent from Streamlit!
187
+ function dataUpdate_Handler(props) {
188
+ console.log('dataUpdate_Handler...........')
189
+ let msgLabel = document.getElementById("message_label")
190
+ }
191
+ // Simply log received data dictionary
192
+ function log_Handler(props) {
193
+ console.log("Received from Streamlit: " + JSON.stringify(props))
194
+ }
195
+
196
+ let pipeline = [initializeProps_Handler, dataUpdate_Handler, log_Handler]
197
+
198
+ // ----------------------------------------------------
199
+ // Finally, initialize component passing in pipeline
200
+ initialize(pipeline)
201
+
202
+ </script>
203
+
204
+ </html>
pages/1_Maximally_activating_patches.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+
4
+ from plotly.subplots import make_subplots
5
+ import plotly.graph_objects as go
6
+
7
+ import graphviz
8
+
9
+ from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
10
+ from frontend import on_click_graph
11
+ from backend.utils import load_dataset_dict
12
+
13
+ HIGHTLIGHT_COLOR = '#e7bcc5'
14
+ st.set_page_config(layout='wide')
15
+
16
+
17
+ st.title('Maximally activating image patches')
18
+ st.write('> **What patterns maximally activate this channel in ConvNeXt model?**')
19
+ st.write("""The maximally activating image patches method is a technique used in visualizing the interpretation of convolutional neural networks.
20
+ It works by identifying the regions of the input image that activate a particular neuron in the convolutional layer,
21
+ thus revealing the features that the neuron is detecting. To achieve this, the method generates image patches then feeds into the model while monitoring the neuron's activation.
22
+ The algorithm then selects the patch that produces the highest activation and overlays it on the original image to visualize the features that the neuron is responding to.
23
+ """)
24
+
25
+ # -------------------------- LOAD DATASET ---------------------------------
26
+ dataset_dict = load_dataset_dict()
27
+
28
+ # -------------------------- LOAD GRAPH -----------------------------------
29
+
30
+ def load_dot_to_graph(filename):
31
+ dot = graphviz.Source.from_file(filename)
32
+ source_lines = str(dot).splitlines()
33
+ source_lines.pop(0)
34
+ source_lines.pop(-1)
35
+ graph = graphviz.Digraph()
36
+ graph.body += source_lines
37
+ return graph, dot
38
+
39
+
40
+ # st.header('ConvNeXt')
41
+ convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
42
+ convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
43
+
44
+ convnext_graph.graph_attr['size'] = '4,40'
45
+
46
+ # -------------------------- DISPLAY GRAPH -----------------------------------
47
+
48
+ def chosen_node_text(clicked_node_title):
49
+ clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
50
+ stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
51
+ block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
52
+ layer_id = clicked_node_title.split()[-1]
53
+
54
+ if 'embeddings' in layer_id:
55
+ display_text = 'Patchify layer'
56
+ activation_key = 'embeddings.patch_embeddings'
57
+ elif 'downsampling' in layer_id:
58
+ display_text = f'Stage {stage_id} > Downsampling layer'
59
+ activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
60
+ else:
61
+ display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
62
+ activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
63
+ return display_text, activation_key
64
+
65
+
66
+ props = {
67
+ 'hightlight_color': HIGHTLIGHT_COLOR,
68
+ 'initial_state': {
69
+ 'group_1_header': 'Choose an option from group 1',
70
+ 'group_2_header': 'Choose an option from group 2'
71
+ }
72
+ }
73
+
74
+
75
+ col1, col2 = st.columns((2,5))
76
+ col1.markdown("#### Architecture")
77
+ col1.write('')
78
+ col1.write('Click on a layer below to generate top-k maximally activating image patches')
79
+ col1.graphviz_chart(convnext_graph)
80
+
81
+ with col2:
82
+ st.markdown("#### Output")
83
+ nodes = on_click_graph(key='toggle_buttons', **props)
84
+
85
+ # -------------------------- DISPLAY OUTPUT -----------------------------------
86
+
87
+ if nodes != None:
88
+ clicked_node_title = nodes["choice"]["node_title"]
89
+ clicked_node_id = nodes["choice"]["node_id"]
90
+ display_text, activation_key = chosen_node_text(clicked_node_title)
91
+ col2.write(f'**Chosen layer:** {display_text}')
92
+ # col2.write(f'**Activation key:** {activation_key}')
93
+
94
+ hightlight_syle = f'''
95
+ <style>
96
+ div[data-stale]:has(iframe) {{
97
+ height: 0;
98
+ }}
99
+ #{clicked_node_id}>polygon {{
100
+ fill: {HIGHTLIGHT_COLOR};
101
+ stroke: {HIGHTLIGHT_COLOR};
102
+ }}
103
+ </style>
104
+ '''
105
+ col2.markdown(hightlight_syle, unsafe_allow_html=True)
106
+
107
+ with col2:
108
+ layer_infos = None
109
+ with st.form('top_k_form'):
110
+ activation_path = './data/activation/convnext_activation.json'
111
+ activation = load_activation(activation_path)
112
+ num_channels = activation[activation_key].shape[1]
113
+
114
+ top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
115
+ channel_start, channel_end = st.slider(
116
+ 'Choose channel range of this layer (recommend to choose small range less than 30)',
117
+ 1, num_channels, value=(1, 30))
118
+ summit_button = st.form_submit_button('Generate image patches')
119
+ if summit_button:
120
+
121
+ activation = activation[activation_key][:top_k,:,:]
122
+ layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
123
+ # st.write(channel_start, channel_end)
124
+ # st.write(activation.shape, activation.shape[1])
125
+
126
+ if layer_infos != None:
127
+ num_cols, num_rows = top_k, channel_end - channel_start + 1
128
+ # num_rows = activation.shape[1]
129
+ top_k_coor_max_ = activation
130
+ st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
131
+
132
+ for row in range(channel_start, channel_end+1):
133
+ if row == channel_start:
134
+ top_margin = 50
135
+ fig = make_subplots(
136
+ rows=1, cols=num_cols,
137
+ subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
138
+ else:
139
+ top_margin = 0
140
+ fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
141
+ for col in range(1, num_cols+1):
142
+ k, c = col-1, row-1
143
+ img_index = int(top_k_coor_max_[k, c, 3])
144
+ activation_value = top_k_coor_max_[k, c, 0]
145
+ img = dataset_dict[img_index//10_000][img_index%10_000]['image']
146
+ class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
147
+ class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
148
+
149
+ idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
150
+ x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
151
+ img = np.array(img)[y1:y2, x1:x2, :]
152
+
153
+ hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
154
+ fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
155
+ fig.update_xaxes(showticklabels=False, showgrid=False)
156
+ fig.update_yaxes(showticklabels=False, showgrid=False)
157
+ fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
158
+ fig.update_layout(showlegend=False, yaxis_title=row)
159
+ fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
160
+ fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
161
+ st.plotly_chart(fig, use_container_width=True)
162
+
163
+
164
+ else:
165
+ col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
166
+ col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
pages/2_SmoothGrad.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ from backend.utils import make_grid, load_dataset, load_model, load_images
6
+
7
+ from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
8
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
+ import torch
10
+
11
+ from matplotlib.backends.backend_agg import RendererAgg
12
+ _lock = RendererAgg.lock
13
+
14
+ st.set_page_config(layout='wide')
15
+ BACKGROUND_COLOR = '#bcd0e7'
16
+
17
+
18
+ st.title('Feature attribution visualization with SmoothGrad')
19
+ st.write("""> **Which features are responsible for the current prediction of ConvNeXt?**
20
+
21
+ In machine learning, it is helpful to identify the significant features of the input (e.g., pixels for images) that affect the model's prediction.
22
+ If the model makes an incorrect prediction, we might want to determine which features contributed to the mistake.
23
+ To do this, we can generate a feature importance mask, which is a grayscale image with the same size as the original image.
24
+ The brightness of each pixel in the mask represents the importance of that feature to the model's prediction.
25
+
26
+ There are various methods to calculate an image sensitivity mask for a specific prediction.
27
+ One simple way is to use the gradient of a class prediction neuron concerning the input pixels, indicating how the prediction is affected by small pixel changes.
28
+ However, this method usually produces a noisy mask.
29
+ To reduce the noise, the SmoothGrad technique as described in [SmoothGrad: Removing noise by adding noise](https://arxiv.org/abs/1706.03825) by Daniel _et al_ is used,
30
+ which adds Gaussian noise to multiple copies of the image and averages the resulting gradients.
31
+ """)
32
+
33
+ instruction_text = """Users need to input the model(s), type of image set and image set setting to use this functionality.
34
+ 1. Choose model: Users can choose one or more models for comparison.
35
+ There are 3 models supported: [ConvNeXt](https://huggingface.co/facebook/convnext-tiny-224),
36
+ [ResNet](https://huggingface.co/microsoft/resnet-50) and [MobileNet](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/).
37
+ These 3 models have similar number of parameters.
38
+ \n2. Choose type of Image set: There are 2 types of Image set. They are _User-defined set_ and _Random set_.
39
+ \n3. Image set setting: If users choose _User-defined set_ in Image set,
40
+ users need to enter a list of image IDs separated by commas (,). For example, `0,1,4,7` is a valid input.
41
+ Check the page [ImageNet1k](/ImageNet1k) to see all the Image IDs.
42
+ If users choose _Random set_ in Image set, users just need to choose the number of random images to display here.
43
+ """
44
+ with st.expander("See more instruction", expanded=False):
45
+ st.write(instruction_text)
46
+
47
+
48
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
49
+
50
+ # --------------------------- LOAD function -----------------------------
51
+
52
+
53
+ images = []
54
+ image_ids = []
55
+ # INPUT ------------------------------
56
+ st.header('Input')
57
+ with st.form('smooth_grad_form'):
58
+ st.markdown('**Model and Input Setting**')
59
+ selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
60
+ selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
61
+
62
+ summit_button = st.form_submit_button('Set')
63
+ if summit_button:
64
+ setting_container = st.container()
65
+ # for id in image_ids:
66
+ # images = load_images(image_ids)
67
+
68
+ with st.form('2nd_form'):
69
+ st.markdown('**Image set setting**')
70
+ if selected_image_set == 'Random set':
71
+ no_images = st.slider('Number of images', 1, 50, value=10)
72
+ image_ids = random.sample(list(range(50_000)), k=no_images)
73
+ else:
74
+ text = st.text_area('Specific Image IDs', value='0')
75
+ image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
76
+
77
+ run_button = st.form_submit_button('Display output')
78
+ if run_button:
79
+ for id in image_ids:
80
+ images = load_images(image_ids)
81
+
82
+ st.header('Output')
83
+
84
+ models = {}
85
+ feature_extractors = {}
86
+
87
+ for i, model_name in enumerate(selected_models):
88
+ models[model_name], feature_extractors[model_name] = load_model(model_name)
89
+
90
+
91
+ # DISPLAY ----------------------------------
92
+ if run_button:
93
+ header_cols = st.columns([1, 1] + [2]*len(selected_models))
94
+ header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
95
+ header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
96
+ for i, model_name in enumerate(selected_models):
97
+ header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
98
+
99
+ grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
100
+
101
+
102
+ @st.cache(allow_output_mutation=True)
103
+ # @st.cache_data
104
+ def generate_images(image_id, model_name):
105
+ j = image_ids.index(image_id)
106
+ image = images[j]['image']
107
+ return generate_smoothgrad_mask(
108
+ image, model_name,
109
+ models[model_name], feature_extractors[model_name], num_samples=10)
110
+
111
+ with _lock:
112
+ for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
113
+ grids[j][0].write(f'{image_id}. {image_dict["label"]}')
114
+ image = image_dict['image']
115
+ ori_image = ShowImage(np.asarray(image))
116
+ grids[j][1].image(ori_image)
117
+
118
+ for i, model_name in enumerate(selected_models):
119
+ # ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
120
+ # model_name, models[model_name], feature_extractors[model_name], num_samples=10)
121
+ heatmap_image, masked_image = generate_images(image_id, model_name)
122
+ # grids[j][1].image(ori_image)
123
+ grids[j][i*2+2].image(heatmap_image)
124
+ grids[j][i*2+3].image(masked_image)
pages/3_Adversarial_attack.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ from backend.utils import make_grid, load_dataset, load_model, load_image
6
+
7
+ from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img, LoadImage, ShowHeatMap, ShowMaskedImage
8
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
+ import torch
10
+
11
+ from matplotlib.backends.backend_agg import RendererAgg
12
+
13
+ from backend.adversarial_attack import *
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+ st.set_page_config(layout='wide')
18
+ BACKGROUND_COLOR = '#bcd0e7'
19
+ SECONDARY_COLOR = '#bce7db'
20
+
21
+
22
+ st.title('Adversarial Attack')
23
+ st.write('> **How adversarial attacks affect ConvNeXt interpretation?**')
24
+ st.write("""Adversarial examples are inputs crafted to confuse neural networks, causing them to misclassify a given input.
25
+ These examples are not distinguishable by humans but cause the network to fail to recognize the image content.
26
+ One type of such attack is the fast gradient sign method (FGSM) attack, which is a white box attack that aims to ensure misclassification.
27
+ A white box attack is where the attacker has full access to the model being attacked.
28
+
29
+ The FGSM attack is one of the earliest and most popular adversarial attacks.
30
+ It is described by Goodfellow _et al_ in their work on [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572).
31
+ The attack is simple yet powerful, using the gradients that neural networks use to learn.
32
+ Instead of adjusting the weights based on the backpropagated gradients to minimize loss, the attack adjusts the input data to maximize the loss using the gradient of the loss with respect to the input data.
33
+ """)
34
+
35
+ instruction_text = """Instruction to input:
36
+ 1. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
37
+ 2. Choosing epsilon: **Epsilon** is the amount of perturbation on the original image under attack. The higher the epsilon is, the more pertubed the image is, the more confusion made to the model.
38
+ Users can choose a specific epsilon by engtering **Epsilon** and hit the _Choose the defined epsilon_ button.
39
+ Users can also let the algorithm find the smallest epsilon automatically by hitting the _Find the smallest epsilon automatically_ button.
40
+ The underlying algorithm will iterate through a set of epsilon in ascending order until reaching the **maximum value of epsilon**.
41
+ After each iteration, the epsilon will increase by an amount equal to **step** variable.
42
+ Users can change the default values of the two variable value optionally.
43
+ """
44
+ st.write("To use the functionality below, users need to input the **image** and the **epsilon**.")
45
+ with st.expander("See more instruction", expanded=False):
46
+ st.write(instruction_text)
47
+
48
+
49
+
50
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
51
+ image_id = None
52
+
53
+ if 'image_id' not in st.session_state:
54
+ st.session_state.image_id = 0
55
+
56
+ # def on_change_random_input():
57
+ # st.session_state.image_id = st.session_state.image_id
58
+
59
+ # ----------------------------- INPUT ----------------------------------
60
+ st.header('Input')
61
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
62
+ # --------------------------- INPUT column 1 ---------------------------
63
+ with input_col_1:
64
+ with st.form('image_form'):
65
+
66
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
67
+ st.write('**Choose or generate a random image**')
68
+ chosen_image_id_input = st.empty()
69
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
70
+
71
+ choose_image_button = st.form_submit_button('Choose the defined image')
72
+ random_id = st.form_submit_button('Generate a random image')
73
+
74
+ if random_id:
75
+ image_id = random.randint(0, 50000)
76
+ st.session_state.image_id = image_id
77
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
78
+
79
+ if choose_image_button:
80
+ image_id = int(image_id)
81
+ st.session_state.image_id = int(image_id)
82
+ # st.write(image_id, st.session_state.image_id)
83
+
84
+ # ---------------------------- SET UP OUTPUT ------------------------------
85
+ epsilon_container = st.empty()
86
+ st.header('Output')
87
+ st.subheader('Perform attack')
88
+
89
+ # perform attack container
90
+ header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
91
+ output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
92
+
93
+ # prediction error container
94
+ error_container = st.empty()
95
+ smoothgrad_header_container = st.empty()
96
+
97
+ # smoothgrad container
98
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
99
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
100
+
101
+ original_image_dict = load_image(st.session_state.image_id)
102
+ input_image = original_image_dict['image']
103
+ input_label = original_image_dict['label']
104
+ input_id = original_image_dict['id']
105
+
106
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
107
+ with output_col_1:
108
+ pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
109
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
110
+ st.image(input_image)
111
+ header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
112
+
113
+
114
+
115
+ if pred_class_id != (input_id-1):
116
+ with error_container.container():
117
+ st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
118
+ st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
119
+ icon = "🚫")
120
+
121
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
122
+ with input_col_2:
123
+ with st.form('epsilon_form'):
124
+ st.write('**Set epsilon or find the smallest epsilon automatically**')
125
+ chosen_epsilon_input = st.empty()
126
+ epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
127
+
128
+ epsilon_button = st.form_submit_button('Choose the defined epsilon')
129
+ find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
130
+
131
+
132
+ with input_col_3:
133
+ with st.form('iterate_epsilon_form'):
134
+ max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
135
+ step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
136
+ setting_button = st.form_submit_button('Set iterating mode')
137
+
138
+
139
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
140
+ if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
141
+ with output_col_3:
142
+ if epsilon_button:
143
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
144
+ else:
145
+ epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
146
+ with epsilon_container.container():
147
+ epsilon_container_text = 'Checking epsilon'
148
+ st.write(epsilon_container_text)
149
+ st.progress(0)
150
+
151
+ for i, e in enumerate(epsilons):
152
+
153
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
154
+ with epsilon_container.container():
155
+ epsilon_container_text = f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%'
156
+ st.write(epsilon_container_text)
157
+ st.progress(i/len(epsilons))
158
+
159
+ epsilon = e
160
+
161
+ if new_id != input_id - 1:
162
+ epsilon_container.empty()
163
+ st.balloons()
164
+ break
165
+ if i == len(epsilons)-1:
166
+ epsilon_container.error(f'FGSM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
167
+ icon = "🚫")
168
+
169
+ perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
170
+ perturbed_amount = perturbed_image - input_image
171
+ header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
172
+ st.image(ShowImage(perturbed_amount))
173
+
174
+ with output_col_2:
175
+ # st.write('plus sign')
176
+ st.image(LoadImage('frontend/images/plus-sign.png'))
177
+
178
+ with output_col_4:
179
+ # st.write('equal sign')
180
+ st.image(LoadImage('frontend/images/equal-sign.png'))
181
+
182
+ # ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
183
+ with output_col_5:
184
+ # st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
185
+ st.image(ShowImage(perturbed_image))
186
+ header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
187
+
188
+ # -------------------------- DISPLAY SMOOTHGRAD ---------------------------
189
+ smoothgrad_header_container.subheader('SmoothGrad visualization')
190
+
191
+ with smoothgrad_col_1:
192
+ smooth_head_1.write(f'SmoothGrad before attacked')
193
+ heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
194
+ st.image(heatmap_image)
195
+ st.image(masked_image)
196
+ with smoothgrad_col_3:
197
+ smooth_head_3.write('SmoothGrad after attacked')
198
+ heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
199
+ st.image(heatmap_image_attacked)
200
+ st.image(masked_image_attacked)
201
+
202
+ with smoothgrad_col_2:
203
+ st.image(LoadImage('frontend/images/minus-sign-5.png'))
204
+
205
+ with smoothgrad_col_5:
206
+ smooth_head_5.write('SmoothGrad difference')
207
+ difference_mask = abs(attacked_mask-mask)
208
+ st.image(ShowHeatMap(difference_mask))
209
+ masked_image = ShowMaskedImage(difference_mask, perturbed_image)
210
+ st.image(masked_image)
211
+
212
+ with smoothgrad_col_4:
213
+ st.image(LoadImage('frontend/images/equal-sign.png'))
214
+
215
+
pages/4_ImageNet1k.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ from backend.utils import load_dataset, use_container_width_percentage
5
+
6
+ st.set_page_config(layout='wide')
7
+
8
+ st.title('ImageNet-1k')
9
+ st.markdown('This page shows the summary of 50,000 images in the validation set of [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)')
10
+
11
+ # SCREEN_WIDTH, SCREEN_HEIGHT = 2560, 1664
12
+
13
+ with st.spinner("Loading dataset..."):
14
+ dataset_dict = {}
15
+ for data_index in range(5):
16
+ dataset_dict[data_index] = load_dataset(data_index)
17
+
18
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
19
+
20
+ class_labels = imagenet_df.ClassLabel.unique().tolist()
21
+ class_labels.sort()
22
+ selected_classes = st.multiselect('Class filter: ', options=['All'] + class_labels)
23
+ if not ('All' in selected_classes or len(selected_classes) == 0):
24
+ imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
25
+ # st.write(class_labels)
26
+
27
+ col1, col2 = st.columns([2, 1])
28
+ with col1:
29
+ st.dataframe(imagenet_df)
30
+ use_container_width_percentage(100)
31
+
32
+ with col2:
33
+ st.text_area('Type anything here to copy later :)')
34
+ image = None
35
+ with st.form("display image"):
36
+ img_index = st.text_input('Image ID to display')
37
+
38
+ submitted = st.form_submit_button('Display this image')
39
+ error_container = st.empty()
40
+
41
+ if submitted:
42
+ try:
43
+ img_index = int(img_index)
44
+ if img_index > 50000-1 or img_index < 0:
45
+ error_container.error('The Image ID must be in range from 0 to 49999', icon="🚫")
46
+ else:
47
+ image = dataset_dict[img_index//10_000][img_index%10_000]['image']
48
+ class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
49
+ class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
50
+ except ValueError:
51
+ error_container.error('Please enter an integer number for Image ID', icon = "🚫")
52
+
53
+ if image != None:
54
+ st.image(image)
55
+ st.write('**Class label:** ', class_label)
56
+ st.write('\n**Class id:** ', str(class_id))
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ captum==0.5.0
2
+ graphviz==0.20.1
3
+ Markdown==3.4.1
4
+ matplotlib==3.6.2
5
+ numpy==1.22.3
6
+ opencv_python_headless==4.6.0.66
7
+ pandas==1.5.2
8
+ Pillow==9.4.0
9
+ plotly==5.11.0
10
+ scipy==1.10.1
11
+ setuptools==65.5.0
12
+ streamlit==1.19.0
13
+ torch==1.10.1
14
+ torchvision==0.11.2
15
+ tqdm==4.64.1
16
+ transformers==4.25.1
17
+ git+https://github.com/vlue-c/Visual-Explanation-Methods-PyTorch.git