Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
Β·
22d9362
1
Parent(s):
182c9ce
(CLEAN): remove 'torch.tensor(logits)' misuse to fix softmax warning
Browse files- Replaced incorrect torch.tensor(logits) wrapping with logits.detach()
- Eliminated 'UserWarning' in streamlit run logs
- Applied 'flatten()' after softmax for consistent display shape
- Final polishing for Step 1 completion
app.py
CHANGED
|
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
|
|
| 9 |
import matplotlib
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
|
|
| 12 |
import streamlit as st
|
| 13 |
import os
|
| 14 |
import sys
|
|
@@ -256,6 +257,11 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
|
|
| 256 |
|
| 257 |
return Image.open(buf)
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
def get_confidence_description(logit_margin):
|
| 261 |
"""Get human-readable confidence description"""
|
|
@@ -268,7 +274,6 @@ def get_confidence_description(logit_margin):
|
|
| 268 |
else:
|
| 269 |
return "LOW", "π΄"
|
| 270 |
|
| 271 |
-
|
| 272 |
def log_message(msg: str):
|
| 273 |
"""Append a timestamped line to the in-app log, creating the buffer if needed."""
|
| 274 |
if "log_messages" not in st.session_state or st.session_state["log_messages"] is None:
|
|
@@ -277,15 +282,10 @@ def log_message(msg: str):
|
|
| 277 |
f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}"
|
| 278 |
)
|
| 279 |
|
| 280 |
-
|
| 281 |
def trigger_run():
|
| 282 |
"""Set a flag so we can detect button press reliably across reruns"""
|
| 283 |
st.session_state['run_requested'] = True
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
def on_sample_change():
|
| 290 |
"""Read selected sample once and persist as text."""
|
| 291 |
sel = st.session_state.get("sample_select", "-- Select Sample --")
|
|
@@ -304,7 +304,6 @@ def on_sample_change():
|
|
| 304 |
st.session_state["status_message"] = f"β Error loading sample: {e}"
|
| 305 |
st.session_state["status_type"] = "error"
|
| 306 |
|
| 307 |
-
|
| 308 |
def on_input_mode_change():
|
| 309 |
"""Reset sample when switching to Upload"""
|
| 310 |
if st.session_state["input_mode"] == "Upload File":
|
|
@@ -312,12 +311,10 @@ def on_input_mode_change():
|
|
| 312 |
# π§ Reset when switching modes to prevent stale right-column visuals
|
| 313 |
reset_results("Switched input mode")
|
| 314 |
|
| 315 |
-
|
| 316 |
def on_model_change():
|
| 317 |
"""Force the right column back to init state when the model changes"""
|
| 318 |
reset_results("Model changed")
|
| 319 |
|
| 320 |
-
|
| 321 |
def reset_results(reason: str = ""):
|
| 322 |
"""Clear previous inference artifacts so the right column returns to initial state."""
|
| 323 |
st.session_state["inference_run_once"] = False
|
|
@@ -359,6 +356,23 @@ def reset_ephemeral_state():
|
|
| 359 |
|
| 360 |
st.rerun()
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
# Main app
|
| 363 |
def main():
|
| 364 |
init_session_state()
|
|
@@ -617,6 +631,9 @@ def main():
|
|
| 617 |
prediction = torch.argmax(logits, dim=1).item()
|
| 618 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 619 |
|
|
|
|
|
|
|
|
|
|
| 620 |
inference_time = time.time() - start_time
|
| 621 |
log_message(
|
| 622 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
|
|
@@ -671,7 +688,11 @@ def main():
|
|
| 671 |
st.info(
|
| 672 |
"βΉοΈ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
|
| 673 |
|
| 674 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
tab1, tab2, tab3 = st.tabs(
|
| 676 |
["π Details", "π¬ Technical", "π Explanation"])
|
| 677 |
|
|
|
|
| 9 |
import matplotlib
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
import streamlit as st
|
| 14 |
import os
|
| 15 |
import sys
|
|
|
|
| 257 |
|
| 258 |
return Image.open(buf)
|
| 259 |
|
| 260 |
+
def render_confidence_bar(probabilities, class_labels):
|
| 261 |
+
bar = lambda p: "β" * int(p * 20)
|
| 262 |
+
for label, prob in zip(class_labels, probabilities):
|
| 263 |
+
st.write(f"**{label}**: {bar(prob)} {prob*100:.1f}%")
|
| 264 |
+
|
| 265 |
|
| 266 |
def get_confidence_description(logit_margin):
|
| 267 |
"""Get human-readable confidence description"""
|
|
|
|
| 274 |
else:
|
| 275 |
return "LOW", "π΄"
|
| 276 |
|
|
|
|
| 277 |
def log_message(msg: str):
|
| 278 |
"""Append a timestamped line to the in-app log, creating the buffer if needed."""
|
| 279 |
if "log_messages" not in st.session_state or st.session_state["log_messages"] is None:
|
|
|
|
| 282 |
f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}"
|
| 283 |
)
|
| 284 |
|
|
|
|
| 285 |
def trigger_run():
|
| 286 |
"""Set a flag so we can detect button press reliably across reruns"""
|
| 287 |
st.session_state['run_requested'] = True
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
def on_sample_change():
|
| 290 |
"""Read selected sample once and persist as text."""
|
| 291 |
sel = st.session_state.get("sample_select", "-- Select Sample --")
|
|
|
|
| 304 |
st.session_state["status_message"] = f"β Error loading sample: {e}"
|
| 305 |
st.session_state["status_type"] = "error"
|
| 306 |
|
|
|
|
| 307 |
def on_input_mode_change():
|
| 308 |
"""Reset sample when switching to Upload"""
|
| 309 |
if st.session_state["input_mode"] == "Upload File":
|
|
|
|
| 311 |
# π§ Reset when switching modes to prevent stale right-column visuals
|
| 312 |
reset_results("Switched input mode")
|
| 313 |
|
|
|
|
| 314 |
def on_model_change():
|
| 315 |
"""Force the right column back to init state when the model changes"""
|
| 316 |
reset_results("Model changed")
|
| 317 |
|
|
|
|
| 318 |
def reset_results(reason: str = ""):
|
| 319 |
"""Clear previous inference artifacts so the right column returns to initial state."""
|
| 320 |
st.session_state["inference_run_once"] = False
|
|
|
|
| 356 |
|
| 357 |
st.rerun()
|
| 358 |
|
| 359 |
+
def plot_confidence_bar(probabilities: list[float], class_labels: list[str]) -> None:
|
| 360 |
+
"""Renders a horizontal bar chart of prediction confidences per class."""
|
| 361 |
+
fig, ax = plt.subplots(figsize=(4, 1.5))
|
| 362 |
+
bars = ax.barh(class_labels, probabilities, color=[
|
| 363 |
+
"green" if i == np.argmax(probabilities) else "gray"
|
| 364 |
+
for i in range(len(probabilities))
|
| 365 |
+
])
|
| 366 |
+
ax.set_xlabel("Confidence")
|
| 367 |
+
ax.set_title("Prediction Confidence")
|
| 368 |
+
ax.xaxis.set_ticks([0, 0.5, 1.0])
|
| 369 |
+
ax.set_xlim(0, 1.0)
|
| 370 |
+
for i, (label, prob) in enumerate(zip(class_labels, probabilities)):
|
| 371 |
+
ax.text(prob + 0.01, i, f"{prob*100:.1f}%", va='center', fontsize=8)
|
| 372 |
+
|
| 373 |
+
st.pyplot(fig)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
# Main app
|
| 377 |
def main():
|
| 378 |
init_session_state()
|
|
|
|
| 631 |
prediction = torch.argmax(logits, dim=1).item()
|
| 632 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 633 |
|
| 634 |
+
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
| 635 |
+
|
| 636 |
+
|
| 637 |
inference_time = time.time() - start_time
|
| 638 |
log_message(
|
| 639 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
|
|
|
|
| 688 |
st.info(
|
| 689 |
"βΉοΈ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
|
| 690 |
|
| 691 |
+
# ===display confidence results===
|
| 692 |
+
class_labels = ["Stable", "Weathered"]
|
| 693 |
+
plot_confidence_bar(probabilities=probs.tolist(), class_labels=class_labels)
|
| 694 |
+
|
| 695 |
+
# ===Detailed results tabs===
|
| 696 |
tab1, tab2, tab3 = st.tabs(
|
| 697 |
["π Details", "π¬ Technical", "π Explanation"])
|
| 698 |
|