taesiri commited on
Commit
90dfd54
·
1 Parent(s): a327867
Files changed (2) hide show
  1. app.py +42 -16
  2. image_utils.py +6 -8
app.py CHANGED
@@ -30,16 +30,16 @@ selected_xai_tool = None
30
  folder_to_name = {}
31
  # class_descriptions = {}
32
  classifier_predictions = {}
33
- selected_dataset = "Task-1-CUB-iNat-HumanStudy"
34
 
35
  root_visualization_dir = "./visualizations/"
36
- viz_url = "https://static.taesiri.com/xai/CUB-Task1.zip"
37
  viz_archivefile = "CUB-Final.zip"
38
 
39
  demonstration_url = "https://static.taesiri.com/xai/CUB-Demonstrations.zip"
40
  demonst_zipfile = "demonstrations.zip"
41
 
42
- picklefile_url = "https://static.taesiri.com/xai/Task1-CUB-CHMOnly.pickle"
43
  prediction_root = "./predictions/"
44
  prediction_pickle = f"{prediction_root}predictions.pickle"
45
 
@@ -84,22 +84,48 @@ session_state = SessionState.get(
84
 
85
  def resmaple_queries():
86
  if session_state.first_run == 1:
87
- both_correct = glob(
88
- root_visualization_dir + selected_dataset + "/Both_correct/*.jpg"
89
- )
90
- both_wrong = glob(
91
- root_visualization_dir + selected_dataset + "/Both_wrong/*.jpg"
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  correct_samples = list(
95
- np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False)
 
 
96
  )
97
  wrong_samples = list(
98
- np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False)
 
 
99
  )
100
 
101
  all_images = correct_samples + wrong_samples
102
  random.shuffle(all_images)
 
103
  session_state.queries = all_images
104
  session_state.first_run = -1
105
  # RESET INTERACTIONS
@@ -109,7 +135,7 @@ def resmaple_queries():
109
 
110
  def render_experiment(query):
111
  current_query = session_state.queries[query]
112
- query_id = os.path.basename(current_query)
113
 
114
  predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"]
115
  prediction_confidence = classifier_predictions[query_id][
@@ -350,8 +376,8 @@ def main():
350
  "Unselected",
351
  "NOXAI",
352
  "KNN",
353
- # "EMD Nearest Neighbors",
354
- # "EMD Correspondence",
355
  "CHM Nearest Neighbors",
356
  "CHM Correspondence",
357
  ]
@@ -380,11 +406,11 @@ def main():
380
  st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``")
381
 
382
  if session_state.XAI_tool == "NOXAI":
383
- CLASSIFIER_TAG = "knn"
384
  selected_xai_tool = None
385
  elif session_state.XAI_tool == "KNN":
386
  selected_xai_tool = load_knn_nns
387
- CLASSIFIER_TAG = "knn"
388
  elif session_state.XAI_tool == "CHM Nearest Neighbors":
389
  selected_xai_tool = load_chm_nns
390
  CLASSIFIER_TAG = "CHM"
 
30
  folder_to_name = {}
31
  # class_descriptions = {}
32
  classifier_predictions = {}
33
+ selected_dataset = "CUB-iNAt-Unified"
34
 
35
  root_visualization_dir = "./visualizations/"
36
+ viz_url = "https://static.taesiri.com/xai/CUB-iNAt-Unified.zip"
37
  viz_archivefile = "CUB-Final.zip"
38
 
39
  demonstration_url = "https://static.taesiri.com/xai/CUB-Demonstrations.zip"
40
  demonst_zipfile = "demonstrations.zip"
41
 
42
+ picklefile_url = "https://static.taesiri.com/xai/Task1-CUB-ALL.pickle"
43
  prediction_root = "./predictions/"
44
  prediction_pickle = f"{prediction_root}predictions.pickle"
45
 
 
84
 
85
  def resmaple_queries():
86
  if session_state.first_run == 1:
87
+
88
+ # EMD_Corrent = [k for k, v in classifier_predictions.items() if v["EMD-Output"]]
89
+ # EMD_Wrong = [k for k, v in classifier_predictions.items() if not v["EMD-Output"]]
90
+ # KNN_Corrent = [k for k, v in classifier_predictions.items() if v["KNN-Output"]]
91
+ # KNN_Wrong = [k for k, v in classifier_predictions.items() if not v["KNN-Output"]]
92
+ # CHM_Corrent = [k for k, v in classifier_predictions.items() if v["CHM-Output"]]
93
+ # CHM_Wrong = [k for k, v in classifier_predictions.items() if not v["CHM-Output"]]
94
+
95
+ Corret_predictions_idx = [
96
+ k
97
+ for k, v in classifier_predictions.items()
98
+ if v[f"{CLASSIFIER_TAG}-Output"]
99
+ ]
100
+ Wrong_predictions_idx = [
101
+ k
102
+ for k, v in classifier_predictions.items()
103
+ if not v[f"{CLASSIFIER_TAG}-Output"]
104
+ ]
105
+
106
+ correct_classified_plots = [
107
+ f"{root_visualization_dir}{selected_dataset}/cub-inat-{x}.jpeg"
108
+ for x in Corret_predictions_idx
109
+ ]
110
+ wrong_classified_plots = [
111
+ f"{root_visualization_dir}{selected_dataset}/cub-inat-{x}.jpeg"
112
+ for x in Wrong_predictions_idx
113
+ ]
114
 
