Spaces:
Running
Running
devjas1
commited on
Commit
·
2c41fa3
1
Parent(s):
ff443f3
(FEAT): Enhance results management with utility functions for session state initialization and reset
Browse files- .gitignore +1 -0
- utils/results_manager.py +61 -16
.gitignore
CHANGED
@@ -25,3 +25,4 @@ datasets/**
|
|
25 |
!datasets/.README.md
|
26 |
# ---------------------------------------
|
27 |
|
|
|
|
25 |
!datasets/.README.md
|
26 |
# ---------------------------------------
|
27 |
|
28 |
+
__pycache__.py
|
utils/results_manager.py
CHANGED
@@ -9,6 +9,7 @@ from typing import Dict, List, Any, Optional
|
|
9 |
from pathlib import Path
|
10 |
import io
|
11 |
|
|
|
12 |
class ResultsManager:
|
13 |
"""Manages session-wide results for multi-file inference"""
|
14 |
|
@@ -73,7 +74,7 @@ class ResultsManager:
|
|
73 |
if not results:
|
74 |
return pd.DataFrame()
|
75 |
|
76 |
-
|
77 |
df_data = []
|
78 |
for result in results:
|
79 |
row = {
|
@@ -99,7 +100,7 @@ class ResultsManager:
|
|
99 |
if df.empty:
|
100 |
return b""
|
101 |
|
102 |
-
|
103 |
csv_buffer = io.StringIO()
|
104 |
df.to_csv(csv_buffer, index=False)
|
105 |
return csv_buffer.getvalue().encode('utf-8')
|
@@ -128,9 +129,9 @@ class ResultsManager:
|
|
128 |
"avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
|
129 |
"files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
|
130 |
}
|
131 |
-
|
132 |
correct_predictions = sum(
|
133 |
-
1 for r in results
|
134 |
if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
|
135 |
)
|
136 |
total_with_gt = stats["files_with_ground_truth"]
|
@@ -138,7 +139,7 @@ class ResultsManager:
|
|
138 |
stats["accuracy"] = correct_predictions / total_with_gt
|
139 |
else:
|
140 |
stats["accuracy"] = None
|
141 |
-
|
142 |
return stats
|
143 |
|
144 |
@staticmethod
|
@@ -146,26 +147,70 @@ class ResultsManager:
|
|
146 |
"""Remove a result by filename. Returns True if removed, False if not found."""
|
147 |
results = ResultsManager.get_results()
|
148 |
original_length = len(results)
|
149 |
-
|
150 |
# Filter out results with matching filename
|
151 |
st.session_state[ResultsManager.RESULTS_KEY] = [
|
152 |
r for r in results if r["filename"] != filename
|
153 |
]
|
154 |
-
|
155 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
@staticmethod
|
158 |
def display_results_table() -> None:
|
159 |
"""Display the results table in Streamlit UI"""
|
160 |
df = ResultsManager.get_results_dataframe()
|
161 |
|
162 |
if df.empty:
|
163 |
-
st.info(
|
|
|
164 |
return
|
165 |
|
166 |
st.subheader(f"Inference Results ({len(df)} files)")
|
167 |
|
168 |
-
|
169 |
stats = ResultsManager.get_summary_stats()
|
170 |
if stats:
|
171 |
col1, col2, col3, col4 = st.columns(4)
|
@@ -174,17 +219,18 @@ class ResultsManager:
|
|
174 |
with col2:
|
175 |
st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
|
176 |
with col3:
|
177 |
-
st.metric(
|
|
|
178 |
with col4:
|
179 |
if stats["accuracy"] is not None:
|
180 |
st.metric("Accuracy", f"{stats['accuracy']:.3f}")
|
181 |
else:
|
182 |
st.metric("Accuracy", "N/A")
|
183 |
|
184 |
-
|
185 |
st.dataframe(df, use_container_width=True)
|
186 |
|
187 |
-
|
188 |
col1, col2, col3 = st.columns([1, 1, 2])
|
189 |
|
190 |
with col1:
|
@@ -208,6 +254,5 @@ class ResultsManager:
|
|
208 |
)
|
209 |
|
210 |
with col3:
|
211 |
-
if st.button("Clear All Results", help="Clear all stored results"):
|
212 |
-
|
213 |
-
st.rerun()
|
|
|
9 |
from pathlib import Path
|
10 |
import io
|
11 |
|
12 |
+
|
13 |
class ResultsManager:
|
14 |
"""Manages session-wide results for multi-file inference"""
|
15 |
|
|
|
74 |
if not results:
|
75 |
return pd.DataFrame()
|
76 |
|
77 |
+
# ===Flatten the results for DataFrame===
|
78 |
df_data = []
|
79 |
for result in results:
|
80 |
row = {
|
|
|
100 |
if df.empty:
|
101 |
return b""
|
102 |
|
103 |
+
# ===Use StringIO to create CSV in memory===
|
104 |
csv_buffer = io.StringIO()
|
105 |
df.to_csv(csv_buffer, index=False)
|
106 |
return csv_buffer.getvalue().encode('utf-8')
|
|
|
129 |
"avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
|
130 |
"files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
|
131 |
}
|
132 |
+
# ===Calculate accuracy if ground truth is available===
|
133 |
correct_predictions = sum(
|
134 |
+
1 for r in results
|
135 |
if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
|
136 |
)
|
137 |
total_with_gt = stats["files_with_ground_truth"]
|
|
|
139 |
stats["accuracy"] = correct_predictions / total_with_gt
|
140 |
else:
|
141 |
stats["accuracy"] = None
|
142 |
+
|
143 |
return stats
|
144 |
|
145 |
@staticmethod
|
|
|
147 |
"""Remove a result by filename. Returns True if removed, False if not found."""
|
148 |
results = ResultsManager.get_results()
|
149 |
original_length = len(results)
|
150 |
+
|
151 |
# Filter out results with matching filename
|
152 |
st.session_state[ResultsManager.RESULTS_KEY] = [
|
153 |
r for r in results if r["filename"] != filename
|
154 |
]
|
155 |
+
|
156 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
# ==UTILITY FUNCTIONS==
|
160 |
+
def init_session_state():
|
161 |
+
"""Keep a persistent session state"""
|
162 |
+
defaults = {
|
163 |
+
"status_message": "Ready to analyze polymer spectra 🔬",
|
164 |
+
"status_type": "info",
|
165 |
+
"input_text": None,
|
166 |
+
"filename": None,
|
167 |
+
"input_source": None, # "upload", "batch" or "sample"
|
168 |
+
"sample_select": "-- Select Sample --",
|
169 |
+
"input_mode": "Upload File", # controls which pane is visible
|
170 |
+
"inference_run_once": False,
|
171 |
+
"x_raw": None, "y_raw": None, "y_resampled": None,
|
172 |
+
"log_messages": [],
|
173 |
+
"uploader_version": 0,
|
174 |
+
"current_upload_key": "upload_txt_0",
|
175 |
+
"active_tab": "Details",
|
176 |
+
"batch_mode": False,
|
177 |
+
}
|
178 |
+
|
179 |
+
# Init session state with defaults
|
180 |
+
for key, value in defaults.items():
|
181 |
+
if key not in st.session_state:
|
182 |
+
st.session_state[key] = value
|
183 |
+
|
184 |
+
@staticmethod
|
185 |
+
def reset_ephemeral_state():
|
186 |
+
"""Comprehensive reset for the entire app state."""
|
187 |
+
# Define keys that should NOT be cleared by a full reset
|
188 |
+
keep_keys = {"model_select", "input_mode"}
|
189 |
+
|
190 |
+
for k in list(st.session_state.keys()):
|
191 |
+
if k not in keep_keys:
|
192 |
+
st.session_state.pop(k, None)
|
193 |
+
|
194 |
+
# Re-initialize the core state after clearing
|
195 |
+
ResultsManager.init_session_state()
|
196 |
+
|
197 |
+
# CRITICAL: Bump the uploader version to force a widget reset
|
198 |
+
st.session_state["uploader_version"] += 1
|
199 |
+
st.session_state["current_upload_key"] = f"upload_txt_{st.session_state['uploader_version']}"
|
200 |
+
|
201 |
@staticmethod
|
202 |
def display_results_table() -> None:
|
203 |
"""Display the results table in Streamlit UI"""
|
204 |
df = ResultsManager.get_results_dataframe()
|
205 |
|
206 |
if df.empty:
|
207 |
+
st.info(
|
208 |
+
"No inference results yet. Upload files and run analysis to see results here.")
|
209 |
return
|
210 |
|
211 |
st.subheader(f"Inference Results ({len(df)} files)")
|
212 |
|
213 |
+
# ==Summary stats==
|
214 |
stats = ResultsManager.get_summary_stats()
|
215 |
if stats:
|
216 |
col1, col2, col3, col4 = st.columns(4)
|
|
|
219 |
with col2:
|
220 |
st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
|
221 |
with col3:
|
222 |
+
st.metric(
|
223 |
+
"Stable/Weathered", f"{stats['stable_predictions']}/{stats['weathered_predictions']}")
|
224 |
with col4:
|
225 |
if stats["accuracy"] is not None:
|
226 |
st.metric("Accuracy", f"{stats['accuracy']:.3f}")
|
227 |
else:
|
228 |
st.metric("Accuracy", "N/A")
|
229 |
|
230 |
+
# ==Results Table==
|
231 |
st.dataframe(df, use_container_width=True)
|
232 |
|
233 |
+
# ==Export Button==
|
234 |
col1, col2, col3 = st.columns([1, 1, 2])
|
235 |
|
236 |
with col1:
|
|
|
254 |
)
|
255 |
|
256 |
with col3:
|
257 |
+
if st.button("Clear All Results", help="Clear all stored results", on_click=ResultsManager.reset_ephemeral_state):
|
258 |
+
st.rerun()
|
|