Spaces:
Running
(FEAT)[UI/UX]: Add support for FTIR, multi-format upload, and model comparison tab
Browse filesSidebar:
- Added spectroscopy modality selection (Raman/FTIR) with explanatory info for each.
- Expanded model selection and improved project description to reflect FTIR and multi-model features.
Input column:
- File uploader now accepts .txt, .csv, and .json for single and batch uploads.
- Updated help text and file type validation.
New function 'render_comparison_tab':
- Allows users to select multiple models and upload/choose sample data for side-by-side prediction.
- Displays comparison results in tables and visualizations (confidence bar chart, agreement stats, performance metrics).
- Supports exporting results in JSON/full report formats.
- Shows historical comparison statistics with agreement matrix and heatmap.
New function render_performance_tab:
- Integrates performance dashboard from tracker utility.
- modules/ui_components.py +478 -51
@@ -13,9 +13,9 @@ from modules.callbacks import (
|
|
13 |
on_model_change,
|
14 |
on_input_mode_change,
|
15 |
on_sample_change,
|
|
|
16 |
reset_ephemeral_state,
|
17 |
log_message,
|
18 |
-
clear_batch_results,
|
19 |
)
|
20 |
from core_logic import (
|
21 |
get_sample_files,
|
@@ -24,7 +24,6 @@ from core_logic import (
|
|
24 |
parse_spectrum_data,
|
25 |
label_file,
|
26 |
)
|
27 |
-
from modules.callbacks import reset_results
|
28 |
from utils.results_manager import ResultsManager
|
29 |
from utils.confidence import calculate_softmax_confidence
|
30 |
from utils.multifile import process_multiple_files, display_batch_results
|
@@ -41,7 +40,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
41 |
"""Create spectrum visualization plot"""
|
42 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
43 |
|
44 |
-
#
|
45 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
46 |
ax[0].set_title("Raw Input Spectrum")
|
47 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
@@ -49,7 +48,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
49 |
ax[0].grid(True, alpha=0.3)
|
50 |
ax[0].legend()
|
51 |
|
52 |
-
#
|
53 |
ax[1].plot(
|
54 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
55 |
)
|
@@ -60,7 +59,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
60 |
ax[1].legend()
|
61 |
|
62 |
fig.tight_layout()
|
63 |
-
#
|
64 |
buf = io.BytesIO()
|
65 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
66 |
buf.seek(0)
|
@@ -69,6 +68,9 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
69 |
return Image.open(buf)
|
70 |
|
71 |
|
|
|
|
|
|
|
72 |
def render_confidence_progress(
|
73 |
probs: np.ndarray,
|
74 |
labels: list[str] = ["Stable", "Weathered"],
|
@@ -114,7 +116,10 @@ def render_confidence_progress(
|
|
114 |
st.markdown("")
|
115 |
|
116 |
|
117 |
-
|
|
|
|
|
|
|
118 |
if d is None:
|
119 |
d = {}
|
120 |
if not d:
|
@@ -126,6 +131,9 @@ def render_kv_grid(d: dict = {}, ncols: int = 2):
|
|
126 |
st.caption(f"**{k}:** {v}")
|
127 |
|
128 |
|
|
|
|
|
|
|
129 |
def render_model_meta(model_choice: str):
|
130 |
info = MODEL_CONFIG.get(model_choice, {})
|
131 |
emoji = info.get("emoji", "")
|
@@ -143,6 +151,9 @@ def render_model_meta(model_choice: str):
|
|
143 |
st.caption(desc)
|
144 |
|
145 |
|
|
|
|
|
|
|
146 |
def get_confidence_description(logit_margin):
|
147 |
"""Get human-readable confidence description"""
|
148 |
if logit_margin > 1000:
|
@@ -155,13 +166,35 @@ def get_confidence_description(logit_margin):
|
|
155 |
return "LOW", "🔴"
|
156 |
|
157 |
|
|
|
|
|
|
|
158 |
def render_sidebar():
|
159 |
with st.sidebar:
|
160 |
# Header
|
161 |
st.header("AI-Driven Polymer Classification")
|
162 |
st.caption(
|
163 |
-
"Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
model_labels = [
|
166 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
167 |
]
|
@@ -173,10 +206,10 @@ def render_sidebar():
|
|
173 |
)
|
174 |
model_choice = selected_label.split(" ", 1)[1]
|
175 |
|
176 |
-
#
|
177 |
render_model_meta(model_choice)
|
178 |
|
179 |
-
#
|
180 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
181 |
st.markdown(
|
182 |
"""
|
@@ -184,8 +217,9 @@ def render_sidebar():
|
|
184 |
|
185 |
**Purpose**: Classify polymer degradation using AI<br>
|
186 |
**Input**: Raman spectroscopy .txt files<br>
|
187 |
-
**Models**: CNN architectures for
|
188 |
-
**
|
|
|
189 |
|
190 |
|
191 |
**Contributors**<br>
|
@@ -207,11 +241,7 @@ def render_sidebar():
|
|
207 |
)
|
208 |
|
209 |
|
210 |
-
#
|
211 |
-
|
212 |
-
# In modules/ui_components.py
|
213 |
-
|
214 |
-
|
215 |
def render_input_column():
|
216 |
st.markdown("##### Data Input")
|
217 |
|
@@ -224,22 +254,20 @@ def render_input_column():
|
|
224 |
)
|
225 |
|
226 |
# == Input Mode Logic ==
|
227 |
-
# ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ...
|
228 |
-
# ==Upload tab==
|
229 |
if mode == "Upload File":
|
230 |
upload_key = st.session_state["current_upload_key"]
|
231 |
up = st.file_uploader(
|
232 |
-
"Upload
|
233 |
-
type="txt",
|
234 |
-
help="Upload
|
235 |
key=upload_key, # ← versioned key
|
236 |
)
|
237 |
|
238 |
-
#
|
239 |
if up is not None:
|
240 |
raw = up.read()
|
241 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
242 |
-
#
|
243 |
if (
|
244 |
st.session_state.get("filename") != getattr(up, "name", None)
|
245 |
or st.session_state.get("input_source") != "upload"
|
@@ -255,23 +283,20 @@ def render_input_column():
|
|
255 |
st.session_state["status_type"] = "success"
|
256 |
reset_results("New file uploaded")
|
257 |
|
258 |
-
#
|
259 |
elif mode == "Batch Upload":
|
260 |
st.session_state["batch_mode"] = True
|
261 |
-
# --- START: BUG 1 & 3 FIX ---
|
262 |
# Use a versioned key to ensure the file uploader resets properly.
|
263 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
264 |
uploaded_files = st.file_uploader(
|
265 |
-
"Upload multiple
|
266 |
-
type="txt",
|
267 |
accept_multiple_files=True,
|
268 |
-
help="Upload
|
269 |
key=batch_upload_key,
|
270 |
)
|
271 |
-
# --- END: BUG 1 & 3 FIX ---
|
272 |
|
273 |
if uploaded_files:
|
274 |
-
# --- START: Bug 1 Fix ---
|
275 |
# Use a dictionary to keep only unique files based on name and size
|
276 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
277 |
unique_file_list = list(unique_files.values())
|
@@ -281,9 +306,7 @@ def render_input_column():
|
|
281 |
|
282 |
# Optionally, inform the user that duplicates were removed
|
283 |
if num_uploaded > num_unique:
|
284 |
-
st.info(
|
285 |
-
f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed."
|
286 |
-
)
|
287 |
|
288 |
# Use the unique list
|
289 |
st.session_state["batch_files"] = unique_file_list
|
@@ -291,7 +314,6 @@ def render_input_column():
|
|
291 |
f"{num_unique} ready for batch analysis"
|
292 |
)
|
293 |
st.session_state["status_type"] = "success"
|
294 |
-
# --- END: Bug 1 Fix ---
|
295 |
else:
|
296 |
st.session_state["batch_files"] = []
|
297 |
# This check prevents resetting the status if files are already staged
|
@@ -301,7 +323,7 @@ def render_input_column():
|
|
301 |
)
|
302 |
st.session_state["status_type"] = "info"
|
303 |
|
304 |
-
#
|
305 |
elif mode == "Sample Data":
|
306 |
st.session_state["batch_mode"] = False
|
307 |
sample_files = get_sample_files()
|
@@ -330,9 +352,6 @@ def render_input_column():
|
|
330 |
else:
|
331 |
st.info(msg)
|
332 |
|
333 |
-
# --- DE-NESTED LOGIC STARTS HERE ---
|
334 |
-
# This code now runs on EVERY execution, guaranteeing the buttons will appear.
|
335 |
-
|
336 |
# Safely get model choice from session state
|
337 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
338 |
model = load_model(model_choice)
|
@@ -388,7 +407,7 @@ def render_input_column():
|
|
388 |
st.error(f"Error processing spectrum data: {e}")
|
389 |
|
390 |
|
391 |
-
#
|
392 |
|
393 |
|
394 |
def render_results_column():
|
@@ -410,7 +429,7 @@ def render_results_column():
|
|
410 |
filename = st.session_state.get("filename", "Unknown")
|
411 |
|
412 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
413 |
-
#
|
414 |
if y_resampled is None:
|
415 |
raise ValueError(
|
416 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
@@ -437,14 +456,14 @@ def render_results_column():
|
|
437 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
438 |
)
|
439 |
|
440 |
-
#
|
441 |
true_label_idx = label_file(filename)
|
442 |
true_label_str = (
|
443 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
444 |
if true_label_idx is not None
|
445 |
else "Unknown"
|
446 |
)
|
447 |
-
#
|
448 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
449 |
|
450 |
# Enhanced confidence calculation
|
@@ -455,7 +474,7 @@ def render_results_column():
|
|
455 |
)
|
456 |
confidence_desc = confidence_level
|
457 |
else:
|
458 |
-
# Fallback to
|
459 |
logit_margin = abs(
|
460 |
(logits_list[0] - logits_list[1])
|
461 |
if logits_list is not None and len(logits_list) >= 2
|
@@ -487,7 +506,7 @@ def render_results_column():
|
|
487 |
},
|
488 |
)
|
489 |
|
490 |
-
#
|
491 |
model_choice = (
|
492 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
493 |
if "model_select" in st.session_state
|
@@ -505,7 +524,6 @@ def render_results_column():
|
|
505 |
if os.path.exists(model_path)
|
506 |
else "N/A"
|
507 |
)
|
508 |
-
# Removed unused variable 'input_tensor'
|
509 |
|
510 |
start_render = time.time()
|
511 |
|
@@ -590,17 +608,13 @@ def render_results_column():
|
|
590 |
""",
|
591 |
unsafe_allow_html=True,
|
592 |
)
|
593 |
-
# --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
|
594 |
|
595 |
st.divider()
|
596 |
|
597 |
-
#
|
598 |
-
# Secondary info is now a clean, single-line caption
|
599 |
st.caption(
|
600 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
601 |
)
|
602 |
-
# --- END: CLEAN METADATA FOOTER ---
|
603 |
-
|
604 |
st.markdown("</div>", unsafe_allow_html=True)
|
605 |
|
606 |
elif active_tab == "Technical":
|
@@ -918,7 +932,7 @@ def render_results_column():
|
|
918 |
"""
|
919 |
)
|
920 |
else:
|
921 |
-
#
|
922 |
st.markdown(
|
923 |
"""
|
924 |
##### How to Get Started
|
@@ -948,3 +962,416 @@ def render_results_column():
|
|
948 |
- 🏭 Quality control in manufacturing
|
949 |
"""
|
950 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
on_model_change,
|
14 |
on_input_mode_change,
|
15 |
on_sample_change,
|
16 |
+
reset_results,
|
17 |
reset_ephemeral_state,
|
18 |
log_message,
|
|
|
19 |
)
|
20 |
from core_logic import (
|
21 |
get_sample_files,
|
|
|
24 |
parse_spectrum_data,
|
25 |
label_file,
|
26 |
)
|
|
|
27 |
from utils.results_manager import ResultsManager
|
28 |
from utils.confidence import calculate_softmax_confidence
|
29 |
from utils.multifile import process_multiple_files, display_batch_results
|
|
|
40 |
"""Create spectrum visualization plot"""
|
41 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
42 |
|
43 |
+
# Raw spectrum
|
44 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
45 |
ax[0].set_title("Raw Input Spectrum")
|
46 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
|
48 |
ax[0].grid(True, alpha=0.3)
|
49 |
ax[0].legend()
|
50 |
|
51 |
+
# Resampled spectrum
|
52 |
ax[1].plot(
|
53 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
54 |
)
|
|
|
59 |
ax[1].legend()
|
60 |
|
61 |
fig.tight_layout()
|
62 |
+
# Convert to image
|
63 |
buf = io.BytesIO()
|
64 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
65 |
buf.seek(0)
|
|
|
68 |
return Image.open(buf)
|
69 |
|
70 |
|
71 |
+
# //////////////////////////////////////////
|
72 |
+
|
73 |
+
|
74 |
def render_confidence_progress(
|
75 |
probs: np.ndarray,
|
76 |
labels: list[str] = ["Stable", "Weathered"],
|
|
|
116 |
st.markdown("")
|
117 |
|
118 |
|
119 |
+
from typing import Optional
|
120 |
+
|
121 |
+
|
122 |
+
def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
|
123 |
if d is None:
|
124 |
d = {}
|
125 |
if not d:
|
|
|
131 |
st.caption(f"**{k}:** {v}")
|
132 |
|
133 |
|
134 |
+
# //////////////////////////////////////////
|
135 |
+
|
136 |
+
|
137 |
def render_model_meta(model_choice: str):
|
138 |
info = MODEL_CONFIG.get(model_choice, {})
|
139 |
emoji = info.get("emoji", "")
|
|
|
151 |
st.caption(desc)
|
152 |
|
153 |
|
154 |
+
# //////////////////////////////////////////
|
155 |
+
|
156 |
+
|
157 |
def get_confidence_description(logit_margin):
|
158 |
"""Get human-readable confidence description"""
|
159 |
if logit_margin > 1000:
|
|
|
166 |
return "LOW", "🔴"
|
167 |
|
168 |
|
169 |
+
# //////////////////////////////////////////
|
170 |
+
|
171 |
+
|
172 |
def render_sidebar():
|
173 |
with st.sidebar:
|
174 |
# Header
|
175 |
st.header("AI-Driven Polymer Classification")
|
176 |
st.caption(
|
177 |
+
"Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01"
|
178 |
+
)
|
179 |
+
|
180 |
+
# Modality Selection
|
181 |
+
st.markdown("##### Spectroscopy Modality")
|
182 |
+
modality = st.selectbox(
|
183 |
+
"Choose Modality",
|
184 |
+
["raman", "ftir"],
|
185 |
+
index=0,
|
186 |
+
key="modality_select",
|
187 |
+
format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}",
|
188 |
)
|
189 |
+
|
190 |
+
# Display modality info
|
191 |
+
if modality == "ftir":
|
192 |
+
st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction")
|
193 |
+
else:
|
194 |
+
st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing")
|
195 |
+
|
196 |
+
# Model selection
|
197 |
+
st.markdown("##### AI Model Selection")
|
198 |
model_labels = [
|
199 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
200 |
]
|
|
|
206 |
)
|
207 |
model_choice = selected_label.split(" ", 1)[1]
|
208 |
|
209 |
+
# Compact metadata directly under dropdown
|
210 |
render_model_meta(model_choice)
|
211 |
|
212 |
+
# Collapsed info to reduce clutter
|
213 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
214 |
st.markdown(
|
215 |
"""
|
|
|
217 |
|
218 |
**Purpose**: Classify polymer degradation using AI<br>
|
219 |
**Input**: Raman spectroscopy .txt files<br>
|
220 |
+
**Models**: CNN architectures for classification<br>
|
221 |
+
**Modalities**: Raman and FTIR spectroscopy support<br>
|
222 |
+
**Features**: Multi-model comparison and analysis<br>
|
223 |
|
224 |
|
225 |
**Contributors**<br>
|
|
|
241 |
)
|
242 |
|
243 |
|
244 |
+
# //////////////////////////////////////////
|
|
|
|
|
|
|
|
|
245 |
def render_input_column():
|
246 |
st.markdown("##### Data Input")
|
247 |
|
|
|
254 |
)
|
255 |
|
256 |
# == Input Mode Logic ==
|
|
|
|
|
257 |
if mode == "Upload File":
|
258 |
upload_key = st.session_state["current_upload_key"]
|
259 |
up = st.file_uploader(
|
260 |
+
"Upload spectrum file (.txt, .csv, .json)",
|
261 |
+
type=["txt", "csv", "json"],
|
262 |
+
help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format",
|
263 |
key=upload_key, # ← versioned key
|
264 |
)
|
265 |
|
266 |
+
# Process change immediately
|
267 |
if up is not None:
|
268 |
raw = up.read()
|
269 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
270 |
+
# only reparse if its a different file|source
|
271 |
if (
|
272 |
st.session_state.get("filename") != getattr(up, "name", None)
|
273 |
or st.session_state.get("input_source") != "upload"
|
|
|
283 |
st.session_state["status_type"] = "success"
|
284 |
reset_results("New file uploaded")
|
285 |
|
286 |
+
# Batch Upload tab
|
287 |
elif mode == "Batch Upload":
|
288 |
st.session_state["batch_mode"] = True
|
|
|
289 |
# Use a versioned key to ensure the file uploader resets properly.
|
290 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
291 |
uploaded_files = st.file_uploader(
|
292 |
+
"Upload multiple spectrum files (.txt, .csv, .json)",
|
293 |
+
type=["txt", "csv", "json"],
|
294 |
accept_multiple_files=True,
|
295 |
+
help="Upload spectroscopy files in TXT, CSV, or JSON format.",
|
296 |
key=batch_upload_key,
|
297 |
)
|
|
|
298 |
|
299 |
if uploaded_files:
|
|
|
300 |
# Use a dictionary to keep only unique files based on name and size
|
301 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
302 |
unique_file_list = list(unique_files.values())
|
|
|
306 |
|
307 |
# Optionally, inform the user that duplicates were removed
|
308 |
if num_uploaded > num_unique:
|
309 |
+
st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.")
|
|
|
|
|
310 |
|
311 |
# Use the unique list
|
312 |
st.session_state["batch_files"] = unique_file_list
|
|
|
314 |
f"{num_unique} ready for batch analysis"
|
315 |
)
|
316 |
st.session_state["status_type"] = "success"
|
|
|
317 |
else:
|
318 |
st.session_state["batch_files"] = []
|
319 |
# This check prevents resetting the status if files are already staged
|
|
|
323 |
)
|
324 |
st.session_state["status_type"] = "info"
|
325 |
|
326 |
+
# Sample tab
|
327 |
elif mode == "Sample Data":
|
328 |
st.session_state["batch_mode"] = False
|
329 |
sample_files = get_sample_files()
|
|
|
352 |
else:
|
353 |
st.info(msg)
|
354 |
|
|
|
|
|
|
|
355 |
# Safely get model choice from session state
|
356 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
357 |
model = load_model(model_choice)
|
|
|
407 |
st.error(f"Error processing spectrum data: {e}")
|
408 |
|
409 |
|
410 |
+
# //////////////////////////////////////////
|
411 |
|
412 |
|
413 |
def render_results_column():
|
|
|
429 |
filename = st.session_state.get("filename", "Unknown")
|
430 |
|
431 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
432 |
+
# Run inference
|
433 |
if y_resampled is None:
|
434 |
raise ValueError(
|
435 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
|
|
456 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
457 |
)
|
458 |
|
459 |
+
# Get ground truth
|
460 |
true_label_idx = label_file(filename)
|
461 |
true_label_str = (
|
462 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
463 |
if true_label_idx is not None
|
464 |
else "Unknown"
|
465 |
)
|
466 |
+
# Get prediction
|
467 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
468 |
|
469 |
# Enhanced confidence calculation
|
|
|
474 |
)
|
475 |
confidence_desc = confidence_level
|
476 |
else:
|
477 |
+
# Fallback to legacy method
|
478 |
logit_margin = abs(
|
479 |
(logits_list[0] - logits_list[1])
|
480 |
if logits_list is not None and len(logits_list) >= 2
|
|
|
506 |
},
|
507 |
)
|
508 |
|
509 |
+
# Precompute Stats
|
510 |
model_choice = (
|
511 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
512 |
if "model_select" in st.session_state
|
|
|
524 |
if os.path.exists(model_path)
|
525 |
else "N/A"
|
526 |
)
|
|
|
527 |
|
528 |
start_render = time.time()
|
529 |
|
|
|
608 |
""",
|
609 |
unsafe_allow_html=True,
|
610 |
)
|
|
|
611 |
|
612 |
st.divider()
|
613 |
|
614 |
+
# METADATA FOOTER
|
|
|
615 |
st.caption(
|
616 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
617 |
)
|
|
|
|
|
618 |
st.markdown("</div>", unsafe_allow_html=True)
|
619 |
|
620 |
elif active_tab == "Technical":
|
|
|
932 |
"""
|
933 |
)
|
934 |
else:
|
935 |
+
# Getting Started
|
936 |
st.markdown(
|
937 |
"""
|
938 |
##### How to Get Started
|
|
|
962 |
- 🏭 Quality control in manufacturing
|
963 |
"""
|
964 |
)
|
965 |
+
|
966 |
+
|
967 |
+
# //////////////////////////////////////////
|
968 |
+
|
969 |
+
|
970 |
+
def render_comparison_tab():
|
971 |
+
"""Render the multi-model comparison interface"""
|
972 |
+
import streamlit as st
|
973 |
+
import matplotlib.pyplot as plt
|
974 |
+
from models.registry import choices, validate_model_list
|
975 |
+
from utils.results_manager import ResultsManager
|
976 |
+
from core_logic import get_sample_files, run_inference, parse_spectrum_data
|
977 |
+
from utils.preprocessing import preprocess_spectrum
|
978 |
+
from utils.multifile import parse_spectrum_data
|
979 |
+
import numpy as np
|
980 |
+
import time
|
981 |
+
|
982 |
+
st.markdown("### Multi-Model Comparison Analysis")
|
983 |
+
st.markdown(
|
984 |
+
"Compare predictions across different AI models for comprehensive analysis."
|
985 |
+
)
|
986 |
+
|
987 |
+
# Model selection for comparison
|
988 |
+
st.markdown("##### Select Models for Comparison")
|
989 |
+
|
990 |
+
available_models = choices()
|
991 |
+
selected_models = st.multiselect(
|
992 |
+
"Choose models to compare",
|
993 |
+
available_models,
|
994 |
+
default=(
|
995 |
+
available_models[:2] if len(available_models) >= 2 else available_models
|
996 |
+
),
|
997 |
+
help="Select 2 or more models to compare their predictions side-by-side",
|
998 |
+
)
|
999 |
+
|
1000 |
+
if len(selected_models) < 2:
|
1001 |
+
st.warning("⚠️ Please select at least 2 models for comparison.")
|
1002 |
+
|
1003 |
+
# Input selection for comparison
|
1004 |
+
col1, col2 = st.columns([1, 1.5])
|
1005 |
+
|
1006 |
+
with col1:
|
1007 |
+
st.markdown("###### Input Data")
|
1008 |
+
|
1009 |
+
# File upload for comparison
|
1010 |
+
comparison_file = st.file_uploader(
|
1011 |
+
"Upload spectrum for comparison",
|
1012 |
+
type=["txt", "csv", "json"],
|
1013 |
+
key="comparison_file_upload",
|
1014 |
+
help="Upload a spectrum file to test across all selected models",
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
# Or select sample data
|
1018 |
+
selected_sample = None # Initialize with a default value
|
1019 |
+
sample_files = get_sample_files()
|
1020 |
+
if sample_files:
|
1021 |
+
sample_options = ["-- Select Sample --"] + [p.name for p in sample_files]
|
1022 |
+
selected_sample = st.selectbox(
|
1023 |
+
"Or choose sample data", sample_options, key="comparison_sample_select"
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
# Get modality from session state
|
1027 |
+
modality = st.session_state.get("modality_select", "raman")
|
1028 |
+
st.info(f"Using {modality.upper()} preprocessing parameters")
|
1029 |
+
|
1030 |
+
# Run comparison button
|
1031 |
+
run_comparison = st.button(
|
1032 |
+
"Run Multi-Model Comparison",
|
1033 |
+
type="primary",
|
1034 |
+
disabled=not (
|
1035 |
+
comparison_file
|
1036 |
+
or (sample_files and selected_sample != "-- Select Sample --")
|
1037 |
+
),
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
with col2:
|
1041 |
+
st.markdown("###### Comparison Results")
|
1042 |
+
|
1043 |
+
if run_comparison:
|
1044 |
+
# Determine input source
|
1045 |
+
input_text = None
|
1046 |
+
filename = "unknown"
|
1047 |
+
|
1048 |
+
if comparison_file:
|
1049 |
+
raw = comparison_file.read()
|
1050 |
+
input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
1051 |
+
filename = comparison_file.name
|
1052 |
+
elif sample_files and selected_sample != "-- Select Sample --":
|
1053 |
+
sample_path = next(p for p in sample_files if p.name == selected_sample)
|
1054 |
+
with open(sample_path, "r") as f:
|
1055 |
+
input_text = f.read()
|
1056 |
+
filename = selected_sample
|
1057 |
+
|
1058 |
+
if input_text:
|
1059 |
+
try:
|
1060 |
+
# Parse spectrum data
|
1061 |
+
x_raw, y_raw = parse_spectrum_data(
|
1062 |
+
str(input_text), filename or "unknown_filename"
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
# Store results
|
1066 |
+
comparison_results = {}
|
1067 |
+
processing_times = {}
|
1068 |
+
|
1069 |
+
progress_bar = st.progress(0)
|
1070 |
+
status_text = st.empty()
|
1071 |
+
|
1072 |
+
for i, model_name in enumerate(selected_models):
|
1073 |
+
status_text.text(f"Running inference with {model_name}...")
|
1074 |
+
|
1075 |
+
start_time = time.time()
|
1076 |
+
|
1077 |
+
# Preprocess spectrum with modality-specific parameters
|
1078 |
+
_, y_processed = preprocess_spectrum(
|
1079 |
+
x_raw, y_raw, modality=modality, target_len=500
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
# Run inference
|
1083 |
+
prediction, logits_list, probs, inference_time, logits = (
|
1084 |
+
run_inference(y_processed, model_name)
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
processing_time = time.time() - start_time
|
1088 |
+
|
1089 |
+
if prediction is not None:
|
1090 |
+
# Map prediction to class name
|
1091 |
+
class_names = ["Stable", "Weathered"]
|
1092 |
+
predicted_class = (
|
1093 |
+
class_names[int(prediction)]
|
1094 |
+
if prediction < len(class_names)
|
1095 |
+
else f"Class_{prediction}"
|
1096 |
+
)
|
1097 |
+
confidence = (
|
1098 |
+
max(probs)
|
1099 |
+
if probs is not None and len(probs) > 0
|
1100 |
+
else 0.0
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
comparison_results[model_name] = {
|
1104 |
+
"prediction": prediction,
|
1105 |
+
"predicted_class": predicted_class,
|
1106 |
+
"confidence": confidence,
|
1107 |
+
"probs": probs if probs is not None else [],
|
1108 |
+
"logits": (
|
1109 |
+
logits_list if logits_list is not None else []
|
1110 |
+
),
|
1111 |
+
"processing_time": processing_time,
|
1112 |
+
}
|
1113 |
+
processing_times[model_name] = processing_time
|
1114 |
+
|
1115 |
+
progress_bar.progress((i + 1) / len(selected_models))
|
1116 |
+
|
1117 |
+
status_text.text("Comparison complete!")
|
1118 |
+
|
1119 |
+
# Display results
|
1120 |
+
if comparison_results:
|
1121 |
+
st.markdown("###### Model Predictions")
|
1122 |
+
|
1123 |
+
# Create comparison table
|
1124 |
+
import pandas as pd
|
1125 |
+
|
1126 |
+
table_data = []
|
1127 |
+
for model_name, result in comparison_results.items():
|
1128 |
+
row = {
|
1129 |
+
"Model": model_name,
|
1130 |
+
"Prediction": result["predicted_class"],
|
1131 |
+
"Confidence": f"{result['confidence']:.3f}",
|
1132 |
+
"Processing Time (s)": f"{result['processing_time']:.3f}",
|
1133 |
+
}
|
1134 |
+
table_data.append(row)
|
1135 |
+
|
1136 |
+
df = pd.DataFrame(table_data)
|
1137 |
+
st.dataframe(df, use_container_width=True)
|
1138 |
+
|
1139 |
+
# Show confidence comparison
|
1140 |
+
st.markdown("##### Confidence Comparison")
|
1141 |
+
conf_col1, conf_col2 = st.columns(2)
|
1142 |
+
|
1143 |
+
with conf_col1:
|
1144 |
+
# Bar chart of confidences
|
1145 |
+
models = list(comparison_results.keys())
|
1146 |
+
confidences = [
|
1147 |
+
comparison_results[m]["confidence"] for m in models
|
1148 |
+
]
|
1149 |
+
|
1150 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
1151 |
+
bars = ax.bar(
|
1152 |
+
models,
|
1153 |
+
confidences,
|
1154 |
+
alpha=0.7,
|
1155 |
+
color=["steelblue", "orange", "green", "red"][
|
1156 |
+
: len(models)
|
1157 |
+
],
|
1158 |
+
)
|
1159 |
+
ax.set_ylabel("Confidence")
|
1160 |
+
ax.set_title("Model Confidence Comparison")
|
1161 |
+
ax.set_ylim(0, 1)
|
1162 |
+
plt.xticks(rotation=45)
|
1163 |
+
|
1164 |
+
# Add value labels on bars
|
1165 |
+
for bar, conf in zip(bars, confidences):
|
1166 |
+
height = bar.get_height()
|
1167 |
+
ax.text(
|
1168 |
+
bar.get_x() + bar.get_width() / 2.0,
|
1169 |
+
height + 0.01,
|
1170 |
+
f"{conf:.3f}",
|
1171 |
+
ha="center",
|
1172 |
+
va="bottom",
|
1173 |
+
)
|
1174 |
+
|
1175 |
+
plt.tight_layout()
|
1176 |
+
st.pyplot(fig)
|
1177 |
+
|
1178 |
+
with conf_col2:
|
1179 |
+
# Agreement analysis
|
1180 |
+
predictions = [
|
1181 |
+
comparison_results[m]["prediction"] for m in models
|
1182 |
+
]
|
1183 |
+
unique_predictions = set(predictions)
|
1184 |
+
|
1185 |
+
if len(unique_predictions) == 1:
|
1186 |
+
st.success("✅ All models agree on the prediction!")
|
1187 |
+
else:
|
1188 |
+
st.warning("⚠️ Models disagree on the prediction")
|
1189 |
+
|
1190 |
+
# Show prediction distribution
|
1191 |
+
from collections import Counter
|
1192 |
+
|
1193 |
+
pred_counts = Counter(predictions)
|
1194 |
+
|
1195 |
+
st.markdown("**Prediction Distribution:**")
|
1196 |
+
for pred, count in pred_counts.items():
|
1197 |
+
class_name = (
|
1198 |
+
["Stable", "Weathered"][pred]
|
1199 |
+
if pred < 2
|
1200 |
+
else f"Class_{pred}"
|
1201 |
+
)
|
1202 |
+
percentage = (count / len(predictions)) * 100
|
1203 |
+
st.write(
|
1204 |
+
f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
|
1205 |
+
)
|
1206 |
+
|
1207 |
+
# Performance metrics
|
1208 |
+
st.markdown("##### Performance Metrics")
|
1209 |
+
perf_col1, perf_col2 = st.columns(2)
|
1210 |
+
|
1211 |
+
with perf_col1:
|
1212 |
+
avg_time = np.mean(list(processing_times.values()))
|
1213 |
+
fastest_model = min(
|
1214 |
+
processing_times.keys(),
|
1215 |
+
key=lambda k: processing_times[k],
|
1216 |
+
)
|
1217 |
+
slowest_model = max(
|
1218 |
+
processing_times.keys(),
|
1219 |
+
key=lambda k: processing_times[k],
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
st.metric("Average Processing Time", f"{avg_time:.3f}s")
|
1223 |
+
st.metric(
|
1224 |
+
"Fastest Model",
|
1225 |
+
f"{fastest_model}",
|
1226 |
+
f"{processing_times[fastest_model]:.3f}s",
|
1227 |
+
)
|
1228 |
+
st.metric(
|
1229 |
+
"Slowest Model",
|
1230 |
+
f"{slowest_model}",
|
1231 |
+
f"{processing_times[slowest_model]:.3f}s",
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
with perf_col2:
|
1235 |
+
most_confident = max(
|
1236 |
+
comparison_results.keys(),
|
1237 |
+
key=lambda k: comparison_results[k]["confidence"],
|
1238 |
+
)
|
1239 |
+
least_confident = min(
|
1240 |
+
comparison_results.keys(),
|
1241 |
+
key=lambda k: comparison_results[k]["confidence"],
|
1242 |
+
)
|
1243 |
+
|
1244 |
+
st.metric(
|
1245 |
+
"Most Confident",
|
1246 |
+
f"{most_confident}",
|
1247 |
+
f"{comparison_results[most_confident]['confidence']:.3f}",
|
1248 |
+
)
|
1249 |
+
st.metric(
|
1250 |
+
"Least Confident",
|
1251 |
+
f"{least_confident}",
|
1252 |
+
f"{comparison_results[least_confident]['confidence']:.3f}",
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
# Store results in session state for potential export
|
1256 |
+
# Store results in session state for potential export
|
1257 |
+
st.session_state["last_comparison_results"] = {
|
1258 |
+
"filename": filename,
|
1259 |
+
"modality": modality,
|
1260 |
+
"models": comparison_results,
|
1261 |
+
"summary": {
|
1262 |
+
"agreement": len(unique_predictions) == 1,
|
1263 |
+
"avg_processing_time": avg_time,
|
1264 |
+
"fastest_model": fastest_model,
|
1265 |
+
"most_confident": most_confident,
|
1266 |
+
},
|
1267 |
+
}
|
1268 |
+
|
1269 |
+
except Exception as e:
|
1270 |
+
st.error(f"Error during comparison: {str(e)}")
|
1271 |
+
|
1272 |
+
# Show recent comparison results if available
|
1273 |
+
elif "last_comparison_results" in st.session_state:
|
1274 |
+
st.info(
|
1275 |
+
"Previous comparison results available. Upload a new file or select a sample to run new comparison."
|
1276 |
+
)
|
1277 |
+
|
1278 |
+
# Show comparison history
|
1279 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
1280 |
+
if comparison_stats:
|
1281 |
+
st.markdown("#### Comparison History")
|
1282 |
+
|
1283 |
+
with st.expander("View detailed comparison statistics", expanded=False):
|
1284 |
+
# Show model statistics table
|
1285 |
+
stats_data = []
|
1286 |
+
for model_name, stats in comparison_stats.items():
|
1287 |
+
row = {
|
1288 |
+
"Model": model_name,
|
1289 |
+
"Total Predictions": stats["total_predictions"],
|
1290 |
+
"Avg Confidence": f"{stats['avg_confidence']:.3f}",
|
1291 |
+
"Avg Processing Time": f"{stats['avg_processing_time']:.3f}s",
|
1292 |
+
"Accuracy": (
|
1293 |
+
f"{stats['accuracy']:.3f}"
|
1294 |
+
if stats["accuracy"] is not None
|
1295 |
+
else "N/A"
|
1296 |
+
),
|
1297 |
+
}
|
1298 |
+
stats_data.append(row)
|
1299 |
+
|
1300 |
+
if stats_data:
|
1301 |
+
import pandas as pd
|
1302 |
+
|
1303 |
+
stats_df = pd.DataFrame(stats_data)
|
1304 |
+
st.dataframe(stats_df, use_container_width=True)
|
1305 |
+
|
1306 |
+
# Show agreement matrix if multiple models
|
1307 |
+
agreement_matrix = ResultsManager.get_agreement_matrix()
|
1308 |
+
if not agreement_matrix.empty and len(agreement_matrix) > 1:
|
1309 |
+
st.markdown("**Model Agreement Matrix**")
|
1310 |
+
st.dataframe(agreement_matrix.round(3), use_container_width=True)
|
1311 |
+
|
1312 |
+
# Plot agreement heatmap
|
1313 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
1314 |
+
im = ax.imshow(
|
1315 |
+
agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
# Add text annotations
|
1319 |
+
for i in range(len(agreement_matrix)):
|
1320 |
+
for j in range(len(agreement_matrix.columns)):
|
1321 |
+
text = ax.text(
|
1322 |
+
j,
|
1323 |
+
i,
|
1324 |
+
f"{agreement_matrix.iloc[i, j]:.2f}",
|
1325 |
+
ha="center",
|
1326 |
+
va="center",
|
1327 |
+
color="black",
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
ax.set_xticks(range(len(agreement_matrix.columns)))
|
1331 |
+
ax.set_yticks(range(len(agreement_matrix)))
|
1332 |
+
ax.set_xticklabels(agreement_matrix.columns, rotation=45)
|
1333 |
+
ax.set_yticklabels(agreement_matrix.index)
|
1334 |
+
ax.set_title("Model Agreement Matrix")
|
1335 |
+
|
1336 |
+
plt.colorbar(im, ax=ax, label="Agreement Rate")
|
1337 |
+
plt.tight_layout()
|
1338 |
+
st.pyplot(fig)
|
1339 |
+
|
1340 |
+
# Export functionality
|
1341 |
+
if "last_comparison_results" in st.session_state:
|
1342 |
+
st.markdown("##### Export Results")
|
1343 |
+
|
1344 |
+
export_col1, export_col2 = st.columns(2)
|
1345 |
+
|
1346 |
+
with export_col1:
|
1347 |
+
if st.button("📥 Export Comparison (JSON)"):
|
1348 |
+
import json
|
1349 |
+
|
1350 |
+
results = st.session_state["last_comparison_results"]
|
1351 |
+
json_str = json.dumps(results, indent=2, default=str)
|
1352 |
+
st.download_button(
|
1353 |
+
label="Download JSON",
|
1354 |
+
data=json_str,
|
1355 |
+
file_name=f"comparison_{results['filename'].split('.')[0]}.json",
|
1356 |
+
mime="application/json",
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
with export_col2:
|
1360 |
+
if st.button("📊 Export Full Report"):
|
1361 |
+
report = ResultsManager.export_comparison_report()
|
1362 |
+
st.download_button(
|
1363 |
+
label="Download Full Report",
|
1364 |
+
data=report,
|
1365 |
+
file_name="model_comparison_report.json",
|
1366 |
+
mime="application/json",
|
1367 |
+
)
|
1368 |
+
|
1369 |
+
|
1370 |
+
# //////////////////////////////////////////
|
1371 |
+
|
1372 |
+
|
1373 |
+
def render_performance_tab():
|
1374 |
+
"""Render the performance tracking and analysis tab."""
|
1375 |
+
from utils.performance_tracker import display_performance_dashboard
|
1376 |
+
|
1377 |
+
display_performance_dashboard()
|