115
  correct_samples = list(
116
+ np.random.choice(
117
+ a=correct_classified_plots, size=NUMBER_OF_TRIALS // 2, replace=False
118
+ )
119
  )
120
  wrong_samples = list(
121
+ np.random.choice(
122
+ a=wrong_classified_plots, size=NUMBER_OF_TRIALS // 2, replace=False
123
+ )
124
  )
125
 
126
  all_images = correct_samples + wrong_samples
127
  random.shuffle(all_images)
128
+
129
  session_state.queries = all_images
130
  session_state.first_run = -1
131
  # RESET INTERACTIONS
 
135
 
136
  def render_experiment(query):
137
  current_query = session_state.queries[query]
138
+ query_id = int(os.path.basename(current_query).split("-")[2].split(".")[0])
139
 
140
  predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"]
141
  prediction_confidence = classifier_predictions[query_id][
 
376
  "Unselected",
377
  "NOXAI",
378
  "KNN",
379
+ "EMD Nearest Neighbors",
380
+ "EMD Correspondence",
381
  "CHM Nearest Neighbors",
382
  "CHM Correspondence",
383
  ]
 
406
  st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``")
407
 
408
  if session_state.XAI_tool == "NOXAI":
409
+ CLASSIFIER_TAG = "KNN"
410
  selected_xai_tool = None
411
  elif session_state.XAI_tool == "KNN":
412
  selected_xai_tool = load_knn_nns
413
+ CLASSIFIER_TAG = "KNN"
414
  elif session_state.XAI_tool == "CHM Nearest Neighbors":
415
  selected_xai_tool = load_chm_nns
416
  CLASSIFIER_TAG = "CHM"
image_utils.py CHANGED
@@ -26,7 +26,7 @@ def load_query(image_path):
26
 
27
  # Crop the center of the image
28
  cropped_image = image.crop(
29
- (left + 75, top + 145, right - 1790, bottom - (1140))
30
  ).resize((300, 300))
31
 
32
  return cropped_image
@@ -47,7 +47,7 @@ def load_chm_nns(image_path):
47
  bottom = (height + new_height) / 2
48
 
49
  # Crop the center of the image
50
- cropped_image = image.crop((left + 485, top + 145, right - 15, bottom - (1140)))
51
  return cropped_image
52
 
53
 
@@ -65,7 +65,7 @@ def load_chm_corrs(image_path):
65
  bottom = (height + new_height) / 2
66
 
67
  # Crop the center of the image
68
- cropped_image = image.crop((left + 485, top + 900, right - 15, bottom - (25 + 10)))
69
  return cropped_image
70
 
71
 
@@ -86,7 +86,7 @@ def load_knn_nns(image_path):
86
  bottom = (height + new_height) / 2
87
 
88
  # Crop the center of the image
89
- cropped_image = image.crop((left + 485, top + 525, right - 10, bottom - (770)))
90
  return cropped_image
91
 
92
 
@@ -107,9 +107,7 @@ def load_emd_nns(image_path):
107
  bottom = (height + new_height) / 2
108
 
109
  # Crop the center of the image
110
- cropped_image = image.crop(
111
- (left + 10, top + 2075, right - 420, bottom - (925 + 25 + 10))
112
- )
113
  return cropped_image
114
 
115
 
@@ -127,7 +125,7 @@ def load_emd_corrs(image_path):
127
  bottom = (height + new_height) / 2
128
 
129
  # Crop the center of the image
130
- cropped_image = image.crop((left + 10, top + 2500, right - 20, bottom))
131
  return cropped_image
132
 
133
 
 
26
 
27
  # Crop the center of the image
28
  cropped_image = image.crop(
29
+ (left + 5, top + 40, right - 2125, bottom - (2805))
30
  ).resize((300, 300))
31
 
32
  return cropped_image
 
47
  bottom = (height + new_height) / 2
48
 
49
  # Crop the center of the image
50
+ cropped_image = image.crop((left + 525, top + 2830, right - 0, bottom - (10)))
51
  return cropped_image
52
 
53
 
 
65
  bottom = (height + new_height) / 2
66
 
67
  # Crop the center of the image
68
+ cropped_image = image.crop((left + 15, top + 1835, right - 45, bottom - 445))
69
  return cropped_image
70
 
71
 
 
86
  bottom = (height + new_height) / 2
87
 
88
  # Crop the center of the image
89
+ cropped_image = image.crop((left + 525, top + 40, right - 10, bottom - (2805)))
90
  return cropped_image
91
 
92
 
 
107
  bottom = (height + new_height) / 2
108
 
109
  # Crop the center of the image
110
+ cropped_image = image.crop((left + 525, top + 480, right - 5, bottom - (2365)))
 
 
111
  return cropped_image
112
 
113
 
 
125
  bottom = (height + new_height) / 2
126
 
127
  # Crop the center of the image
128
+ cropped_image = image.crop((left + 90, top + 880, right - 75, bottom - 1438))
129
  return cropped_image
130
 
131