devjas1 commited on
Commit
2c41fa3
·
1 Parent(s): ff443f3

(FEAT): Enhance results management with utility functions for session state initialization and reset

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. utils/results_manager.py +61 -16
.gitignore CHANGED
@@ -25,3 +25,4 @@ datasets/**
25
  !datasets/.README.md
26
  # ---------------------------------------
27
 
 
 
25
  !datasets/.README.md
26
  # ---------------------------------------
27
 
28
+ __pycache__.py
utils/results_manager.py CHANGED
@@ -9,6 +9,7 @@ from typing import Dict, List, Any, Optional
9
  from pathlib import Path
10
  import io
11
 
 
12
  class ResultsManager:
13
  """Manages session-wide results for multi-file inference"""
14
 
@@ -73,7 +74,7 @@ class ResultsManager:
73
  if not results:
74
  return pd.DataFrame()
75
 
76
- #===Flatten the results for DataFrame===
77
  df_data = []
78
  for result in results:
79
  row = {
@@ -99,7 +100,7 @@ class ResultsManager:
99
  if df.empty:
100
  return b""
101
 
102
- #===Use StringIO to create CSV in memory===
103
  csv_buffer = io.StringIO()
104
  df.to_csv(csv_buffer, index=False)
105
  return csv_buffer.getvalue().encode('utf-8')
@@ -128,9 +129,9 @@ class ResultsManager:
128
  "avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
129
  "files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
130
  }
131
- #===Calculate accuracy if ground truth is available===
132
  correct_predictions = sum(
133
- 1 for r in results
134
  if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
135
  )
136
  total_with_gt = stats["files_with_ground_truth"]
@@ -138,7 +139,7 @@ class ResultsManager:
138
  stats["accuracy"] = correct_predictions / total_with_gt
139
  else:
140
  stats["accuracy"] = None
141
-
142
  return stats
143
 
144
  @staticmethod
@@ -146,26 +147,70 @@ class ResultsManager:
146
  """Remove a result by filename. Returns True if removed, False if not found."""
147
  results = ResultsManager.get_results()
148
  original_length = len(results)
149
-
150
  # Filter out results with matching filename
151
  st.session_state[ResultsManager.RESULTS_KEY] = [
152
  r for r in results if r["filename"] != filename
153
  ]
154
-
155
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
156
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  @staticmethod
158
  def display_results_table() -> None:
159
  """Display the results table in Streamlit UI"""
160
  df = ResultsManager.get_results_dataframe()
161
 
162
  if df.empty:
163
- st.info("No inference results yet. Upload files and run analysis to see results here.")
 
164
  return
165
 
166
  st.subheader(f"Inference Results ({len(df)} files)")
167
 
168
- #==Summary stats==
169
  stats = ResultsManager.get_summary_stats()
170
  if stats:
171
  col1, col2, col3, col4 = st.columns(4)
@@ -174,17 +219,18 @@ class ResultsManager:
174
  with col2:
175
  st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
176
  with col3:
177
- st.metric("Stable/Weathered", f"{stats['stable_predictions']}/{stats['weathered_predictions']}")
 
178
  with col4:
179
  if stats["accuracy"] is not None:
180
  st.metric("Accuracy", f"{stats['accuracy']:.3f}")
181
  else:
182
  st.metric("Accuracy", "N/A")
183
 
184
- #==Results Table==
185
  st.dataframe(df, use_container_width=True)
186
 
187
- #==Export Button==
188
  col1, col2, col3 = st.columns([1, 1, 2])
189
 
190
  with col1:
@@ -208,6 +254,5 @@ class ResultsManager:
208
  )
209
 
210
  with col3:
211
- if st.button("Clear All Results", help="Clear all stored results"):
212
- ResultsManager.clear_results()
213
- st.rerun()
 
9
  from pathlib import Path
10
  import io
11
 
12
+
13
  class ResultsManager:
14
  """Manages session-wide results for multi-file inference"""
15
 
 
74
  if not results:
75
  return pd.DataFrame()
76
 
77
+ # ===Flatten the results for DataFrame===
78
  df_data = []
79
  for result in results:
80
  row = {
 
100
  if df.empty:
101
  return b""
102
 
103
+ # ===Use StringIO to create CSV in memory===
104
  csv_buffer = io.StringIO()
105
  df.to_csv(csv_buffer, index=False)
106
  return csv_buffer.getvalue().encode('utf-8')
 
129
  "avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
130
  "files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
131
  }
132
+ # ===Calculate accuracy if ground truth is available===
133
  correct_predictions = sum(
134
+ 1 for r in results
135
  if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
136
  )
137
  total_with_gt = stats["files_with_ground_truth"]
 
139
  stats["accuracy"] = correct_predictions / total_with_gt
140
  else:
141
  stats["accuracy"] = None
142
+
143
  return stats
144
 
145
  @staticmethod
 
147
  """Remove a result by filename. Returns True if removed, False if not found."""
148
  results = ResultsManager.get_results()
149
  original_length = len(results)
150
+
151
  # Filter out results with matching filename
152
  st.session_state[ResultsManager.RESULTS_KEY] = [
153
  r for r in results if r["filename"] != filename
154
  ]
155
+
156
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
157
+
158
+ @staticmethod
159
+ # ==UTILITY FUNCTIONS==
160
+ def init_session_state():
161
+ """Keep a persistent session state"""
162
+ defaults = {
163
+ "status_message": "Ready to analyze polymer spectra 🔬",
164
+ "status_type": "info",
165
+ "input_text": None,
166
+ "filename": None,
167
+ "input_source": None, # "upload", "batch" or "sample"
168
+ "sample_select": "-- Select Sample --",
169
+ "input_mode": "Upload File", # controls which pane is visible
170
+ "inference_run_once": False,
171
+ "x_raw": None, "y_raw": None, "y_resampled": None,
172
+ "log_messages": [],
173
+ "uploader_version": 0,
174
+ "current_upload_key": "upload_txt_0",
175
+ "active_tab": "Details",
176
+ "batch_mode": False,
177
+ }
178
+
179
+ # Init session state with defaults
180
+ for key, value in defaults.items():
181
+ if key not in st.session_state:
182
+ st.session_state[key] = value
183
+
184
+ @staticmethod
185
+ def reset_ephemeral_state():
186
+ """Comprehensive reset for the entire app state."""
187
+ # Define keys that should NOT be cleared by a full reset
188
+ keep_keys = {"model_select", "input_mode"}
189
+
190
+ for k in list(st.session_state.keys()):
191
+ if k not in keep_keys:
192
+ st.session_state.pop(k, None)
193
+
194
+ # Re-initialize the core state after clearing
195
+ ResultsManager.init_session_state()
196
+
197
+ # CRITICAL: Bump the uploader version to force a widget reset
198
+ st.session_state["uploader_version"] += 1
199
+ st.session_state["current_upload_key"] = f"upload_txt_{st.session_state['uploader_version']}"
200
+
201
  @staticmethod
202
  def display_results_table() -> None:
203
  """Display the results table in Streamlit UI"""
204
  df = ResultsManager.get_results_dataframe()
205
 
206
  if df.empty:
207
+ st.info(
208
+ "No inference results yet. Upload files and run analysis to see results here.")
209
  return
210
 
211
  st.subheader(f"Inference Results ({len(df)} files)")
212
 
213
+ # ==Summary stats==
214
  stats = ResultsManager.get_summary_stats()
215
  if stats:
216
  col1, col2, col3, col4 = st.columns(4)
 
219
  with col2:
220
  st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
221
  with col3:
222
+ st.metric(
223
+ "Stable/Weathered", f"{stats['stable_predictions']}/{stats['weathered_predictions']}")
224
  with col4:
225
  if stats["accuracy"] is not None:
226
  st.metric("Accuracy", f"{stats['accuracy']:.3f}")
227
  else:
228
  st.metric("Accuracy", "N/A")
229
 
230
+ # ==Results Table==
231
  st.dataframe(df, use_container_width=True)
232
 
233
+ # ==Export Button==
234
  col1, col2, col3 = st.columns([1, 1, 2])
235
 
236
  with col1:
 
254
  )
255
 
256
  with col3:
257
+ if st.button("Clear All Results", help="Clear all stored results", on_click=ResultsManager.reset_ephemeral_state):
258
+ st.rerun()