devjas1 commited on
Commit
9318b04
Β·
1 Parent(s): 345529d

FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering

Browse files
Files changed (1) hide show
  1. modules/analyzer.py +144 -102
modules/analyzer.py CHANGED
@@ -17,21 +17,21 @@ from modules.ui_components import create_spectrum_plot
17
  import hashlib
18
 
19
 
20
- # --- NEW HELPER FUNCTION for theme-aware plots ---
21
  @contextmanager
22
- def theme_aware_plot():
23
- """A context manager to make Matplotlib plots respect Streamlit's theme."""
24
- # Get the current theme from Streamlit's config with error handling
 
 
25
  try:
26
  theme_opts = st.get_option("theme") or {}
27
  except RuntimeError:
28
  # Fallback to empty dict if theme config is not available
29
  theme_opts = {}
30
-
31
  text_color = theme_opts.get("textColor", "#000000")
32
  bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
33
 
34
- # Set Matplotlib's rcParams to match the theme
35
  with plt.rc_context(
36
  {
37
  "figure.facecolor": bg_color,
@@ -42,12 +42,18 @@ def theme_aware_plot():
42
  "ytick.color": text_color,
43
  "grid.color": text_color,
44
  "axes.edgecolor": text_color,
 
 
45
  }
46
  ):
47
- yield
 
 
 
 
48
 
49
 
50
- # --- END HELPER FUNCTION ---
51
 
52
 
53
  class BatchAnalysis:
@@ -105,12 +111,10 @@ class BatchAnalysis:
105
  ),
106
  )
107
 
108
- # In modules/analyzer.py
109
-
110
  def render_visual_diagnostics(self):
111
  """
112
- Renders the main diagnostic plots with improved aesthetics, layout,
113
- and automatic theme adaptation.
114
  """
115
  st.markdown("##### Visual Analysis")
116
  if not self.has_ground_truth:
@@ -118,22 +122,18 @@ class BatchAnalysis:
118
  return
119
 
120
  valid_gt_df = self.df.dropna(subset=["Ground Truth"])
121
-
122
- # Use a single row of columns for the two main plots
123
  plot_col1, plot_col2 = st.columns(2)
124
 
125
- # --- Chart 1: Confusion Matrix ---
126
- with plot_col1: # Content for the first column
127
- with st.container(border=True): # Group plot and buttons visually
128
  st.markdown("**Confusion Matrix**")
129
  cm = confusion_matrix(
130
  valid_gt_df["Ground Truth"],
131
  valid_gt_df["Prediction"],
132
  labels=list(LABEL_MAP.keys()),
133
  )
134
-
135
- with theme_aware_plot(): # Apply theme-aware styling
136
- fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
137
  sns.heatmap(
138
  cm,
139
  annot=True,
@@ -145,58 +145,98 @@ class BatchAnalysis:
145
  )
146
  ax.set_ylabel("Actual Class", fontsize=12)
147
  ax.set_xlabel("Predicted Class", fontsize=12)
148
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
 
 
149
  ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
150
- st.pyplot(fig, use_container_width=True) # Render the plot
151
-
152
- st.caption("Click a cell below to filter the data grid:")
153
-
154
- # Render CM filter buttons directly below the plot in the same column
155
- cm_labels = list(LABEL_MAP.values())
156
- for i, actual_label in enumerate(cm_labels):
157
- btn_cols_row = st.columns(
158
- len(cm_labels)
159
- ) # Create a row of columns for buttons
160
- for j, predicted_label in enumerate(cm_labels):
161
- cell_value = cm[i, j]
162
- btn_cols_row[j].button( # Button for each cell
163
- f"Actual: {actual_label}\nPred: {predicted_label} ({cell_value})",
164
- key=f"cm_cell_{i}_{j}",
165
- on_click=self._set_cm_filter,
166
- args=(i, j, actual_label, predicted_label),
167
- use_container_width=True,
168
- )
169
- # Clear filter button for CM
170
- if st.session_state.get("cm_filter_active", False):
171
- st.button(
172
- "Clear Matrix Filter",
173
- on_click=self._clear_cm_filter,
174
- key="clear_cm_filter_btn_below",
175
- )
176
 
