polymer-aging-ml / modules /analyzer.py
devjas1
FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering
9318b04
raw
history blame
18.9 kB
# In modules/analyzer.py
import streamlit as st
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from datetime import datetime
from contextlib import contextmanager # Correctly imported for use with @contextmanager
from config import LABEL_MAP # Assuming LABEL_MAP is correctly defined in config.py
# --- ADD THESE IMPORTS AT THE TOP OF THE FILE ---
from utils.results_manager import ResultsManager
from modules.ui_components import create_spectrum_plot
import hashlib
# --- NEW: Centralized plot styling helper ---
@contextmanager
def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs):
"""
A context manager to apply consistent Matplotlib styling and
make plots theme-aware.
"""
try:
theme_opts = st.get_option("theme") or {}
except RuntimeError:
# Fallback to empty dict if theme config is not available
theme_opts = {}
text_color = theme_opts.get("textColor", "#000000")
bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
with plt.rc_context(
{
"figure.facecolor": bg_color,
"axes.facecolor": bg_color,
"text.color": text_color,
"axes.labelcolor": text_color,
"xtick.color": text_color,
"ytick.color": text_color,
"grid.color": text_color,
"axes.edgecolor": text_color,
"axes.titlecolor": text_color, # Ensure title color matches
"figure.autolayout": True, # Auto-adjusts subplot params for a tight layout
}
):
fig, ax = plt.subplots(
figsize=figsize, constrained_layout=constrained_layout, **kwargs
)
yield fig, ax
plt.close(fig) # Always close figure to prevent memory leaks
# --- END NEW HELPER ---
class BatchAnalysis:
def __init__(self, df: pd.DataFrame):
"""Initializes the analysis object with the results DataFrame."""
self.df = df
if self.df.empty:
return
self.total_files = len(self.df)
self.has_ground_truth = (
"Ground Truth" in self.df.columns
and not self.df["Ground Truth"].isnull().all()
)
self._prepare_data()
self.kpis = self._calculate_kpis()
def _prepare_data(self):
"""Ensures data types are correct for analysis."""
self.df["Confidence"] = pd.to_numeric(self.df["Confidence"], errors="coerce")
if self.has_ground_truth:
self.df["Ground Truth"] = pd.to_numeric(
self.df["Ground Truth"], errors="coerce"
)
def _calculate_kpis(self) -> dict:
"""A private method to compute all the key performance indicators."""
stable_count = self.df[
self.df["Predicted Class"] == "Stable (Unweathered)"
].shape[0]
accuracy = "N/A"
if self.has_ground_truth:
valid_gt = self.df.dropna(subset=["Ground Truth", "Prediction"])
accuracy = (valid_gt["Prediction"] == valid_gt["Ground Truth"]).mean()
return {
"Total Files": self.total_files,
"Avg. Confidence": self.df["Confidence"].mean(),
"Stable/Weathered": f"{stable_count}/{self.total_files - stable_count}",
"Accuracy": accuracy,
}
def render_kpis(self):
"""Renders the top-level KPI metrics."""
kpi_cols = st.columns(4)
kpi_cols[0].metric("Total Files", f"{self.kpis['Total Files']}")
kpi_cols[1].metric("Avg. Confidence", f"{self.kpis['Avg. Confidence']:.3f}")
kpi_cols[2].metric("Stable/Weathered", self.kpis["Stable/Weathered"])
kpi_cols[3].metric(
"Accuracy",
(
f"{self.kpis['Accuracy']:.3f}"
if isinstance(self.kpis["Accuracy"], float)
else "N/A"
),
)
def render_visual_diagnostics(self):
"""
Renders diagnostic plots with corrected aesthetics and a robust,
interactive drill-down filter using st.selectbox.
"""
st.markdown("##### Visual Analysis")
if not self.has_ground_truth:
st.info("Visual analysis requires Ground Truth data for this batch.")
return
valid_gt_df = self.df.dropna(subset=["Ground Truth"])
plot_col1, plot_col2 = st.columns(2)
# --- Chart 1: Confusion Matrix (Aesthetically Corrected) ---
with plot_col1:
with st.container(border=True):
st.markdown("**Confusion Matrix**")
cm = confusion_matrix(
valid_gt_df["Ground Truth"],
valid_gt_df["Prediction"],
labels=list(LABEL_MAP.keys()),
)
with plot_style_context() as (fig, ax):
sns.heatmap(
cm,
annot=True,
fmt="g",
ax=ax,
cmap="Blues",
xticklabels=list(LABEL_MAP.values()),
yticklabels=list(LABEL_MAP.values()),
)
ax.set_ylabel("Actual Class", fontsize=12)
ax.set_xlabel("Predicted Class", fontsize=12)
# --- AESTHETIC FIX: Rotate X-labels vertically for a compact look ---
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_title("Prediction vs. Actual (Counts)", fontsize=14)
st.pyplot(fig, use_container_width=True)
# --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) ---
with plot_col2:
with st.container(border=True):
st.markdown("**Confidence Analysis**")
valid_gt_df["Result"] = np.where(
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
"✅ Correct",
"❌ Incorrect",
)
with plot_style_context() as (fig, ax):
sns.boxplot(
x="Result",
y="Confidence",
data=valid_gt_df,
ax=ax,
palette={"✅ Correct": "lightgreen", "❌ Incorrect": "salmon"},
)
ax.set_ylabel("Model Confidence", fontsize=12)
ax.set_xlabel("Prediction Outcome", fontsize=12)
ax.set_title("Confidence Distribution by Outcome", fontsize=14)
st.pyplot(fig, use_container_width=True)
st.divider()
# --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox ---
st.markdown("###### Interactive Confusion Matrix Drill-Down")
st.caption(
"Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab."
)
# Create a list of options for the selectbox from the confusion matrix
cm = confusion_matrix(
valid_gt_df["Ground Truth"],
valid_gt_df["Prediction"],
labels=list(LABEL_MAP.keys()),
)
cm_labels = list(LABEL_MAP.values())
options = ["-- Select a cell to filter --"]
# This nested loop creates the human-readable options for the dropdown
for i, actual_label in enumerate(cm_labels):
for j, predicted_label in enumerate(cm_labels):
cell_value = cm[i, j]
# We only add cells with content to the dropdown to avoid clutter
if cell_value > 0:
option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)"
options.append(option_str)
# The selectbox widget, which is more robust for state management
selected_option = st.selectbox(
"Drill-Down Filter",
options=options,
key="cm_selectbox", # Give it a key to track its state
index=0, # Default to the placeholder
)
# Logic to activate or deactivate the filter based on selection
if selected_option != "-- Select a cell to filter --":
# Parse the selection to get the actual and predicted classes
parts = selected_option.split("|")
actual_str = parts[0].replace("Actual: ", "").strip()
# FIX: Split on " (" to get the full label without the file count
predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip()
# Find the corresponding numeric indices with error handling
actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str]
predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str]
if not actual_matching or not predicted_matching:
return
actual_idx = actual_matching[0]
predicted_idx = predicted_matching[0]
# Use a simplified callback-like update to session state
st.session_state["cm_actual_filter"] = actual_idx
st.session_state["cm_predicted_filter"] = predicted_idx
st.session_state["cm_filter_label"] = (
f"Actual: {actual_str}, Predicted: {predicted_str}"
)
st.session_state["cm_filter_active"] = True
else:
# If the user selects the placeholder, deactivate the filter
if st.session_state.get("cm_filter_active", False):
self._clear_cm_filter()
def _set_cm_filter(
self,
actual_idx: int,
predicted_idx: int,
actual_label: str,
predicted_label: str,
):
"""Callback to set the confusion matrix filter in session state."""
st.session_state["cm_actual_filter"] = actual_idx
st.session_state["cm_predicted_filter"] = predicted_idx
st.session_state["cm_filter_label"] = (
f"Actual: {actual_label}, Predicted: {predicted_label}"
)
st.session_state["cm_filter_active"] = True
# Streamlit will rerun automatically
def _clear_cm_filter(self):
"""Callback to clear the confusion matrix filter from session state."""
if "cm_actual_filter" in st.session_state:
del st.session_state["cm_actual_filter"]
if "cm_predicted_filter" in st.session_state:
del st.session_state["cm_predicted_filter"]
if "cm_filter_label" in st.session_state:
del st.session_state["cm_filter_label"]
if "cm_filter_active" in st.session_state:
del st.session_state["cm_filter_active"]
def render_interactive_grid(self):
"""
Renders the filterable, detailed data grid with robust handling for
row selection to prevent KeyError.
"""
st.markdown("##### Detailed Results Explorer")
# Start with a full copy of the dataframe to apply filters to
filtered_df = self.df.copy()
# --- Filter Section (STREAMLINED LAYOUT) ---
with st.container(border=True):
st.markdown("**Filters**")
filter_row1 = st.columns([1, 1])
filter_row2 = st.columns(1) # Slider takes full width
# Filter 1: By Predicted Class
selected_classes = filter_row1[0].multiselect(
"Filter by Prediction:",
options=self.df["Predicted Class"].unique(),
default=self.df["Predicted Class"].unique(),
)
filtered_df = filtered_df[
filtered_df["Predicted Class"].isin(selected_classes)
]
# Filter 2: By Ground Truth Correctness (if available)
if self.has_ground_truth:
filtered_df["Correct"] = (
filtered_df["Prediction"] == filtered_df["Ground Truth"]
)
correctness_options = ["✅ Correct", "❌ Incorrect"]
filtered_df["Result_Display"] = np.where(
filtered_df["Correct"], "✅ Correct", "❌ Incorrect"
)
selected_correctness = filter_row1[1].multiselect(
"Filter by Result:",
options=correctness_options,
default=correctness_options,
)
filter_correctness_bools = [
True if c == "✅ Correct" else False for c in selected_correctness
]
filtered_df = filtered_df[
filtered_df["Correct"].isin(filter_correctness_bools)
]
# Filter 3: By Confidence Range (full width below others)
min_conf, max_conf = filter_row2[0].slider(
"Filter by Confidence Range:",
min_value=0.0,
max_value=1.0,
value=(0.0, 1.0),
step=0.01,
format="%.2f", # Format slider display for clarity
)
filtered_df = filtered_df[
(filtered_df["Confidence"] >= min_conf)
& (filtered_df["Confidence"] <= max_conf)
]
# --- END FILTER SECTION ---
# Apply Confusion Matrix Drill-Down Filter (if active)
if st.session_state.get("cm_filter_active", False):
actual_idx = st.session_state["cm_actual_filter"]
predicted_idx = st.session_state["cm_predicted_filter"]
filter_label = st.session_state["cm_filter_label"]
st.info(f"Filtering results for: **{filter_label}**")
filtered_df = filtered_df[
(filtered_df["Ground Truth"] == actual_idx)
& (filtered_df["Prediction"] == predicted_idx)
]
# --- Display the Filtered Data Table ---
if filtered_df.empty:
st.warning("No files match the current filter criteria.")
st.session_state.selected_spectrum_file = None
else:
display_df = filtered_df.drop(
columns=["Correct", "Result_Display"], errors="ignore"
)
st.dataframe(
display_df,
use_container_width=True,
hide_index=True,
on_select="rerun",
selection_mode="single-row",
key="results_grid_selection",
)
# --- ROBUST SELECTION HANDLING (THE FIX) ---
selection_state = st.session_state.get("results_grid_selection")
# Check if selection_state is a dictionary AND if it contains the 'rows' key
if (
isinstance(selection_state, dict)
and "rows" in selection_state
and selection_state["rows"]
):
selected_index = selection_state["rows"][0]
if selected_index < len(filtered_df):
st.session_state.selected_spectrum_file = filtered_df.iloc[
selected_index
]["Filename"]
else:
# This can happen if the table is re-filtered and the old index is now out of bounds
st.session_state.selected_spectrum_file = None
else:
# If the selection is empty or in an unexpected format, clear the selection
st.session_state.selected_spectrum_file = None
# --- END ROBUST HANDLING ---
# --- ADD THIS ENTIRE NEW METHOD ---
def render_selected_spectrum(self):
"""
Renders an expander with the spectrum plot for the currently selected file.
This is called after the data grid.
"""
selected_file = st.session_state.get("selected_spectrum_file")
# Only render if a file has been selected in the current session
if selected_file:
with st.expander(f"View Spectrum for: **{selected_file}**", expanded=True):
# Retrieve the full, detailed record for the selected file
spectrum_data = ResultsManager.get_spectrum_data_for_file(selected_file)
# Check if the detailed data was successfully retrieved and contains all necessary arrays
if spectrum_data and all(
spectrum_data.get(k) is not None
for k in ["x_raw", "y_raw", "x_resampled", "y_resampled"]
):
# Generate a unique cache key for the plot to avoid re-generating it unnecessarily
cache_key = hashlib.md5(
(
f"{spectrum_data['x_raw'].tobytes()}"
f"{spectrum_data['y_raw'].tobytes()}"
f"{spectrum_data['x_resampled'].tobytes()}"
f"{spectrum_data['y_resampled'].tobytes()}"
).encode()
).hexdigest()
# Call the plotting function from ui_components
plot_image = create_spectrum_plot(
spectrum_data["x_raw"],
spectrum_data["y_raw"],
spectrum_data["x_resampled"],
spectrum_data["y_resampled"],
_cache_key=cache_key,
)
st.image(
plot_image,
caption=f"Raw vs. Resampled Spectrum for {selected_file}",
use_container_width=True,
)
else:
st.warning(
f"Could not retrieve spectrum data for '{selected_file}'. The data might not have been stored during the initial run."
)
# --- END NEW METHOD ---
def render(self):
"""
The main public method to render the entire dashboard using a more
organized and streamlined tab-based layout.
"""
if self.df.empty:
st.info(
"The results table is empty. Please run an analysis on the 'Upload and Run' page."
)
return
# --- Tier 1: KPIs (Always visible at the top) ---
self.render_kpis()
st.divider()
# --- Tier 2: Tabbed Interface for Deeper Analysis ---
tab1, tab2 = st.tabs(["📊 Visual Diagnostics", "🗂️ Results Explorer"])
with tab1:
# The visual diagnostics (Confusion Matrix, etc.) go here.
self.render_visual_diagnostics()
with tab2:
# The interactive grid AND the spectrum viewer it controls go here.
self.render_interactive_grid()
self.render_selected_spectrum()