Spaces:
Running
Running
devjas1
FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering
9318b04
# In modules/analyzer.py | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import seaborn as sns | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
from contextlib import contextmanager # Correctly imported for use with @contextmanager | |
from config import LABEL_MAP # Assuming LABEL_MAP is correctly defined in config.py | |
# --- ADD THESE IMPORTS AT THE TOP OF THE FILE --- | |
from utils.results_manager import ResultsManager | |
from modules.ui_components import create_spectrum_plot | |
import hashlib | |
# --- NEW: Centralized plot styling helper --- | |
def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs): | |
""" | |
A context manager to apply consistent Matplotlib styling and | |
make plots theme-aware. | |
""" | |
try: | |
theme_opts = st.get_option("theme") or {} | |
except RuntimeError: | |
# Fallback to empty dict if theme config is not available | |
theme_opts = {} | |
text_color = theme_opts.get("textColor", "#000000") | |
bg_color = theme_opts.get("backgroundColor", "#FFFFFF") | |
with plt.rc_context( | |
{ | |
"figure.facecolor": bg_color, | |
"axes.facecolor": bg_color, | |
"text.color": text_color, | |
"axes.labelcolor": text_color, | |
"xtick.color": text_color, | |
"ytick.color": text_color, | |
"grid.color": text_color, | |
"axes.edgecolor": text_color, | |
"axes.titlecolor": text_color, # Ensure title color matches | |
"figure.autolayout": True, # Auto-adjusts subplot params for a tight layout | |
} | |
): | |
fig, ax = plt.subplots( | |
figsize=figsize, constrained_layout=constrained_layout, **kwargs | |
) | |
yield fig, ax | |
plt.close(fig) # Always close figure to prevent memory leaks | |
# --- END NEW HELPER --- | |
class BatchAnalysis: | |
def __init__(self, df: pd.DataFrame): | |
"""Initializes the analysis object with the results DataFrame.""" | |
self.df = df | |
if self.df.empty: | |
return | |
self.total_files = len(self.df) | |
self.has_ground_truth = ( | |
"Ground Truth" in self.df.columns | |
and not self.df["Ground Truth"].isnull().all() | |
) | |
self._prepare_data() | |
self.kpis = self._calculate_kpis() | |
def _prepare_data(self): | |
"""Ensures data types are correct for analysis.""" | |
self.df["Confidence"] = pd.to_numeric(self.df["Confidence"], errors="coerce") | |
if self.has_ground_truth: | |
self.df["Ground Truth"] = pd.to_numeric( | |
self.df["Ground Truth"], errors="coerce" | |
) | |
def _calculate_kpis(self) -> dict: | |
"""A private method to compute all the key performance indicators.""" | |
stable_count = self.df[ | |
self.df["Predicted Class"] == "Stable (Unweathered)" | |
].shape[0] | |
accuracy = "N/A" | |
if self.has_ground_truth: | |
valid_gt = self.df.dropna(subset=["Ground Truth", "Prediction"]) | |
accuracy = (valid_gt["Prediction"] == valid_gt["Ground Truth"]).mean() | |
return { | |
"Total Files": self.total_files, | |
"Avg. Confidence": self.df["Confidence"].mean(), | |
"Stable/Weathered": f"{stable_count}/{self.total_files - stable_count}", | |
"Accuracy": accuracy, | |
} | |
def render_kpis(self): | |
"""Renders the top-level KPI metrics.""" | |
kpi_cols = st.columns(4) | |
kpi_cols[0].metric("Total Files", f"{self.kpis['Total Files']}") | |
kpi_cols[1].metric("Avg. Confidence", f"{self.kpis['Avg. Confidence']:.3f}") | |
kpi_cols[2].metric("Stable/Weathered", self.kpis["Stable/Weathered"]) | |
kpi_cols[3].metric( | |
"Accuracy", | |
( | |
f"{self.kpis['Accuracy']:.3f}" | |
if isinstance(self.kpis["Accuracy"], float) | |
else "N/A" | |
), | |
) | |
def render_visual_diagnostics(self): | |
""" | |
Renders diagnostic plots with corrected aesthetics and a robust, | |
interactive drill-down filter using st.selectbox. | |
""" | |
st.markdown("##### Visual Analysis") | |
if not self.has_ground_truth: | |
st.info("Visual analysis requires Ground Truth data for this batch.") | |
return | |
valid_gt_df = self.df.dropna(subset=["Ground Truth"]) | |
plot_col1, plot_col2 = st.columns(2) | |
# --- Chart 1: Confusion Matrix (Aesthetically Corrected) --- | |
with plot_col1: | |
with st.container(border=True): | |
st.markdown("**Confusion Matrix**") | |
cm = confusion_matrix( | |
valid_gt_df["Ground Truth"], | |
valid_gt_df["Prediction"], | |
labels=list(LABEL_MAP.keys()), | |
) | |
with plot_style_context() as (fig, ax): | |
sns.heatmap( | |
cm, | |
annot=True, | |
fmt="g", | |
ax=ax, | |
cmap="Blues", | |
xticklabels=list(LABEL_MAP.values()), | |
yticklabels=list(LABEL_MAP.values()), | |
) | |
ax.set_ylabel("Actual Class", fontsize=12) | |
ax.set_xlabel("Predicted Class", fontsize=12) | |
# --- AESTHETIC FIX: Rotate X-labels vertically for a compact look --- | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=90) | |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0) | |
ax.set_title("Prediction vs. Actual (Counts)", fontsize=14) | |
st.pyplot(fig, use_container_width=True) | |
# --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) --- | |
with plot_col2: | |
with st.container(border=True): | |
st.markdown("**Confidence Analysis**") | |
valid_gt_df["Result"] = np.where( | |
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"], | |
"✅ Correct", | |
"❌ Incorrect", | |
) | |
with plot_style_context() as (fig, ax): | |
sns.boxplot( | |
x="Result", | |
y="Confidence", | |
data=valid_gt_df, | |
ax=ax, | |
palette={"✅ Correct": "lightgreen", "❌ Incorrect": "salmon"}, | |
) | |
ax.set_ylabel("Model Confidence", fontsize=12) | |
ax.set_xlabel("Prediction Outcome", fontsize=12) | |
ax.set_title("Confidence Distribution by Outcome", fontsize=14) | |
st.pyplot(fig, use_container_width=True) | |
st.divider() | |
# --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox --- | |
st.markdown("###### Interactive Confusion Matrix Drill-Down") | |
st.caption( | |
"Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab." | |
) | |
# Create a list of options for the selectbox from the confusion matrix | |
cm = confusion_matrix( | |
valid_gt_df["Ground Truth"], | |
valid_gt_df["Prediction"], | |
labels=list(LABEL_MAP.keys()), | |
) | |
cm_labels = list(LABEL_MAP.values()) | |
options = ["-- Select a cell to filter --"] | |
# This nested loop creates the human-readable options for the dropdown | |
for i, actual_label in enumerate(cm_labels): | |
for j, predicted_label in enumerate(cm_labels): | |
cell_value = cm[i, j] | |
# We only add cells with content to the dropdown to avoid clutter | |
if cell_value > 0: | |
option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)" | |
options.append(option_str) | |
# The selectbox widget, which is more robust for state management | |
selected_option = st.selectbox( | |
"Drill-Down Filter", | |
options=options, | |
key="cm_selectbox", # Give it a key to track its state | |
index=0, # Default to the placeholder | |
) | |
# Logic to activate or deactivate the filter based on selection | |
if selected_option != "-- Select a cell to filter --": | |
# Parse the selection to get the actual and predicted classes | |
parts = selected_option.split("|") | |
actual_str = parts[0].replace("Actual: ", "").strip() | |
# FIX: Split on " (" to get the full label without the file count | |
predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip() | |
# Find the corresponding numeric indices with error handling | |
actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str] | |
predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str] | |
if not actual_matching or not predicted_matching: | |
return | |
actual_idx = actual_matching[0] | |
predicted_idx = predicted_matching[0] | |
# Use a simplified callback-like update to session state | |
st.session_state["cm_actual_filter"] = actual_idx | |
st.session_state["cm_predicted_filter"] = predicted_idx | |
st.session_state["cm_filter_label"] = ( | |
f"Actual: {actual_str}, Predicted: {predicted_str}" | |
) | |
st.session_state["cm_filter_active"] = True | |
else: | |
# If the user selects the placeholder, deactivate the filter | |
if st.session_state.get("cm_filter_active", False): | |
self._clear_cm_filter() | |
def _set_cm_filter( | |
self, | |
actual_idx: int, | |
predicted_idx: int, | |
actual_label: str, | |
predicted_label: str, | |
): | |
"""Callback to set the confusion matrix filter in session state.""" | |
st.session_state["cm_actual_filter"] = actual_idx | |
st.session_state["cm_predicted_filter"] = predicted_idx | |
st.session_state["cm_filter_label"] = ( | |
f"Actual: {actual_label}, Predicted: {predicted_label}" | |
) | |
st.session_state["cm_filter_active"] = True | |
# Streamlit will rerun automatically | |
def _clear_cm_filter(self): | |
"""Callback to clear the confusion matrix filter from session state.""" | |
if "cm_actual_filter" in st.session_state: | |
del st.session_state["cm_actual_filter"] | |
if "cm_predicted_filter" in st.session_state: | |
del st.session_state["cm_predicted_filter"] | |
if "cm_filter_label" in st.session_state: | |
del st.session_state["cm_filter_label"] | |
if "cm_filter_active" in st.session_state: | |
del st.session_state["cm_filter_active"] | |
def render_interactive_grid(self): | |
""" | |
Renders the filterable, detailed data grid with robust handling for | |
row selection to prevent KeyError. | |
""" | |
st.markdown("##### Detailed Results Explorer") | |
# Start with a full copy of the dataframe to apply filters to | |
filtered_df = self.df.copy() | |
# --- Filter Section (STREAMLINED LAYOUT) --- | |
with st.container(border=True): | |
st.markdown("**Filters**") | |
filter_row1 = st.columns([1, 1]) | |
filter_row2 = st.columns(1) # Slider takes full width | |
# Filter 1: By Predicted Class | |
selected_classes = filter_row1[0].multiselect( | |
"Filter by Prediction:", | |
options=self.df["Predicted Class"].unique(), | |
default=self.df["Predicted Class"].unique(), | |
) | |
filtered_df = filtered_df[ | |
filtered_df["Predicted Class"].isin(selected_classes) | |
] | |
# Filter 2: By Ground Truth Correctness (if available) | |
if self.has_ground_truth: | |
filtered_df["Correct"] = ( | |
filtered_df["Prediction"] == filtered_df["Ground Truth"] | |
) | |
correctness_options = ["✅ Correct", "❌ Incorrect"] | |
filtered_df["Result_Display"] = np.where( | |
filtered_df["Correct"], "✅ Correct", "❌ Incorrect" | |
) | |
selected_correctness = filter_row1[1].multiselect( | |
"Filter by Result:", | |
options=correctness_options, | |
default=correctness_options, | |
) | |
filter_correctness_bools = [ | |
True if c == "✅ Correct" else False for c in selected_correctness | |
] | |
filtered_df = filtered_df[ | |
filtered_df["Correct"].isin(filter_correctness_bools) | |
] | |
# Filter 3: By Confidence Range (full width below others) | |
min_conf, max_conf = filter_row2[0].slider( | |
"Filter by Confidence Range:", | |
min_value=0.0, | |
max_value=1.0, | |
value=(0.0, 1.0), | |
step=0.01, | |
format="%.2f", # Format slider display for clarity | |
) | |
filtered_df = filtered_df[ | |
(filtered_df["Confidence"] >= min_conf) | |
& (filtered_df["Confidence"] <= max_conf) | |
] | |
# --- END FILTER SECTION --- | |
# Apply Confusion Matrix Drill-Down Filter (if active) | |
if st.session_state.get("cm_filter_active", False): | |
actual_idx = st.session_state["cm_actual_filter"] | |
predicted_idx = st.session_state["cm_predicted_filter"] | |
filter_label = st.session_state["cm_filter_label"] | |
st.info(f"Filtering results for: **{filter_label}**") | |
filtered_df = filtered_df[ | |
(filtered_df["Ground Truth"] == actual_idx) | |
& (filtered_df["Prediction"] == predicted_idx) | |
] | |
# --- Display the Filtered Data Table --- | |
if filtered_df.empty: | |
st.warning("No files match the current filter criteria.") | |
st.session_state.selected_spectrum_file = None | |
else: | |
display_df = filtered_df.drop( | |
columns=["Correct", "Result_Display"], errors="ignore" | |
) | |
st.dataframe( | |
display_df, | |
use_container_width=True, | |
hide_index=True, | |
on_select="rerun", | |
selection_mode="single-row", | |
key="results_grid_selection", | |
) | |
# --- ROBUST SELECTION HANDLING (THE FIX) --- | |
selection_state = st.session_state.get("results_grid_selection") | |
# Check if selection_state is a dictionary AND if it contains the 'rows' key | |
if ( | |
isinstance(selection_state, dict) | |
and "rows" in selection_state | |
and selection_state["rows"] | |
): | |
selected_index = selection_state["rows"][0] | |
if selected_index < len(filtered_df): | |
st.session_state.selected_spectrum_file = filtered_df.iloc[ | |
selected_index | |
]["Filename"] | |
else: | |
# This can happen if the table is re-filtered and the old index is now out of bounds | |
st.session_state.selected_spectrum_file = None | |
else: | |
# If the selection is empty or in an unexpected format, clear the selection | |
st.session_state.selected_spectrum_file = None | |
# --- END ROBUST HANDLING --- | |
# --- ADD THIS ENTIRE NEW METHOD --- | |
def render_selected_spectrum(self): | |
""" | |
Renders an expander with the spectrum plot for the currently selected file. | |
This is called after the data grid. | |
""" | |
selected_file = st.session_state.get("selected_spectrum_file") | |
# Only render if a file has been selected in the current session | |
if selected_file: | |
with st.expander(f"View Spectrum for: **{selected_file}**", expanded=True): | |
# Retrieve the full, detailed record for the selected file | |
spectrum_data = ResultsManager.get_spectrum_data_for_file(selected_file) | |
# Check if the detailed data was successfully retrieved and contains all necessary arrays | |
if spectrum_data and all( | |
spectrum_data.get(k) is not None | |
for k in ["x_raw", "y_raw", "x_resampled", "y_resampled"] | |
): | |
# Generate a unique cache key for the plot to avoid re-generating it unnecessarily | |
cache_key = hashlib.md5( | |
( | |
f"{spectrum_data['x_raw'].tobytes()}" | |
f"{spectrum_data['y_raw'].tobytes()}" | |
f"{spectrum_data['x_resampled'].tobytes()}" | |
f"{spectrum_data['y_resampled'].tobytes()}" | |
).encode() | |
).hexdigest() | |
# Call the plotting function from ui_components | |
plot_image = create_spectrum_plot( | |
spectrum_data["x_raw"], | |
spectrum_data["y_raw"], | |
spectrum_data["x_resampled"], | |
spectrum_data["y_resampled"], | |
_cache_key=cache_key, | |
) | |
st.image( | |
plot_image, | |
caption=f"Raw vs. Resampled Spectrum for {selected_file}", | |
use_container_width=True, | |
) | |
else: | |
st.warning( | |
f"Could not retrieve spectrum data for '{selected_file}'. The data might not have been stored during the initial run." | |
) | |
# --- END NEW METHOD --- | |
def render(self): | |
""" | |
The main public method to render the entire dashboard using a more | |
organized and streamlined tab-based layout. | |
""" | |
if self.df.empty: | |
st.info( | |
"The results table is empty. Please run an analysis on the 'Upload and Run' page." | |
) | |
return | |
# --- Tier 1: KPIs (Always visible at the top) --- | |
self.render_kpis() | |
st.divider() | |
# --- Tier 2: Tabbed Interface for Deeper Analysis --- | |
tab1, tab2 = st.tabs(["📊 Visual Diagnostics", "🗂️ Results Explorer"]) | |
with tab1: | |
# The visual diagnostics (Confusion Matrix, etc.) go here. | |
self.render_visual_diagnostics() | |
with tab2: | |
# The interactive grid AND the spectrum viewer it controls go here. | |
self.render_interactive_grid() | |
self.render_selected_spectrum() | |