177
- # --- Chart 2: Confidence vs. Correctness Box Plot ---
178
- with plot_col2: # Content for the second column
179
- with st.container(border=True): # Group plot visually
180
  st.markdown("**Confidence Analysis**")
181
  valid_gt_df["Result"] = np.where(
182
  valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
183
- "Correct",
184
- "Incorrect",
185
  )
186
-
187
- with theme_aware_plot(): # Apply theme-aware styling
188
- fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
189
  sns.boxplot(
190
  x="Result",
191
  y="Confidence",
192
  data=valid_gt_df,
193
  ax=ax,
194
- palette={"Correct": "#64C764", "Incorrect": "#E57373"},
195
  )
196
  ax.set_ylabel("Model Confidence", fontsize=12)
197
- ax.set_xlabel("Prediction Result", fontsize=12)
 
198
  st.pyplot(fig, use_container_width=True)
199
- st.divider() # Divider after the entire visual section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def _set_cm_filter(
202
  self,
@@ -235,56 +275,58 @@ class BatchAnalysis:
235
  # Start with a full copy of the dataframe to apply filters to
236
  filtered_df = self.df.copy()
237
 
238
- # --- Filter Section ---
239
- st.markdown("**Filters**")
240
- filter_cols = st.columns([2, 2, 3]) # Allocate more space for the slider
241
-
242
- # Filter 1: By Predicted Class
243
- selected_classes = filter_cols[0].multiselect(
244
- "Filter by Prediction:",
245
- options=self.df["Predicted Class"].unique(),
246
- default=self.df["Predicted Class"].unique(),
247
- )
248
- filtered_df = filtered_df[filtered_df["Predicted Class"].isin(selected_classes)]
249
-
250
- # Filter 2: By Ground Truth Correctness (if available)
251
- if self.has_ground_truth:
252
- filtered_df["Correct"] = (
253
- filtered_df["Prediction"] == filtered_df["Ground Truth"]
254
  )
255
- correctness_options = ["βœ… Correct", "❌ Incorrect"]
 
 
256
 
257
- # Create a temporary column for display in multiselect
258
- filtered_df["Result_Display"] = np.where(
259
- filtered_df["Correct"], "βœ… Correct", "❌ Incorrect"
260
- )
 
 
 
 
 
261
 
262
- selected_correctness = filter_cols[1].multiselect(
263
- "Filter by Result:",
264
- options=correctness_options,
265
- default=correctness_options,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  )
267
- # Filter based on the boolean 'Correct' column
268
- filter_correctness_bools = [
269
- True if c == "βœ… Correct" else False for c in selected_correctness
270
- ]
271
  filtered_df = filtered_df[
272
- filtered_df["Correct"].isin(filter_correctness_bools)
 
273
  ]
274
-
275
- # --- NEW: Filter 3: By Confidence Range ---
276
- min_conf, max_conf = filter_cols[2].slider(
277
- "Filter by Confidence Range:",
278
- min_value=0.0,
279
- max_value=1.0,
280
- value=(0.0, 1.0), # Default to the full range
281
- step=0.01,
282
- )
283
- filtered_df = filtered_df[
284
- (filtered_df["Confidence"] >= min_conf)
285
- & (filtered_df["Confidence"] <= max_conf)
286
- ]
287
- # --- END NEW FILTER ---
288
 
289
  # Apply Confusion Matrix Drill-Down Filter (if active)
290
  if st.session_state.get("cm_filter_active", False):
 
17
  import hashlib
18
 
19
 
20
+ # --- NEW: Centralized plot styling helper ---
21
  @contextmanager
22
+ def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs):
23
+ """
24
+ A context manager to apply consistent Matplotlib styling and
25
+ make plots theme-aware.
26
+ """
27
  try:
28
  theme_opts = st.get_option("theme") or {}
29
  except RuntimeError:
30
  # Fallback to empty dict if theme config is not available
31
  theme_opts = {}
 
32
  text_color = theme_opts.get("textColor", "#000000")
33
  bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
