Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
Β·
9318b04
1
Parent(s):
345529d
FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering
Browse files- modules/analyzer.py +144 -102
modules/analyzer.py
CHANGED
@@ -17,21 +17,21 @@ from modules.ui_components import create_spectrum_plot
|
|
17 |
import hashlib
|
18 |
|
19 |
|
20 |
-
# --- NEW
|
21 |
@contextmanager
|
22 |
-
def
|
23 |
-
"""
|
24 |
-
|
|
|
|
|
25 |
try:
|
26 |
theme_opts = st.get_option("theme") or {}
|
27 |
except RuntimeError:
|
28 |
# Fallback to empty dict if theme config is not available
|
29 |
theme_opts = {}
|
30 |
-
|
31 |
text_color = theme_opts.get("textColor", "#000000")
|
32 |
bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
|
33 |
|
34 |
-
# Set Matplotlib's rcParams to match the theme
|
35 |
with plt.rc_context(
|
36 |
{
|
37 |
"figure.facecolor": bg_color,
|
@@ -42,12 +42,18 @@ def theme_aware_plot():
|
|
42 |
"ytick.color": text_color,
|
43 |
"grid.color": text_color,
|
44 |
"axes.edgecolor": text_color,
|
|
|
|
|
45 |
}
|
46 |
):
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
-
# --- END HELPER
|
51 |
|
52 |
|
53 |
class BatchAnalysis:
|
@@ -105,12 +111,10 @@ class BatchAnalysis:
|
|
105 |
),
|
106 |
)
|
107 |
|
108 |
-
# In modules/analyzer.py
|
109 |
-
|
110 |
def render_visual_diagnostics(self):
|
111 |
"""
|
112 |
-
Renders
|
113 |
-
|
114 |
"""
|
115 |
st.markdown("##### Visual Analysis")
|
116 |
if not self.has_ground_truth:
|
@@ -118,22 +122,18 @@ class BatchAnalysis:
|
|
118 |
return
|
119 |
|
120 |
valid_gt_df = self.df.dropna(subset=["Ground Truth"])
|
121 |
-
|
122 |
-
# Use a single row of columns for the two main plots
|
123 |
plot_col1, plot_col2 = st.columns(2)
|
124 |
|
125 |
-
# --- Chart 1: Confusion Matrix ---
|
126 |
-
with plot_col1:
|
127 |
-
with st.container(border=True):
|
128 |
st.markdown("**Confusion Matrix**")
|
129 |
cm = confusion_matrix(
|
130 |
valid_gt_df["Ground Truth"],
|
131 |
valid_gt_df["Prediction"],
|
132 |
labels=list(LABEL_MAP.keys()),
|
133 |
)
|
134 |
-
|
135 |
-
with theme_aware_plot(): # Apply theme-aware styling
|
136 |
-
fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
|
137 |
sns.heatmap(
|
138 |
cm,
|
139 |
annot=True,
|
@@ -145,58 +145,98 @@ class BatchAnalysis:
|
|
145 |
)
|
146 |
ax.set_ylabel("Actual Class", fontsize=12)
|
147 |
ax.set_xlabel("Predicted Class", fontsize=12)
|
148 |
-
|
|
|
|
|
149 |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
|
150 |
-
|
151 |
-
|
152 |
-
st.caption("Click a cell below to filter the data grid:")
|
153 |
-
|
154 |
-
# Render CM filter buttons directly below the plot in the same column
|
155 |
-
cm_labels = list(LABEL_MAP.values())
|
156 |
-
for i, actual_label in enumerate(cm_labels):
|
157 |
-
btn_cols_row = st.columns(
|
158 |
-
len(cm_labels)
|
159 |
-
) # Create a row of columns for buttons
|
160 |
-
for j, predicted_label in enumerate(cm_labels):
|
161 |
-
cell_value = cm[i, j]
|
162 |
-
btn_cols_row[j].button( # Button for each cell
|
163 |
-
f"Actual: {actual_label}\nPred: {predicted_label} ({cell_value})",
|
164 |
-
key=f"cm_cell_{i}_{j}",
|
165 |
-
on_click=self._set_cm_filter,
|
166 |
-
args=(i, j, actual_label, predicted_label),
|
167 |
-
use_container_width=True,
|
168 |
-
)
|
169 |
-
# Clear filter button for CM
|
170 |
-
if st.session_state.get("cm_filter_active", False):
|
171 |
-
st.button(
|
172 |
-
"Clear Matrix Filter",
|
173 |
-
on_click=self._clear_cm_filter,
|
174 |
-
key="clear_cm_filter_btn_below",
|
175 |
-
)
|
176 |
|
177 |
-
# --- Chart 2: Confidence vs. Correctness Box Plot ---
|
178 |
-
with plot_col2:
|
179 |
-
with st.container(border=True):
|
180 |
st.markdown("**Confidence Analysis**")
|
181 |
valid_gt_df["Result"] = np.where(
|
182 |
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
|
183 |
-
"Correct",
|
184 |
-
"Incorrect",
|
185 |
)
|
186 |
-
|
187 |
-
with theme_aware_plot(): # Apply theme-aware styling
|
188 |
-
fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
|
189 |
sns.boxplot(
|
190 |
x="Result",
|
191 |
y="Confidence",
|
192 |
data=valid_gt_df,
|
193 |
ax=ax,
|
194 |
-
palette={"Correct": "
|
195 |
)
|
196 |
ax.set_ylabel("Model Confidence", fontsize=12)
|
197 |
-
ax.set_xlabel("Prediction
|
|
|
198 |
st.pyplot(fig, use_container_width=True)
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
def _set_cm_filter(
|
202 |
self,
|
@@ -235,56 +275,58 @@ class BatchAnalysis:
|
|
235 |
# Start with a full copy of the dataframe to apply filters to
|
236 |
filtered_df = self.df.copy()
|
237 |
|
238 |
-
# --- Filter Section ---
|
239 |
-
st.
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
# Filter 2: By Ground Truth Correctness (if available)
|
251 |
-
if self.has_ground_truth:
|
252 |
-
filtered_df["Correct"] = (
|
253 |
-
filtered_df["Prediction"] == filtered_df["Ground Truth"]
|
254 |
)
|
255 |
-
|
|
|
|
|
256 |
|
257 |
-
#
|
258 |
-
|
259 |
-
filtered_df["Correct"]
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
)
|
267 |
-
# Filter based on the boolean 'Correct' column
|
268 |
-
filter_correctness_bools = [
|
269 |
-
True if c == "β
Correct" else False for c in selected_correctness
|
270 |
-
]
|
271 |
filtered_df = filtered_df[
|
272 |
-
filtered_df["
|
|
|
273 |
]
|
274 |
-
|
275 |
-
# --- NEW: Filter 3: By Confidence Range ---
|
276 |
-
min_conf, max_conf = filter_cols[2].slider(
|
277 |
-
"Filter by Confidence Range:",
|
278 |
-
min_value=0.0,
|
279 |
-
max_value=1.0,
|
280 |
-
value=(0.0, 1.0), # Default to the full range
|
281 |
-
step=0.01,
|
282 |
-
)
|
283 |
-
filtered_df = filtered_df[
|
284 |
-
(filtered_df["Confidence"] >= min_conf)
|
285 |
-
& (filtered_df["Confidence"] <= max_conf)
|
286 |
-
]
|
287 |
-
# --- END NEW FILTER ---
|
288 |
|
289 |
# Apply Confusion Matrix Drill-Down Filter (if active)
|
290 |
if st.session_state.get("cm_filter_active", False):
|
|
|
17 |
import hashlib
|
18 |
|
19 |
|
20 |
+
# --- NEW: Centralized plot styling helper ---
|
21 |
@contextmanager
|
22 |
+
def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs):
|
23 |
+
"""
|
24 |
+
A context manager to apply consistent Matplotlib styling and
|
25 |
+
make plots theme-aware.
|
26 |
+
"""
|
27 |
try:
|
28 |
theme_opts = st.get_option("theme") or {}
|
29 |
except RuntimeError:
|
30 |
# Fallback to empty dict if theme config is not available
|
31 |
theme_opts = {}
|
|
|
32 |
text_color = theme_opts.get("textColor", "#000000")
|
33 |
bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
|
34 |
|
|
|
35 |
with plt.rc_context(
|
36 |
{
|
37 |
"figure.facecolor": bg_color,
|
|
|
42 |
"ytick.color": text_color,
|
43 |
"grid.color": text_color,
|
44 |
"axes.edgecolor": text_color,
|
45 |
+
"axes.titlecolor": text_color, # Ensure title color matches
|
46 |
+
"figure.autolayout": True, # Auto-adjusts subplot params for a tight layout
|
47 |
}
|
48 |
):
|
49 |
+
fig, ax = plt.subplots(
|
50 |
+
figsize=figsize, constrained_layout=constrained_layout, **kwargs
|
51 |
+
)
|
52 |
+
yield fig, ax
|
53 |
+
plt.close(fig) # Always close figure to prevent memory leaks
|
54 |
|
55 |
|
56 |
+
# --- END NEW HELPER ---
|
57 |
|
58 |
|
59 |
class BatchAnalysis:
|
|
|
111 |
),
|
112 |
)
|
113 |
|
|
|
|
|
114 |
def render_visual_diagnostics(self):
|
115 |
"""
|
116 |
+
Renders diagnostic plots with corrected aesthetics and a robust,
|
117 |
+
interactive drill-down filter using st.selectbox.
|
118 |
"""
|
119 |
st.markdown("##### Visual Analysis")
|
120 |
if not self.has_ground_truth:
|
|
|
122 |
return
|
123 |
|
124 |
valid_gt_df = self.df.dropna(subset=["Ground Truth"])
|
|
|
|
|
125 |
plot_col1, plot_col2 = st.columns(2)
|
126 |
|
127 |
+
# --- Chart 1: Confusion Matrix (Aesthetically Corrected) ---
|
128 |
+
with plot_col1:
|
129 |
+
with st.container(border=True):
|
130 |
st.markdown("**Confusion Matrix**")
|
131 |
cm = confusion_matrix(
|
132 |
valid_gt_df["Ground Truth"],
|
133 |
valid_gt_df["Prediction"],
|
134 |
labels=list(LABEL_MAP.keys()),
|
135 |
)
|
136 |
+
with plot_style_context() as (fig, ax):
|
|
|
|
|
137 |
sns.heatmap(
|
138 |
cm,
|
139 |
annot=True,
|
|
|
145 |
)
|
146 |
ax.set_ylabel("Actual Class", fontsize=12)
|
147 |
ax.set_xlabel("Predicted Class", fontsize=12)
|
148 |
+
|
149 |
+
# --- AESTHETIC FIX: Rotate X-labels vertically for a compact look ---
|
150 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
151 |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
|
152 |
+
ax.set_title("Prediction vs. Actual (Counts)", fontsize=14)
|
153 |
+
st.pyplot(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
# --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) ---
|
156 |
+
with plot_col2:
|
157 |
+
with st.container(border=True):
|
158 |
st.markdown("**Confidence Analysis**")
|
159 |
valid_gt_df["Result"] = np.where(
|
160 |
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
|
161 |
+
"β
Correct",
|
162 |
+
"β Incorrect",
|
163 |
)
|
164 |
+
with plot_style_context() as (fig, ax):
|
|
|
|
|
165 |
sns.boxplot(
|
166 |
x="Result",
|
167 |
y="Confidence",
|
168 |
data=valid_gt_df,
|
169 |
ax=ax,
|
170 |
+
palette={"β
Correct": "lightgreen", "β Incorrect": "salmon"},
|
171 |
)
|
172 |
ax.set_ylabel("Model Confidence", fontsize=12)
|
173 |
+
ax.set_xlabel("Prediction Outcome", fontsize=12)
|
174 |
+
ax.set_title("Confidence Distribution by Outcome", fontsize=14)
|
175 |
st.pyplot(fig, use_container_width=True)
|
176 |
+
|
177 |
+
st.divider()
|
178 |
+
|
179 |
+
# --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox ---
|
180 |
+
st.markdown("###### Interactive Confusion Matrix Drill-Down")
|
181 |
+
st.caption(
|
182 |
+
"Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab."
|
183 |
+
)
|
184 |
+
|
185 |
+
# Create a list of options for the selectbox from the confusion matrix
|
186 |
+
cm = confusion_matrix(
|
187 |
+
valid_gt_df["Ground Truth"],
|
188 |
+
valid_gt_df["Prediction"],
|
189 |
+
labels=list(LABEL_MAP.keys()),
|
190 |
+
)
|
191 |
+
cm_labels = list(LABEL_MAP.values())
|
192 |
+
options = ["-- Select a cell to filter --"]
|
193 |
+
|
194 |
+
# This nested loop creates the human-readable options for the dropdown
|
195 |
+
for i, actual_label in enumerate(cm_labels):
|
196 |
+
for j, predicted_label in enumerate(cm_labels):
|
197 |
+
cell_value = cm[i, j]
|
198 |
+
# We only add cells with content to the dropdown to avoid clutter
|
199 |
+
if cell_value > 0:
|
200 |
+
option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)"
|
201 |
+
options.append(option_str)
|
202 |
+
|
203 |
+
# The selectbox widget, which is more robust for state management
|
204 |
+
selected_option = st.selectbox(
|
205 |
+
"Drill-Down Filter",
|
206 |
+
options=options,
|
207 |
+
key="cm_selectbox", # Give it a key to track its state
|
208 |
+
index=0, # Default to the placeholder
|
209 |
+
)
|
210 |
+
|
211 |
+
# Logic to activate or deactivate the filter based on selection
|
212 |
+
if selected_option != "-- Select a cell to filter --":
|
213 |
+
# Parse the selection to get the actual and predicted classes
|
214 |
+
parts = selected_option.split("|")
|
215 |
+
actual_str = parts[0].replace("Actual: ", "").strip()
|
216 |
+
# FIX: Split on " (" to get the full label without the file count
|
217 |
+
predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip()
|
218 |
+
|
219 |
+
# Find the corresponding numeric indices with error handling
|
220 |
+
actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str]
|
221 |
+
predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str]
|
222 |
+
|
223 |
+
if not actual_matching or not predicted_matching:
|
224 |
+
return
|
225 |
+
|
226 |
+
actual_idx = actual_matching[0]
|
227 |
+
predicted_idx = predicted_matching[0]
|
228 |
+
|
229 |
+
# Use a simplified callback-like update to session state
|
230 |
+
st.session_state["cm_actual_filter"] = actual_idx
|
231 |
+
st.session_state["cm_predicted_filter"] = predicted_idx
|
232 |
+
st.session_state["cm_filter_label"] = (
|
233 |
+
f"Actual: {actual_str}, Predicted: {predicted_str}"
|
234 |
+
)
|
235 |
+
st.session_state["cm_filter_active"] = True
|
236 |
+
else:
|
237 |
+
# If the user selects the placeholder, deactivate the filter
|
238 |
+
if st.session_state.get("cm_filter_active", False):
|
239 |
+
self._clear_cm_filter()
|
240 |
|
241 |
def _set_cm_filter(
|
242 |
self,
|
|
|
275 |
# Start with a full copy of the dataframe to apply filters to
|
276 |
filtered_df = self.df.copy()
|
277 |
|
278 |
+
# --- Filter Section (STREAMLINED LAYOUT) ---
|
279 |
+
with st.container(border=True):
|
280 |
+
st.markdown("**Filters**")
|
281 |
+
filter_row1 = st.columns([1, 1])
|
282 |
+
filter_row2 = st.columns(1) # Slider takes full width
|
283 |
+
|
284 |
+
# Filter 1: By Predicted Class
|
285 |
+
selected_classes = filter_row1[0].multiselect(
|
286 |
+
"Filter by Prediction:",
|
287 |
+
options=self.df["Predicted Class"].unique(),
|
288 |
+
default=self.df["Predicted Class"].unique(),
|
|
|
|
|
|
|
|
|
|
|
289 |
)
|
290 |
+
filtered_df = filtered_df[
|
291 |
+
filtered_df["Predicted Class"].isin(selected_classes)
|
292 |
+
]
|
293 |
|
294 |
+
# Filter 2: By Ground Truth Correctness (if available)
|
295 |
+
if self.has_ground_truth:
|
296 |
+
filtered_df["Correct"] = (
|
297 |
+
filtered_df["Prediction"] == filtered_df["Ground Truth"]
|
298 |
+
)
|
299 |
+
correctness_options = ["β
Correct", "β Incorrect"]
|
300 |
+
filtered_df["Result_Display"] = np.where(
|
301 |
+
filtered_df["Correct"], "β
Correct", "β Incorrect"
|
302 |
+
)
|
303 |
|
304 |
+
selected_correctness = filter_row1[1].multiselect(
|
305 |
+
"Filter by Result:",
|
306 |
+
options=correctness_options,
|
307 |
+
default=correctness_options,
|
308 |
+
)
|
309 |
+
filter_correctness_bools = [
|
310 |
+
True if c == "β
Correct" else False for c in selected_correctness
|
311 |
+
]
|
312 |
+
filtered_df = filtered_df[
|
313 |
+
filtered_df["Correct"].isin(filter_correctness_bools)
|
314 |
+
]
|
315 |
+
|
316 |
+
# Filter 3: By Confidence Range (full width below others)
|
317 |
+
min_conf, max_conf = filter_row2[0].slider(
|
318 |
+
"Filter by Confidence Range:",
|
319 |
+
min_value=0.0,
|
320 |
+
max_value=1.0,
|
321 |
+
value=(0.0, 1.0),
|
322 |
+
step=0.01,
|
323 |
+
format="%.2f", # Format slider display for clarity
|
324 |
)
|
|
|
|
|
|
|
|
|
325 |
filtered_df = filtered_df[
|
326 |
+
(filtered_df["Confidence"] >= min_conf)
|
327 |
+
& (filtered_df["Confidence"] <= max_conf)
|
328 |
]
|
329 |
+
# --- END FILTER SECTION ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
# Apply Confusion Matrix Drill-Down Filter (if active)
|
332 |
if st.session_state.get("cm_filter_active", False):
|