Spaces:
Sleeping
Sleeping
devjas1
FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering
9318b04
| # In modules/analyzer.py | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| from contextlib import contextmanager # Correctly imported for use with @contextmanager | |
| from config import LABEL_MAP # Assuming LABEL_MAP is correctly defined in config.py | |
| # --- ADD THESE IMPORTS AT THE TOP OF THE FILE --- | |
| from utils.results_manager import ResultsManager | |
| from modules.ui_components import create_spectrum_plot | |
| import hashlib | |
| # --- NEW: Centralized plot styling helper --- | |
| def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs): | |
| """ | |
| A context manager to apply consistent Matplotlib styling and | |
| make plots theme-aware. | |
| """ | |
| try: | |
| theme_opts = st.get_option("theme") or {} | |
| except RuntimeError: | |
| # Fallback to empty dict if theme config is not available | |
| theme_opts = {} | |
| text_color = theme_opts.get("textColor", "#000000") | |
| bg_color = theme_opts.get("backgroundColor", "#FFFFFF") | |
| with plt.rc_context( | |
| { | |
| "figure.facecolor": bg_color, | |
| "axes.facecolor": bg_color, | |
| "text.color": text_color, | |
| "axes.labelcolor": text_color, | |
| "xtick.color": text_color, | |
| "ytick.color": text_color, | |
| "grid.color": text_color, | |
| "axes.edgecolor": text_color, | |
| "axes.titlecolor": text_color, # Ensure title color matches | |
| "figure.autolayout": True, # Auto-adjusts subplot params for a tight layout | |
| } | |
| ): | |
| fig, ax = plt.subplots( | |
| figsize=figsize, constrained_layout=constrained_layout, **kwargs | |
| ) | |
| yield fig, ax | |
| plt.close(fig) # Always close figure to prevent memory leaks | |
| # --- END NEW HELPER --- | |
| class BatchAnalysis: | |
| def __init__(self, df: pd.DataFrame): | |
| """Initializes the analysis object with the results DataFrame.""" | |
| self.df = df | |
| if self.df.empty: | |
| return | |
| self.total_files = len(self.df) | |
| self.has_ground_truth = ( | |
| "Ground Truth" in self.df.columns | |
| and not self.df["Ground Truth"].isnull().all() | |
| ) | |
| self._prepare_data() | |
| self.kpis = self._calculate_kpis() | |
| def _prepare_data(self): | |
| """Ensures data types are correct for analysis.""" | |
| self.df["Confidence"] = pd.to_numeric(self.df["Confidence"], errors="coerce") | |
| if self.has_ground_truth: | |
| self.df["Ground Truth"] = pd.to_numeric( | |
| self.df["Ground Truth"], errors="coerce" | |
| ) | |
| def _calculate_kpis(self) -> dict: | |
| """A private method to compute all the key performance indicators.""" | |
| stable_count = self.df[ | |
| self.df["Predicted Class"] == "Stable (Unweathered)" | |
| ].shape[0] | |
| accuracy = "N/A" | |
| if self.has_ground_truth: | |
| valid_gt = self.df.dropna(subset=["Ground Truth", "Prediction"]) | |
| accuracy = (valid_gt["Prediction"] == valid_gt["Ground Truth"]).mean() | |
| return { | |
| "Total Files": self.total_files, | |
| "Avg. Confidence": self.df["Confidence"].mean(), | |
| "Stable/Weathered": f"{stable_count}/{self.total_files - stable_count}", | |
| "Accuracy": accuracy, | |
| } | |
| def render_kpis(self): | |
| """Renders the top-level KPI metrics.""" | |
| kpi_cols = st.columns(4) | |
| kpi_cols[0].metric("Total Files", f"{self.kpis['Total Files']}") | |
| kpi_cols[1].metric("Avg. Confidence", f"{self.kpis['Avg. Confidence']:.3f}") | |
| kpi_cols[2].metric("Stable/Weathered", self.kpis["Stable/Weathered"]) | |
| kpi_cols[3].metric( | |
| "Accuracy", | |
| ( | |
| f"{self.kpis['Accuracy']:.3f}" | |
| if isinstance(self.kpis["Accuracy"], float) | |
| else "N/A" | |
| ), | |
| ) | |
| def render_visual_diagnostics(self): | |
| """ | |
| Renders diagnostic plots with corrected aesthetics and a robust, | |
| interactive drill-down filter using st.selectbox. | |
| """ | |
| st.markdown("##### Visual Analysis") | |
| if not self.has_ground_truth: | |
| st.info("Visual analysis requires Ground Truth data for this batch.") | |
| return | |
| valid_gt_df = self.df.dropna(subset=["Ground Truth"]) | |
| plot_col1, plot_col2 = st.columns(2) | |
| # --- Chart 1: Confusion Matrix (Aesthetically Corrected) --- | |
| with plot_col1: | |
| with st.container(border=True): | |
| st.markdown("**Confusion Matrix**") | |
| cm = confusion_matrix( | |
| valid_gt_df["Ground Truth"], | |
| valid_gt_df["Prediction"], | |
| labels=list(LABEL_MAP.keys()), | |
| ) | |
| with plot_style_context() as (fig, ax): | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="g", | |
| ax=ax, | |
| cmap="Blues", | |
| xticklabels=list(LABEL_MAP.values()), | |
| yticklabels=list(LABEL_MAP.values()), | |
| ) | |
| ax.set_ylabel("Actual Class", fontsize=12) | |
| ax.set_xlabel("Predicted Class", fontsize=12) | |
| # --- AESTHETIC FIX: Rotate X-labels vertically for a compact look --- | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=90) | |
| ax.set_yticklabels(ax.get_yticklabels(), rotation=0) | |
| ax.set_title("Prediction vs. Actual (Counts)", fontsize=14) | |
| st.pyplot(fig, use_container_width=True) | |
| # --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) --- | |
| with plot_col2: | |
| with st.container(border=True): | |
| st.markdown("**Confidence Analysis**") | |
| valid_gt_df["Result"] = np.where( | |
| valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"], | |
| "✅ Correct", | |
| "❌ Incorrect", | |
| ) | |
| with plot_style_context() as (fig, ax): | |
| sns.boxplot( | |
| x="Result", | |
| y="Confidence", | |
| data=valid_gt_df, | |
| ax=ax, | |
| palette={"✅ Correct": "lightgreen", "❌ Incorrect": "salmon"}, | |
| ) | |
| ax.set_ylabel("Model Confidence", fontsize=12) | |
| ax.set_xlabel("Prediction Outcome", fontsize=12) | |
| ax.set_title("Confidence Distribution by Outcome", fontsize=14) | |
| st.pyplot(fig, use_container_width=True) | |
| st.divider() | |
| # --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox --- | |
| st.markdown("###### Interactive Confusion Matrix Drill-Down") | |
| st.caption( | |
| "Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab." | |
| ) | |
| # Create a list of options for the selectbox from the confusion matrix | |
| cm = confusion_matrix( | |
| valid_gt_df["Ground Truth"], | |
| valid_gt_df["Prediction"], | |
| labels=list(LABEL_MAP.keys()), | |
| ) | |
| cm_labels = list(LABEL_MAP.values()) | |
| options = ["-- Select a cell to filter --"] | |
| # This nested loop creates the human-readable options for the dropdown | |
| for i, actual_label in enumerate(cm_labels): | |
| for j, predicted_label in enumerate(cm_labels): | |
| cell_value = cm[i, j] | |
| # We only add cells with content to the dropdown to avoid clutter | |
| if cell_value > 0: | |
| option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)" | |
| options.append(option_str) | |
| # The selectbox widget, which is more robust for state management | |
| selected_option = st.selectbox( | |
| "Drill-Down Filter", | |
| options=options, | |
| key="cm_selectbox", # Give it a key to track its state | |
| index=0, # Default to the placeholder | |
| ) | |
| # Logic to activate or deactivate the filter based on selection | |
| if selected_option != "-- Select a cell to filter --": | |
| # Parse the selection to get the actual and predicted classes | |
| parts = selected_option.split("|") | |
| actual_str = parts[0].replace("Actual: ", "").strip() | |
| # FIX: Split on " (" to get the full label without the file count | |
| predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip() | |
| # Find the corresponding numeric indices with error handling | |
| actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str] | |
| predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str] | |
| if not actual_matching or not predicted_matching: | |
| return | |
| actual_idx = actual_matching[0] | |
| predicted_idx = predicted_matching[0] | |
| # Use a simplified callback-like update to session state | |
| st.session_state["cm_actual_filter"] = actual_idx | |
| st.session_state["cm_predicted_filter"] = predicted_idx | |
| st.session_state["cm_filter_label"] = ( | |
| f"Actual: {actual_str}, Predicted: {predicted_str}" | |
| ) | |
| st.session_state["cm_filter_active"] = True | |
| else: | |
| # If the user selects the placeholder, deactivate the filter | |
| if st.session_state.get("cm_filter_active", False): | |
| self._clear_cm_filter() | |
| def _set_cm_filter( | |
| self, | |
| actual_idx: int, | |
| predicted_idx: int, | |
| actual_label: str, | |
| predicted_label: str, | |
| ): | |
| """Callback to set the confusion matrix filter in session state.""" | |
| st.session_state["cm_actual_filter"] = actual_idx | |
| st.session_state["cm_predicted_filter"] = predicted_idx | |
| st.session_state["cm_filter_label"] = ( | |
| f"Actual: {actual_label}, Predicted: {predicted_label}" | |
| ) | |
| st.session_state["cm_filter_active"] = True | |
| # Streamlit will rerun automatically | |
| def _clear_cm_filter(self): | |
| """Callback to clear the confusion matrix filter from session state.""" | |
| if "cm_actual_filter" in st.session_state: | |
| del st.session_state["cm_actual_filter"] | |
| if "cm_predicted_filter" in st.session_state: | |
| del st.session_state["cm_predicted_filter"] | |
| if "cm_filter_label" in st.session_state: | |
| del st.session_state["cm_filter_label"] | |
| if "cm_filter_active" in st.session_state: | |
| del st.session_state["cm_filter_active"] | |
| def render_interactive_grid(self): | |
| """ | |
| Renders the filterable, detailed data grid with robust handling for | |
| row selection to prevent KeyError. | |
| """ | |
| st.markdown("##### Detailed Results Explorer") | |
| # Start with a full copy of the dataframe to apply filters to | |
| filtered_df = self.df.copy() | |
| # --- Filter Section (STREAMLINED LAYOUT) --- | |
| with st.container(border=True): | |
| st.markdown("**Filters**") | |
| filter_row1 = st.columns([1, 1]) | |
| filter_row2 = st.columns(1) # Slider takes full width | |
| # Filter 1: By Predicted Class | |
| selected_classes = filter_row1[0].multiselect( | |
| "Filter by Prediction:", | |
| options=self.df["Predicted Class"].unique(), | |
| default=self.df["Predicted Class"].unique(), | |
| ) | |
| filtered_df = filtered_df[ | |
| filtered_df["Predicted Class"].isin(selected_classes) | |
| ] | |
| # Filter 2: By Ground Truth Correctness (if available) | |
| if self.has_ground_truth: | |
| filtered_df["Correct"] = ( | |
| filtered_df["Prediction"] == filtered_df["Ground Truth"] | |
| ) | |
| correctness_options = ["✅ Correct", "❌ Incorrect"] | |
| filtered_df["Result_Display"] = np.where( | |
| filtered_df["Correct"], "✅ Correct", "❌ Incorrect" | |
| ) | |
| selected_correctness = filter_row1[1].multiselect( | |
| "Filter by Result:", | |
| options=correctness_options, | |
| default=correctness_options, | |
| ) | |
| filter_correctness_bools = [ | |
| True if c == "✅ Correct" else False for c in selected_correctness | |
| ] | |
| filtered_df = filtered_df[ | |
| filtered_df["Correct"].isin(filter_correctness_bools) | |
| ] | |
| # Filter 3: By Confidence Range (full width below others) | |
| min_conf, max_conf = filter_row2[0].slider( | |
| "Filter by Confidence Range:", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=(0.0, 1.0), | |
| step=0.01, | |
| format="%.2f", # Format slider display for clarity | |
| ) | |
| filtered_df = filtered_df[ | |
| (filtered_df["Confidence"] >= min_conf) | |
| & (filtered_df["Confidence"] <= max_conf) | |
| ] | |
| # --- END FILTER SECTION --- | |
| # Apply Confusion Matrix Drill-Down Filter (if active) | |
| if st.session_state.get("cm_filter_active", False): | |
| actual_idx = st.session_state["cm_actual_filter"] | |
| predicted_idx = st.session_state["cm_predicted_filter"] | |
| filter_label = st.session_state["cm_filter_label"] | |
| st.info(f"Filtering results for: **{filter_label}**") | |
| filtered_df = filtered_df[ | |
| (filtered_df["Ground Truth"] == actual_idx) | |
| & (filtered_df["Prediction"] == predicted_idx) | |
| ] | |
| # --- Display the Filtered Data Table --- | |
| if filtered_df.empty: | |
| st.warning("No files match the current filter criteria.") | |
| st.session_state.selected_spectrum_file = None | |
| else: | |
| display_df = filtered_df.drop( | |
| columns=["Correct", "Result_Display"], errors="ignore" | |
| ) | |
| st.dataframe( | |
| display_df, | |
| use_container_width=True, | |
| hide_index=True, | |
| on_select="rerun", | |
| selection_mode="single-row", | |
| key="results_grid_selection", | |
| ) | |
| # --- ROBUST SELECTION HANDLING (THE FIX) --- | |
| selection_state = st.session_state.get("results_grid_selection") | |
| # Check if selection_state is a dictionary AND if it contains the 'rows' key | |
| if ( | |
| isinstance(selection_state, dict) | |
| and "rows" in selection_state | |
| and selection_state["rows"] | |
| ): | |
| selected_index = selection_state["rows"][0] | |
| if selected_index < len(filtered_df): | |
| st.session_state.selected_spectrum_file = filtered_df.iloc[ | |
| selected_index | |
| ]["Filename"] | |
| else: | |
| # This can happen if the table is re-filtered and the old index is now out of bounds | |
| st.session_state.selected_spectrum_file = None | |
| else: | |
| # If the selection is empty or in an unexpected format, clear the selection | |
| st.session_state.selected_spectrum_file = None | |
| # --- END ROBUST HANDLING --- | |
| # --- ADD THIS ENTIRE NEW METHOD --- | |
| def render_selected_spectrum(self): | |
| """ | |
| Renders an expander with the spectrum plot for the currently selected file. | |
| This is called after the data grid. | |
| """ | |
| selected_file = st.session_state.get("selected_spectrum_file") | |
| # Only render if a file has been selected in the current session | |
| if selected_file: | |
| with st.expander(f"View Spectrum for: **{selected_file}**", expanded=True): | |
| # Retrieve the full, detailed record for the selected file | |
| spectrum_data = ResultsManager.get_spectrum_data_for_file(selected_file) | |
| # Check if the detailed data was successfully retrieved and contains all necessary arrays | |
| if spectrum_data and all( | |
| spectrum_data.get(k) is not None | |
| for k in ["x_raw", "y_raw", "x_resampled", "y_resampled"] | |
| ): | |
| # Generate a unique cache key for the plot to avoid re-generating it unnecessarily | |
| cache_key = hashlib.md5( | |
| ( | |
| f"{spectrum_data['x_raw'].tobytes()}" | |
| f"{spectrum_data['y_raw'].tobytes()}" | |
| f"{spectrum_data['x_resampled'].tobytes()}" | |
| f"{spectrum_data['y_resampled'].tobytes()}" | |
| ).encode() | |
| ).hexdigest() | |
| # Call the plotting function from ui_components | |
| plot_image = create_spectrum_plot( | |
| spectrum_data["x_raw"], | |
| spectrum_data["y_raw"], | |
| spectrum_data["x_resampled"], | |
| spectrum_data["y_resampled"], | |
| _cache_key=cache_key, | |
| ) | |
| st.image( | |
| plot_image, | |
| caption=f"Raw vs. Resampled Spectrum for {selected_file}", | |
| use_container_width=True, | |
| ) | |
| else: | |
| st.warning( | |
| f"Could not retrieve spectrum data for '{selected_file}'. The data might not have been stored during the initial run." | |
| ) | |
| # --- END NEW METHOD --- | |
| def render(self): | |
| """ | |
| The main public method to render the entire dashboard using a more | |
| organized and streamlined tab-based layout. | |
| """ | |
| if self.df.empty: | |
| st.info( | |
| "The results table is empty. Please run an analysis on the 'Upload and Run' page." | |
| ) | |
| return | |
| # --- Tier 1: KPIs (Always visible at the top) --- | |
| self.render_kpis() | |
| st.divider() | |
| # --- Tier 2: Tabbed Interface for Deeper Analysis --- | |
| tab1, tab2 = st.tabs(["📊 Visual Diagnostics", "🗂️ Results Explorer"]) | |
| with tab1: | |
| # The visual diagnostics (Confusion Matrix, etc.) go here. | |
| self.render_visual_diagnostics() | |
| with tab2: | |
| # The interactive grid AND the spectrum viewer it controls go here. | |
| self.render_interactive_grid() | |
| self.render_selected_spectrum() | |