taesiri commited on
Commit
8c20033
·
1 Parent(s): 8a6eccd

Initial Commit

Browse files
Files changed (6) hide show
  1. .gitignore +176 -0
  2. SessionState.py +117 -0
  3. app.py +406 -0
  4. download_utils.py +55 -0
  5. helper.py +23 -0
  6. image_utils.py +139 -0
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/python
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
4
+
5
+ ### Python ###
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110
+ __pypackages__/
111
+
112
+ # Celery stuff
113
+ celerybeat-schedule
114
+ celerybeat.pid
115
+
116
+ # SageMath parsed files
117
+ *.sage.py
118
+
119
+ # Environments
120
+ .env
121
+ .venv
122
+ env/
123
+ venv/
124
+ ENV/
125
+ env.bak/
126
+ venv.bak/
127
+
128
+ # Spyder project settings
129
+ .spyderproject
130
+ .spyproject
131
+
132
+ # Rope project settings
133
+ .ropeproject
134
+
135
+ # mkdocs documentation
136
+ /site
137
+
138
+ # mypy
139
+ .mypy_cache/
140
+ .dmypy.json
141
+ dmypy.json
142
+
143
+ # Pyre type checker
144
+ .pyre/
145
+
146
+ # pytype static type analyzer
147
+ .pytype/
148
+
149
+ # Cython debug symbols
150
+ cython_debug/
151
+
152
+ # PyCharm
153
+ # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
154
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
155
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
156
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
157
+ #.idea/
158
+
159
+ # End of https://www.toptal.com/developers/gitignore/api/python
160
+ #
161
+ #
162
+
163
+ ImageNet-HARD-EMD-Real.zip
164
+ demonstrations.zip
165
+ ImageNet-HARD-Normal.zip
166
+ demonstrations/
167
+ visualizations_feb2022/
168
+ ImageNet-HARD-EMD-5-Patches-Real.zip
169
+ visualizations
170
+ imagenet1k-pilot.tar.gz
171
+ predictions/
172
+ imagenet1k-pilot.zip
173
+ Final.zip
174
+ imagenet1k-val-50k-emd_results_rosy-brook-184.pickle
175
+ CUB-Final.zip
176
+ CUB-Demonstrations/
SessionState.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+
3
+ Usage
4
+ -----
5
+
6
+ >>> import SessionState
7
+ >>>
8
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
9
+ >>> session_state.user_name
10
+ ''
11
+ >>> session_state.user_name = 'Mary'
12
+ >>> session_state.favorite_color
13
+ 'black'
14
+
15
+ Since you set user_name above, next time your script runs this will be the
16
+ result:
17
+ >>> session_state = get(user_name='', favorite_color='black')
18
+ >>> session_state.user_name
19
+ 'Mary'
20
+
21
+ """
22
+ try:
23
+ import streamlit.ReportThread as ReportThread
24
+ from streamlit.server.Server import Server
25
+ except Exception:
26
+ # Streamlit >= 0.65.0
27
+ import streamlit.report_thread as ReportThread
28
+ from streamlit.server.server import Server
29
+
30
+
31
+ class SessionState(object):
32
+ def __init__(self, **kwargs):
33
+ """A new SessionState object.
34
+
35
+ Parameters
36
+ ----------
37
+ **kwargs : any
38
+ Default values for the session state.
39
+
40
+ Example
41
+ -------
42
+ >>> session_state = SessionState(user_name='', favorite_color='black')
43
+ >>> session_state.user_name = 'Mary'
44
+ ''
45
+ >>> session_state.favorite_color
46
+ 'black'
47
+
48
+ """
49
+ for key, val in kwargs.items():
50
+ setattr(self, key, val)
51
+
52
+
53
+ def get(**kwargs):
54
+ """Gets a SessionState object for the current session.
55
+
56
+ Creates a new object if necessary.
57
+
58
+ Parameters
59
+ ----------
60
+ **kwargs : any
61
+ Default values you want to add to the session state, if we're creating a
62
+ new one.
63
+
64
+ Example
65
+ -------
66
+ >>> session_state = get(user_name='', favorite_color='black')
67
+ >>> session_state.user_name
68
+ ''
69
+ >>> session_state.user_name = 'Mary'
70
+ >>> session_state.favorite_color
71
+ 'black'
72
+
73
+ Since you set user_name above, next time your script runs this will be the
74
+ result:
75
+ >>> session_state = get(user_name='', favorite_color='black')
76
+ >>> session_state.user_name
77
+ 'Mary'
78
+
79
+ """
80
+ # Hack to get the session object from Streamlit.
81
+
82
+ ctx = ReportThread.get_report_ctx()
83
+
84
+ this_session = None
85
+
86
+ current_server = Server.get_current()
87
+ if hasattr(current_server, '_session_infos'):
88
+ # Streamlit < 0.56
89
+ session_infos = Server.get_current()._session_infos.values()
90
+ else:
91
+ session_infos = Server.get_current()._session_info_by_id.values()
92
+
93
+ for session_info in session_infos:
94
+ s = session_info.session
95
+ if (
96
+ # Streamlit < 0.54.0
97
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
98
+ or
99
+ # Streamlit >= 0.54.0
100
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
101
+ or
102
+ # Streamlit >= 0.65.2
103
+ (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
104
+ ):
105
+ this_session = s
106
+
107
+ if this_session is None:
108
+ raise RuntimeError(
109
+ "Oh noes. Couldn't get your Streamlit Session object. "
110
+ 'Are you doing something fancy with threads?')
111
+
112
+ # Got the session object! Now let's attach some state into it.
113
+
114
+ if not hasattr(this_session, '_custom_session_state'):
115
+ this_session._custom_session_state = SessionState(**kwargs)
116
+
117
+ return this_session._custom_session_state
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import random
5
+ import time
6
+ from collections import Counter
7
+ from datetime import datetime
8
+ from glob import glob
9
+
10
+ import gdown
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import streamlit as st
16
+ from PIL import Image
17
+
18
+ import SessionState
19
+ from download_utils import *
20
+ from image_utils import *
21
+
22
+ random.seed(datetime.now())
23
+ np.random.seed(int(time.time()))
24
+
25
+ NUMBER_OF_TRIALS = 20
26
+ CLASSIFIER_TAG = ""
27
+ selected_xai_tool = None
28
+
29
+ # Config
30
+ folder_to_name = {}
31
+ # class_descriptions = {}
32
+ classifier_predictions = {}
33
+ selected_dataset = "Task-1-CUB-iNat-HumanStudy"
34
+
35
+ root_visualization_dir = "./visualizations/"
36
+ viz_url = "https://static.taesiri.com/xai/CUB-Task1.zip"
37
+ viz_archivefile = "CUB-Final.zip"
38
+
39
+ demonstration_url = "https://static.taesiri.com/xai/cub-demonstrations.zip"
40
+ demonst_zipfile = "demonstrations.zip"
41
+
42
+ picklefile_url = "https://static.taesiri.com/xai/Task1-CUB-CHMOnly.pickle"
43
+ prediction_root = "./predictions/"
44
+ prediction_pickle = f"{prediction_root}predictions.pickle"
45
+
46
+
47
+ # Get the Data
48
+ download_files(
49
+ root_visualization_dir,
50
+ viz_url,
51
+ viz_archivefile,
52
+ demonstration_url,
53
+ demonst_zipfile,
54
+ picklefile_url,
55
+ prediction_root,
56
+ prediction_pickle,
57
+ )
58
+ ################################################
59
+ # GLOBAL VARIABLES
60
+ app_mode = ""
61
+
62
+ # Shared/Global Information
63
+ birds_list = list(
64
+ sorted([x.replace(".jpg", "") for x in os.listdir("./CUB-Demonstrations")])
65
+ )
66
+ id_to_bird = {i: x for i, x in enumerate(birds_list)}
67
+ folder_to_name = {x: x for x in birds_list} #
68
+ ################################################
69
+
70
+ with open(prediction_pickle, "rb") as f:
71
+ classifier_predictions = pickle.load(f)
72
+
73
+ # SESSION STATE
74
+ session_state = SessionState.get(
75
+ page=1,
76
+ first_run=1,
77
+ user_feedback={},
78
+ queries=[],
79
+ is_classifier_correct={},
80
+ XAI_tool="Unselected",
81
+ )
82
+ ################################################
83
+
84
+
85
+ def resmaple_queries():
86
+ if session_state.first_run == 1:
87
+ both_correct = glob(
88
+ root_visualization_dir + selected_dataset + "/Both_correct/*.jpg"
89
+ )
90
+ both_wrong = glob(
91
+ root_visualization_dir + selected_dataset + "/Both_wrong/*.jpg"
92
+ )
93
+
94
+ correct_samples = list(
95
+ np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False)
96
+ )
97
+ wrong_samples = list(
98
+ np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False)
99
+ )
100
+
101
+ all_images = correct_samples + wrong_samples
102
+ random.shuffle(all_images)
103
+ session_state.queries = all_images
104
+ session_state.first_run = -1
105
+ # RESET INTERACTIONS
106
+ session_state.user_feedback = {}
107
+ session_state.is_classifier_correct = {}
108
+
109
+
110
+ def render_experiment(query):
111
+ current_query = session_state.queries[query]
112
+ query_id = os.path.basename(current_query)
113
+
114
+ predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"]
115
+ prediction_confidence = classifier_predictions[query_id][
116
+ f"{CLASSIFIER_TAG}-confidence"
117
+ ]
118
+ prediction_label = folder_to_name[predicted_wnid]
119
+ # class_def = class_descriptions[predicted_wnid]
120
+
121
+ session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][
122
+ f"{CLASSIFIER_TAG.upper()}-Output"
123
+ ]
124
+
125
+ # SHOW QUERY and PREDICTION
126
+
127
+ col1, col2 = st.columns(2)
128
+ with col1:
129
+ st.image(load_query(current_query), caption=f"Query ID: {query_id}")
130
+ with col2:
131
+ # SHOW DESCRIPTION OF CLASS
132
+ with st.expander("Show Class Description"):
133
+ st.write(f"**Name**: {prediction_label}")
134
+ st.write("**Class Definition**:")
135
+ # st.markdown("`" + class_def + "`")
136
+ st.image(
137
+ Image.open(f"CUB-Demonstrations/{predicted_wnid}.jpg"),
138
+ caption=f"Class Explanation",
139
+ use_column_width=True,
140
+ )
141
+
142
+ default_value = 0
143
+ if query_id in session_state.user_feedback.keys():
144
+ if session_state.user_feedback[query_id] == "Correct":
145
+ default_value = 1
146
+ elif session_state.user_feedback[query_id] == "Wrong":
147
+ default_value = 2
148
+
149
+ session_state.user_feedback[query_id] = st.radio(
150
+ "What do you think about model's prediction?",
151
+ ("-", "Correct", "Wrong"),
152
+ key=query_id,
153
+ index=default_value,
154
+ )
155
+ st.write(f"**Model Prediction**: {prediction_label}")
156
+ st.write(f"**Model Confidence**: {prediction_confidence}")
157
+
158
+ # SHOW Model Explanation
159
+ if selected_xai_tool is not None:
160
+ st.image(
161
+ selected_xai_tool(current_query),
162
+ caption=f"Explaination",
163
+ use_column_width=True,
164
+ )
165
+
166
+ # SHOW DEBUG INFO
167
+
168
+ if st.button("Debug: Show Everything"):
169
+ st.image(Image.open(current_query))
170
+
171
+
172
+ def render_results():
173
+ user_correct_guess = 0
174
+ # st.write(session_state.user_feedback)
175
+ # st.write(session_state.is_classifier_correct)
176
+ for q in session_state.user_feedback.keys():
177
+ if session_state.user_feedback[q] != "-":
178
+ uf = True if session_state.user_feedback[q] == "Correct" else False
179
+ if session_state.is_classifier_correct[q] == uf:
180
+ user_correct_guess += 1
181
+
182
+ st.write(
183
+ f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct"
184
+ )
185
+ st.markdown("## User Performance Breakdown")
186
+
187
+ categories = [
188
+ "Correct",
189
+ "Wrong",
190
+ ] # set(session_state.is_classifier_correct.values())
191
+ breakdown_stats_correct = {c: 0 for c in categories}
192
+ breakdown_stats_wrong = {c: 0 for c in categories}
193
+
194
+ experiment_summary = []
195
+
196
+ for q in session_state.user_feedback.keys():
197
+ category = "Correct" if session_state.is_classifier_correct[q] else "Wrong"
198
+ is_user_correct = category == session_state.user_feedback[q]
199
+
200
+ if is_user_correct:
201
+ breakdown_stats_correct[category] += 1
202
+ else:
203
+ breakdown_stats_wrong[category] += 1
204
+
205
+ experiment_summary.append(
206
+ [
207
+ q,
208
+ classifier_predictions[q]["gt_wnid"],
209
+ folder_to_name[
210
+ classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"]
211
+ ],
212
+ category,
213
+ session_state.user_feedback[q],
214
+ is_user_correct,
215
+ ]
216
+ )
217
+ # Summary Table
218
+ experiment_summary_df = pd.DataFrame.from_records(
219
+ experiment_summary,
220
+ columns=[
221
+ "Query",
222
+ "GT Labels",
223
+ f"{CLASSIFIER_TAG} Prediction",
224
+ "Category",
225
+ "User Prediction",
226
+ "Is User Prediction Correct",
227
+ ],
228
+ )
229
+ st.write("Summary", experiment_summary_df)
230
+
231
+ csv = convert_df(experiment_summary_df)
232
+ st.download_button(
233
+ "Press to Download", csv, "summary.csv", "text/csv", key="download-records"
234
+ )
235
+ # SHOW BREAKDOWN
236
+ user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg(
237
+ {"Is User Prediction Correct": ["count", "sum", "mean"]}
238
+ )
239
+ # rename columns
240
+ user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0)
241
+ user_pf_by_model_pred.columns = [
242
+ "Count",
243
+ "Correct User Guess",
244
+ "Mean User Performance",
245
+ ]
246
+ user_pf_by_model_pred.index.name = "Model Prediction"
247
+ st.write("User performance break down by Model prediction:", user_pf_by_model_pred)
248
+ csv = convert_df(user_pf_by_model_pred)
249
+ st.download_button(
250
+ "Press to Download",
251
+ csv,
252
+ "user-performance-by-model-prediction.csv",
253
+ "text/csv",
254
+ key="download-performance-by-model-prediction",
255
+ )
256
+ # CONFUSION MATRIX
257
+
258
+ confusion_matrix = pd.crosstab(
259
+ experiment_summary_df["Category"],
260
+ experiment_summary_df["User Prediction"],
261
+ rownames=["Actual"],
262
+ colnames=["Predicted"],
263
+ )
264
+ st.write("Confusion Matrix", confusion_matrix)
265
+ csv = convert_df(confusion_matrix)
266
+ st.download_button(
267
+ "Press to Download",
268
+ csv,
269
+ "confusion-matrix.csv",
270
+ "text/csv",
271
+ key="download-confusiion-matrix",
272
+ )
273
+
274
+
275
+ def render_menu():
276
+ # Render the readme as markdown using st.markdown.
277
+ readme_text = st.markdown(
278
+ """
279
+ # Instructions
280
+ ```
281
+ When testing this study, you should first see the class definition, then hide the expander and see the query.
282
+ ```
283
+ """
284
+ )
285
+
286
+ app_mode = st.selectbox(
287
+ "Choose the page to show:",
288
+ ["Experiment Instruction", "Start Experiment", "See the Results"],
289
+ )
290
+
291
+ if app_mode == "Experiment Instruction":
292
+ st.success("To continue select an option in the dropdown menu.")
293
+ elif app_mode == "Start Experiment":
294
+ # Clear Canvas
295
+ readme_text.empty()
296
+
297
+ page_id = session_state.page
298
+ col1, col4, col2, col3 = st.columns(4)
299
+ prev_page = col1.button("Previous Image")
300
+
301
+ if prev_page:
302
+ page_id -= 1
303
+ if page_id < 1:
304
+ page_id = 1
305
+
306
+ next_page = col2.button("Next Image")
307
+
308
+ if next_page:
309
+ page_id += 1
310
+ if page_id > NUMBER_OF_TRIALS:
311
+ page_id = NUMBER_OF_TRIALS
312
+
313
+ if page_id == NUMBER_OF_TRIALS:
314
+ st.success(
315
+ 'You have reached the last image. Please go to the "Results" page to see your performance.'
316
+ )
317
+ if st.button("View"):
318
+ app_mode = "See the Results"
319
+
320
+ if col3.button("Resample"):
321
+ st.write("Restarting ...")
322
+ page_id = 1
323
+ session_state.first_run = 1
324
+ resmaple_queries()
325
+
326
+ session_state.page = page_id
327
+ st.write(f"Render Experiment: {session_state.page}")
328
+ render_experiment(session_state.page - 1)
329
+ elif app_mode == "See the Results":
330
+ readme_text.empty()
331
+ st.write("Results Summary")
332
+ render_results()
333
+
334
+
335
+ def main():
336
+ global app_mode
337
+ global session_state
338
+ global selected_xai_tool
339
+ global CLASSIFIER_TAG
340
+
341
+ # Set the session state
342
+ # State Management and General Setup
343
+ st.set_page_config(layout="wide")
344
+ st.title("TASK - 1 - CUB")
345
+
346
+ # st.write(classifier_predictions.keys())
347
+ # st.write(classifier_predictions["ILSVRC2012_val_00024646.JPEG"])
348
+
349
+ options = [
350
+ "Unselected",
351
+ "NOXAI",
352
+ "KNN",
353
+ # "EMD Nearest Neighbors",
354
+ # "EMD Correspondence",
355
+ "CHM Nearest Neighbors",
356
+ "CHM Correspondence",
357
+ ]
358
+
359
+ st.markdown(
360
+ """ <style>
361
+ div[role="radiogroup"] > :first-child{
362
+ display: none !important;
363
+ }
364
+ </style>
365
+ """,
366
+ unsafe_allow_html=True,
367
+ )
368
+
369
+ if session_state.XAI_tool == "Unselected":
370
+ default = options.index(session_state.XAI_tool)
371
+ session_state.XAI_tool = st.radio(
372
+ "What explaination tool do you want to evaluate?",
373
+ options,
374
+ key="which_xai",
375
+ index=default,
376
+ )
377
+ # print(session_state.XAI_tool)
378
+
379
+ if session_state.XAI_tool != "Unselected":
380
+ st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``")
381
+
382
+ if session_state.XAI_tool == "NOXAI":
383
+ CLASSIFIER_TAG = "knn"
384
+ selected_xai_tool = None
385
+ elif session_state.XAI_tool == "KNN":
386
+ selected_xai_tool = load_knn_nns
387
+ CLASSIFIER_TAG = "knn"
388
+ elif session_state.XAI_tool == "CHM Nearest Neighbors":
389
+ selected_xai_tool = load_chm_nns
390
+ CLASSIFIER_TAG = "CHM"
391
+ elif session_state.XAI_tool == "CHM Correspondence":
392
+ selected_xai_tool = load_chm_corrs
393
+ CLASSIFIER_TAG = "CHM"
394
+ elif session_state.XAI_tool == "EMD Nearest Neighbors":
395
+ selected_xai_tool = load_emd_nns
396
+ CLASSIFIER_TAG = "EMD"
397
+ elif session_state.XAI_tool == "EMD Correspondence":
398
+ selected_xai_tool = load_emd_corrs
399
+ CLASSIFIER_TAG = "EMD"
400
+
401
+ resmaple_queries()
402
+ render_menu()
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()
download_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import random
5
+ import tarfile
6
+ import zipfile
7
+ from collections import Counter
8
+ from glob import glob
9
+
10
+ import gdown
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import streamlit as st
16
+ from PIL import Image
17
+
18
+ import SessionState
19
+
20
+
21
+ def download_files(
22
+ root_visualization_dir,
23
+ viz_url,
24
+ viz_archivefile,
25
+ demonstration_url,
26
+ demonst_zipfile,
27
+ picklefile_url,
28
+ prediction_root,
29
+ prediction_pickle,
30
+ ):
31
+ # Get Visualization
32
+ if not os.path.exists(root_visualization_dir):
33
+ gdown.download(viz_url, viz_archivefile, quiet=False)
34
+ os.makedirs(root_visualization_dir, exist_ok=True)
35
+
36
+ if viz_archivefile.endswith("tar.gz"):
37
+ tar = tarfile.open(viz_archivefile, "r:gz")
38
+ tar.extractall(path=root_visualization_dir)
39
+ tar.close()
40
+ elif viz_archivefile.endswith("zip"):
41
+ with zipfile.ZipFile(viz_archivefile, "r") as zip_ref:
42
+ zip_ref.extractall(root_visualization_dir)
43
+
44
+ # Get Demonstrations
45
+ if not os.path.exists(demonst_zipfile):
46
+ gdown.download(demonstration_url, demonst_zipfile, quiet=False)
47
+ # os.makedirs(roo_demonstration_dir, exist_ok=True)
48
+
49
+ with zipfile.ZipFile(demonst_zipfile, "r") as zip_ref:
50
+ zip_ref.extractall("./")
51
+
52
+ # Get Predictions
53
+ if not os.path.exists(prediction_pickle):
54
+ os.makedirs(prediction_root, exist_ok=True)
55
+ gdown.download(picklefile_url, prediction_pickle, quiet=False)
helper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def get_label_for_query(image_url, model_name):
4
+ fourway_label = image_url.split('/')[-2]
5
+
6
+ if fourway_label=='both_correct':
7
+ return 'Correct'
8
+
9
+ if fourway_label=='both_wrong':
10
+ return 'Wrong'
11
+
12
+ if fourway_label == 'chm_correct_knn_incorrect' and model_name == 'CHM':
13
+ return 'Correct'
14
+ elif fourway_label == 'knn_correct_chm_incorrect' and model_name == 'KNN':
15
+ return 'Correct'
16
+
17
+ return 'Wrong'
18
+
19
+ def get_category(image_url):
20
+ return image_url.split('/')[-2]
21
+
22
+ def translate_winds_to_names(winds):
23
+ return [folder_to_name[x] for x in winds]
image_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import random
5
+ from glob import glob
6
+
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ import streamlit as st
11
+ from PIL import Image
12
+
13
+
14
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
15
+ def load_query(image_path):
16
+ image = Image.open(image_path)
17
+ width, height = image.size
18
+
19
+ new_width = width
20
+ new_height = height
21
+
22
+ left = (width - new_width) / 2
23
+ top = (height - new_height) / 2
24
+ right = (width + new_width) / 2
25
+ bottom = (height + new_height) / 2
26
+
27
+ # Crop the center of the image
28
+ cropped_image = image.crop(
29
+ (left + 75, top + 145, right - 1790, bottom - (1140))
30
+ ).resize((300, 300))
31
+
32
+ return cropped_image
33
+
34
+
35
+ # CHM ############################################################################
36
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
37
+ def load_chm_nns(image_path):
38
+ image = Image.open(image_path)
39
+ width, height = image.size
40
+
41
+ new_width = width
42
+ new_height = height
43
+
44
+ left = (width - new_width) / 2
45
+ top = (height - new_height) / 2
46
+ right = (width + new_width) / 2
47
+ bottom = (height + new_height) / 2
48
+
49
+ # Crop the center of the image
50
+ cropped_image = image.crop((left + 485, top + 145, right - 15, bottom - (1140)))
51
+ return cropped_image
52
+
53
+
54
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
55
+ def load_chm_corrs(image_path):
56
+ image = Image.open(image_path)
57
+ width, height = image.size
58
+
59
+ new_width = width
60
+ new_height = height
61
+
62
+ left = (width - new_width) / 2
63
+ top = (height - new_height) / 2
64
+ right = (width + new_width) / 2
65
+ bottom = (height + new_height) / 2
66
+
67
+ # Crop the center of the image
68
+ cropped_image = image.crop((left + 485, top + 900, right - 15, bottom - (25 + 10)))
69
+ return cropped_image
70
+
71
+
72
+ # CHM ############################################################################
73
+
74
+ # KNN ############################################################################
75
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
76
+ def load_knn_nns(image_path):
77
+ image = Image.open(image_path)
78
+ width, height = image.size
79
+
80
+ new_width = width
81
+ new_height = height
82
+
83
+ left = (width - new_width) / 2
84
+ top = (height - new_height) / 2
85
+ right = (width + new_width) / 2
86
+ bottom = (height + new_height) / 2
87
+
88
+ # Crop the center of the image
89
+ cropped_image = image.crop((left + 485, top + 525, right - 10, bottom - (770)))
90
+ return cropped_image
91
+
92
+
93
+ # KNN ############################################################################
94
+
95
+ # EMD ############################################################################
96
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
97
+ def load_emd_nns(image_path):
98
+ image = Image.open(image_path)
99
+ width, height = image.size
100
+
101
+ new_width = width
102
+ new_height = height
103
+
104
+ left = (width - new_width) / 2
105
+ top = (height - new_height) / 2
106
+ right = (width + new_width) / 2
107
+ bottom = (height + new_height) / 2
108
+
109
+ # Crop the center of the image
110
+ cropped_image = image.crop(
111
+ (left + 10, top + 2075, right - 420, bottom - (925 + 25 + 10))
112
+ )
113
+ return cropped_image
114
+
115
+
116
+ @st.cache(allow_output_mutation=True, max_entries=10, ttl=3600)
117
+ def load_emd_corrs(image_path):
118
+ image = Image.open(image_path)
119
+ width, height = image.size
120
+
121
+ new_width = width
122
+ new_height = height
123
+
124
+ left = (width - new_width) / 2
125
+ top = (height - new_height) / 2
126
+ right = (width + new_width) / 2
127
+ bottom = (height + new_height) / 2
128
+
129
+ # Crop the center of the image
130
+ cropped_image = image.crop((left + 10, top + 2500, right - 20, bottom))
131
+ return cropped_image
132
+
133
+
134
+ # EMD ############################################################################
135
+
136
+
137
+ @st.cache()
138
+ def convert_df(df):
139
+ return df.to_csv().encode("utf-8")