34
 
 
35
  with plt.rc_context(
36
  {
37
  "figure.facecolor": bg_color,
 
42
  "ytick.color": text_color,
43
  "grid.color": text_color,
44
  "axes.edgecolor": text_color,
45
+ "axes.titlecolor": text_color, # Ensure title color matches
46
+ "figure.autolayout": True, # Auto-adjusts subplot params for a tight layout
47
  }
48
  ):
49
+ fig, ax = plt.subplots(
50
+ figsize=figsize, constrained_layout=constrained_layout, **kwargs
51
+ )
52
+ yield fig, ax
53
+ plt.close(fig) # Always close figure to prevent memory leaks
54
 
55
 
56
+ # --- END NEW HELPER ---
57
 
58
 
59
  class BatchAnalysis:
 
111
  ),
112
  )
113
 
 
 
114
  def render_visual_diagnostics(self):
115
  """
116
+ Renders diagnostic plots with corrected aesthetics and a robust,
117
+ interactive drill-down filter using st.selectbox.
118
  """
119
  st.markdown("##### Visual Analysis")
120
  if not self.has_ground_truth:
 
122
  return
123
 
124
  valid_gt_df = self.df.dropna(subset=["Ground Truth"])
 
 
125
  plot_col1, plot_col2 = st.columns(2)
126
 
127
+ # --- Chart 1: Confusion Matrix (Aesthetically Corrected) ---
128
+ with plot_col1:
129
+ with st.container(border=True):
130
  st.markdown("**Confusion Matrix**")
131
  cm = confusion_matrix(
132
  valid_gt_df["Ground Truth"],
133
  valid_gt_df["Prediction"],
134
  labels=list(LABEL_MAP.keys()),
135
  )
136
+ with plot_style_context() as (fig, ax):
 
 
137
  sns.heatmap(
138
  cm,
139
  annot=True,
 
145
  )
146
  ax.set_ylabel("Actual Class", fontsize=12)
147
  ax.set_xlabel("Predicted Class", fontsize=12)
148
+
149
+ # --- AESTHETIC FIX: Rotate X-labels vertically for a compact look ---
150
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
151
  ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
152
+ ax.set_title("Prediction vs. Actual (Counts)", fontsize=14)
153
+ st.pyplot(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) ---
156
+ with plot_col2:
157
+ with st.container(border=True):
158
  st.markdown("**Confidence Analysis**")
159
  valid_gt_df["Result"] = np.where(
160
  valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
161
+ "βœ… Correct",
162
+ "❌ Incorrect",
163
  )
164
+ with plot_style_context() as (fig, ax):
 
 
165
  sns.boxplot(
166
  x="Result",
167
  y="Confidence",
168
  data=valid_gt_df,
169
  ax=ax,
170
+ palette={"βœ… Correct": "lightgreen", "❌ Incorrect": "salmon"},
171
  )
172
  ax.set_ylabel("Model Confidence", fontsize=12)
173
+ ax.set_xlabel("Prediction Outcome", fontsize=12)
174
+ ax.set_title("Confidence Distribution by Outcome", fontsize=14)
175
  st.pyplot(fig, use_container_width=True)
