Spaces:
Running
Running
devjas1
FIX(results_manager): correct dashboard page link case sensitivity in ResultsManager
77734fd
"""Session results management for multi-file inference. | |
Handles in-memory results table and export functionality""" | |
import streamlit as st | |
import pandas as pd | |
import json | |
from datetime import datetime | |
from typing import Dict, List, Any, Optional | |
import numpy as np | |
from pathlib import Path | |
import io | |
def local_css(file_name): | |
with open(file_name, encoding="utf-8") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
class ResultsManager: | |
"""Manages session-wide results for multi-file inference""" | |
RESULTS_KEY = "inference_results" | |
def init_results_table() -> None: | |
"""Initialize the results table in session state""" | |
if ResultsManager.RESULTS_KEY not in st.session_state: | |
st.session_state[ResultsManager.RESULTS_KEY] = [] | |
def add_results( | |
filename: str, | |
model_name: str, | |
prediction: int, | |
predicted_class: str, | |
confidence: float, | |
logits: List[float], | |
ground_truth: Optional[int] = None, | |
processing_time: float = 0.0, | |
metadata: Optional[Dict[str, Any]] = None, | |
) -> None: | |
"""Add a single inference result to the results table""" | |
ResultsManager.init_results_table() | |
result = { | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"filename": filename, | |
"model": model_name, | |
"prediction": prediction, | |
"predicted_class": predicted_class, | |
"confidence": confidence, | |
"logits": logits, | |
"ground_truth": ground_truth, | |
"processing_time": processing_time, | |
"metadata": metadata or {}, | |
} | |
st.session_state[ResultsManager.RESULTS_KEY].append(result) | |
def get_results() -> List[Dict[str, Any]]: | |
"""Get all inference results""" | |
ResultsManager.init_results_table() | |
return st.session_state[ResultsManager.RESULTS_KEY] | |
def get_results_count() -> int: | |
"""Get the number of stored results""" | |
return len(ResultsManager.get_results()) | |
def clear_results() -> None: | |
"""Clear all stored results""" | |
st.session_state[ResultsManager.RESULTS_KEY] = [] | |
def get_spectrum_data_for_file(filename: str) -> Optional[Dict[str, np.ndarray]]: | |
""" | |
Retrieves raw and resampled spectrum data for a given filename. | |
Returns None if no data is found for the filename or if data is incomplete. | |
""" | |
results = ResultsManager.get_results() | |
for r in results: | |
if r["filename"] == filename: | |
# Ensure all required keys are present and not None | |
if all( | |
r.get(k) is not None | |
for k in ["x_raw", "y_raw", "x_resampled", "y_resampled"] | |
): | |
return { | |
"x_raw": r["x_raw"], | |
"y_raw": r["y_raw"], | |
"x_resampled": r["x_resampled"], | |
"y_resampled": r["y_resampled"], | |
} | |
else: | |
# If the metadata exists but spectrum data is missing for this entry, | |
# it means it was processed before we started storing spectrums. | |
return None | |
return None # Return None if filename not found | |
def get_results_dataframe() -> pd.DataFrame: | |
"""Convert results to pandas DataFrame for display and export""" | |
results = ResultsManager.get_results() | |
if not results: | |
return pd.DataFrame() | |
# ===Flatten the results for DataFrame=== | |
df_data = [] | |
for result in results: | |
row = { | |
"Timestamp": result["timestamp"], | |
"Filename": result["filename"], | |
"Model": result["model"], | |
"Prediction": result["prediction"], | |
"Predicted Class": result["predicted_class"], | |
"Confidence": f"{result['confidence']:.3f}", | |
"Stable Logit": ( | |
f"{result['logits'][0]:.3f}" if len(result["logits"]) > 0 else "N/A" | |
), | |
"Weathered Logit": ( | |
f"{result['logits'][1]:.3f}" if len(result["logits"]) > 1 else "N/A" | |
), | |
"Ground Truth": ( | |
result["ground_truth"] | |
if result["ground_truth"] is not None | |
else "Unknown" | |
), | |
"Processing Time (s)": f"{result['processing_time']:.3f}", | |
} | |
df_data.append(row) | |
return pd.DataFrame(df_data) | |
def export_to_csv() -> bytes: | |
"""Export results to CSV format""" | |
df = ResultsManager.get_results_dataframe() | |
if df.empty: | |
return b"" | |
# ===Use StringIO to create CSV in memory=== | |
csv_buffer = io.StringIO() | |
df.to_csv(csv_buffer, index=False) | |
return csv_buffer.getvalue().encode("utf-8") | |
def export_to_json() -> str: | |
"""Export results to JSON format""" | |
results = ResultsManager.get_results() | |
return json.dumps(results, indent=2, default=str) | |
def get_summary_stats() -> Dict[str, Any]: | |
"""Get summary statistics for the results""" | |
results = ResultsManager.get_results() | |
if not results: | |
return {} | |
df = ResultsManager.get_results_dataframe() | |
stats = { | |
"total_files": len(results), | |
"models_used": list(set(r["model"] for r in results)), | |
"stable_predictions": sum(1 for r in results if r["prediction"] == 0), | |
"weathered_predictions": sum(1 for r in results if r["prediction"] == 1), | |
"avg_confidence": sum(r["confidence"] for r in results) / len(results), | |
"avg_processing_time": sum(r["processing_time"] for r in results) | |
/ len(results), | |
"files_with_ground_truth": sum( | |
1 for r in results if r["ground_truth"] is not None | |
), | |
} | |
# ===Calculate accuracy if ground truth is available=== | |
correct_predictions = sum( | |
1 | |
for r in results | |
if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"] | |
) | |
total_with_gt = stats["files_with_ground_truth"] | |
if total_with_gt > 0: | |
stats["accuracy"] = correct_predictions / total_with_gt | |
else: | |
stats["accuracy"] = None | |
return stats | |
def remove_result_by_filename(filename: str) -> bool: | |
"""Remove a result by filename. Returns True if removed, False if not found.""" | |
results = ResultsManager.get_results() | |
original_length = len(results) | |
# Filter out results with matching filename | |
st.session_state[ResultsManager.RESULTS_KEY] = [ | |
r for r in results if r["filename"] != filename | |
] | |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length | |
# ==UTILITY FUNCTIONS== | |
def init_session_state(): | |
"""Keep a persistent session state""" | |
defaults = { | |
"status_message": "Ready to analyze polymer spectra 🔬", | |
"status_type": "info", | |
"input_text": None, | |
"filename": None, | |
"input_source": None, # "upload", "batch" or "sample" | |
"sample_select": "-- Select Sample --", | |
"input_mode": "Upload File", # controls which pane is visible | |
"inference_run_once": False, | |
"x_raw": None, | |
"y_raw": None, | |
"y_resampled": None, | |
"log_messages": [], | |
"uploader_version": 0, | |
"current_upload_key": "upload_txt_0", | |
"active_tab": "Details", | |
"batch_mode": False, | |
} | |
# Init session state with defaults | |
for key, value in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
def reset_ephemeral_state(): | |
"""Comprehensive reset for the entire app state.""" | |
current_version = st.session_state.get("uploader_version", 0) | |
# Define keys that should NOT be cleared by a full reset | |
keep_keys = {"model_select", "input_mode"} | |
for k in list(st.session_state.keys()): | |
if k not in keep_keys: | |
st.session_state.pop(k, None) | |
st.session_state["status_message"] = "Ready to analyze polymer spectra" | |
st.session_state["status_type"] = "info" | |
st.session_state["batch_files"] = [] | |
st.session_state["inference_run_once"] = True | |
st.session_state[""] = "" | |
# CRITICAL: Increment the preserved version and re-assign it | |
st.session_state["uploader_version"] = current_version + 1 | |
st.session_state["current_upload_key"] = ( | |
f"upload_txt_{st.session_state['uploader_version']}" | |
) | |
def display_results_table() -> None: | |
"""Display the results table in Streamlit UI""" | |
df = ResultsManager.get_results_dataframe() | |
if df.empty: | |
st.info( | |
"No inference results yet. Upload files and run analysis to see results here." | |
) | |
return | |
local_css("static/style.css") | |
st.subheader(f"Inference Results ({len(df)} files)") | |
# ==Summary stats== | |
stats = ResultsManager.get_summary_stats() | |
if stats: | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Total Files", stats["total_files"]) | |
with col2: | |
st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}") | |
with col3: | |
st.metric( | |
"Stable/Weathered", | |
f"{stats['stable_predictions']}/{stats['weathered_predictions']}", | |
) | |
with col4: | |
if stats["accuracy"] is not None: | |
st.metric("Accuracy", f"{stats['accuracy']:.3f}") | |
else: | |
st.metric("Accuracy", "N/A") | |
# ==Results Table== | |
st.dataframe(df, use_container_width=True) | |
with st.container(border=None, key="page-link-container"): | |
st.page_link( | |
"pages/2_Dashboard.py", | |
label="Inference Analysis Dashboard", | |
help="Dive deeper into your batch results.", | |
use_container_width=False, | |
) | |
# ==Export Button== | |
with st.container(border=None, key="buttons-container"): | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
with col1: | |
csv_data = ResultsManager.export_to_csv() | |
if csv_data: | |
with st.container(border=None, key="csv-button"): | |
st.download_button( | |
label="Download CSV", | |
data=csv_data, | |
file_name=f"polymer_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
mime="text/csv", | |
help="Export Results to CSV", | |
use_container_width=True, | |
type="tertiary", | |
) | |
with col2: | |
json_data = ResultsManager.export_to_json() | |
if json_data: | |
with st.container(border=None, key="json-button"): | |
st.download_button( | |
label="Download JSON", | |
data=json_data, | |
file_name=f"polymer_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
mime="application/json", | |
help="Export Results to JSON", | |
type="tertiary", | |
use_container_width=True, | |
) | |
with col3: | |
with st.container(border=None, key="clearall-button"): | |
st.button( | |
label="Clear All Results", | |
help="Clear all stored results", | |
on_click=ResultsManager.reset_ephemeral_state, | |
use_container_width=True, | |
type="tertiary", | |
) | |