176
+
177
+ st.divider()
178
+
179
+ # --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox ---
180
+ st.markdown("###### Interactive Confusion Matrix Drill-Down")
181
+ st.caption(
182
+ "Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab."
183
+ )
184
+
185
+ # Create a list of options for the selectbox from the confusion matrix
186
+ cm = confusion_matrix(
187
+ valid_gt_df["Ground Truth"],
188
+ valid_gt_df["Prediction"],
189
+ labels=list(LABEL_MAP.keys()),
190
+ )
191
+ cm_labels = list(LABEL_MAP.values())
192
+ options = ["-- Select a cell to filter --"]
193
+
194
+ # This nested loop creates the human-readable options for the dropdown
195
+ for i, actual_label in enumerate(cm_labels):
196
+ for j, predicted_label in enumerate(cm_labels):
197
+ cell_value = cm[i, j]
198
+ # We only add cells with content to the dropdown to avoid clutter
199
+ if cell_value > 0:
200
+ option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)"
201
+ options.append(option_str)
202
+
203
+ # The selectbox widget, which is more robust for state management
204
+ selected_option = st.selectbox(
205
+ "Drill-Down Filter",
206
+ options=options,
207
+ key="cm_selectbox", # Give it a key to track its state
208
+ index=0, # Default to the placeholder
209
+ )
210
+
211
+ # Logic to activate or deactivate the filter based on selection
212
+ if selected_option != "-- Select a cell to filter --":
213
+ # Parse the selection to get the actual and predicted classes
214
+ parts = selected_option.split("|")
215
+ actual_str = parts[0].replace("Actual: ", "").strip()
216
+ # FIX: Split on " (" to get the full label without the file count
217
+ predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip()
218
+
219
+ # Find the corresponding numeric indices with error handling
220
+ actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str]
221
+ predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str]
222
+
223
+ if not actual_matching or not predicted_matching:
224
+ return
225
+
226
+ actual_idx = actual_matching[0]
227
+ predicted_idx = predicted_matching[0]
228
+
229
+ # Use a simplified callback-like update to session state
230
+ st.session_state["cm_actual_filter"] = actual_idx
231
+ st.session_state["cm_predicted_filter"] = predicted_idx
232
+ st.session_state["cm_filter_label"] = (
233
+ f"Actual: {actual_str}, Predicted: {predicted_str}"
234
+ )
235
+ st.session_state["cm_filter_active"] = True
236
+ else:
237
+ # If the user selects the placeholder, deactivate the filter
238
+ if st.session_state.get("cm_filter_active", False):
239
+ self._clear_cm_filter()
240
 
241
  def _set_cm_filter(
242
  self,
 
275
  # Start with a full copy of the dataframe to apply filters to
276
  filtered_df = self.df.copy()
277
 
278
+ # --- Filter Section (STREAMLINED LAYOUT) ---
279
+ with st.container(border=True):
280
+ st.markdown("**Filters**")
281
+ filter_row1 = st.columns([1, 1])
282
+ filter_row2 = st.columns(1) # Slider takes full width
283
+
284
+ # Filter 1: By Predicted Class
285
+ selected_classes = filter_row1[0].multiselect(
286
+ "Filter by Prediction:",
287
+ options=self.df["Predicted Class"].unique(),
288
+ default=self.df["Predicted Class"].unique(),
 
 
 
 
 
289
  )
290
+ filtered_df = filtered_df[
291
+ filtered_df["Predicted Class"].isin(selected_classes)
292
+ ]
293
 
294
+ # Filter 2: By Ground Truth Correctness (if available)
295
+ if self.has_ground_truth:
296
+ filtered_df["Correct"] = (
297
+ filtered_df["Prediction"] == filtered_df["Ground Truth"]
298
+ )
299
+ correctness_options = ["βœ… Correct", "❌ Incorrect"]
300
+ filtered_df["Result_Display"] = np.where(
301
+ filtered_df["Correct"], "βœ… Correct", "❌ Incorrect"
302
+ )
303
 
304
+ selected_correctness = filter_row1[1].multiselect(
305
+ "Filter by Result:",
306
+ options=correctness_options,
307
+ default=correctness_options,
308
+ )
309
+ filter_correctness_bools = [
310
+ True if c == "βœ… Correct" else False for c in selected_correctness
311
+ ]
312
+ filtered_df = filtered_df[
313
+ filtered_df["Correct"].isin(filter_correctness_bools)
314
+ ]
315
+
316
+ # Filter 3: By Confidence Range (full width below others)
317
+ min_conf, max_conf = filter_row2[0].slider(
318
+ "Filter by Confidence Range:",
319
+ min_value=0.0,
320
+ max_value=1.0,
321
+ value=(0.0, 1.0),
322
+ step=0.01,
323
+ format="%.2f", # Format slider display for clarity
324
  )
 
 
 
 
325
  filtered_df = filtered_df[
326
+ (filtered_df["Confidence"] >= min_conf)
327
+ & (filtered_df["Confidence"] <= max_conf)
328
  ]
329
+ # --- END FILTER SECTION ---
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  # Apply Confusion Matrix Drill-Down Filter (if active)
332
  if st.session_state.get("cm_filter_active", False):