devjas1 commited on
Commit
22d9362
Β·
1 Parent(s): 182c9ce

(CLEAN): remove 'torch.tensor(logits)' misuse to fix softmax warning

Browse files

- Replaced incorrect torch.tensor(logits) wrapping with logits.detach()
- Eliminated 'UserWarning' in streamlit run logs
- Applied 'flatten()' after softmax for consistent display shape
- Final polishing for Step 1 completion

Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
9
  import matplotlib
10
  import numpy as np
11
  import torch
 
12
  import streamlit as st
13
  import os
14
  import sys
@@ -256,6 +257,11 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
256
 
257
  return Image.open(buf)
258
 
 
 
 
 
 
259
 
260
  def get_confidence_description(logit_margin):
261
  """Get human-readable confidence description"""
@@ -268,7 +274,6 @@ def get_confidence_description(logit_margin):
268
  else:
269
  return "LOW", "πŸ”΄"
270
 
271
-
272
  def log_message(msg: str):
273
  """Append a timestamped line to the in-app log, creating the buffer if needed."""
274
  if "log_messages" not in st.session_state or st.session_state["log_messages"] is None:
@@ -277,15 +282,10 @@ def log_message(msg: str):
277
  f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}"
278
  )
279
 
280
-
281
  def trigger_run():
282
  """Set a flag so we can detect button press reliably across reruns"""
283
  st.session_state['run_requested'] = True
284
 
285
-
286
-
287
-
288
-
289
  def on_sample_change():
290
  """Read selected sample once and persist as text."""
291
  sel = st.session_state.get("sample_select", "-- Select Sample --")
@@ -304,7 +304,6 @@ def on_sample_change():
304
  st.session_state["status_message"] = f"❌ Error loading sample: {e}"
305
  st.session_state["status_type"] = "error"
306
 
307
-
308
  def on_input_mode_change():
309
  """Reset sample when switching to Upload"""
310
  if st.session_state["input_mode"] == "Upload File":
@@ -312,12 +311,10 @@ def on_input_mode_change():
312
  # πŸ”§ Reset when switching modes to prevent stale right-column visuals
313
  reset_results("Switched input mode")
314
 
315
-
316
  def on_model_change():
317
  """Force the right column back to init state when the model changes"""
318
  reset_results("Model changed")
319
 
320
-
321
  def reset_results(reason: str = ""):
322
  """Clear previous inference artifacts so the right column returns to initial state."""
323
  st.session_state["inference_run_once"] = False
@@ -359,6 +356,23 @@ def reset_ephemeral_state():
359
 
360
  st.rerun()
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  # Main app
363
  def main():
364
  init_session_state()
@@ -617,6 +631,9 @@ def main():
617
  prediction = torch.argmax(logits, dim=1).item()
618
  logits_list = logits.detach().numpy().tolist()[0]
619
 
 
 
 
620
  inference_time = time.time() - start_time
621
  log_message(
622
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
@@ -671,7 +688,11 @@ def main():
671
  st.info(
672
  "ℹ️ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
673
 
674
- # Detailed results tabs
 
 
 
 
675
  tab1, tab2, tab3 = st.tabs(
676
  ["πŸ“Š Details", "πŸ”¬ Technical", "πŸ“˜ Explanation"])
677
 
 
9
  import matplotlib
10
  import numpy as np
11
  import torch
12
+ import torch.nn.functional as F
13
  import streamlit as st
14
  import os
15
  import sys
 
257
 
258
  return Image.open(buf)
259
 
260
+ def render_confidence_bar(probabilities, class_labels):
261
+ bar = lambda p: "β–ˆ" * int(p * 20)
262
+ for label, prob in zip(class_labels, probabilities):
263
+ st.write(f"**{label}**: {bar(prob)} {prob*100:.1f}%")
264
+
265
 
266
  def get_confidence_description(logit_margin):
267
  """Get human-readable confidence description"""
 
274
  else:
275
  return "LOW", "πŸ”΄"
276
 
 
277
  def log_message(msg: str):
278
  """Append a timestamped line to the in-app log, creating the buffer if needed."""
279
  if "log_messages" not in st.session_state or st.session_state["log_messages"] is None:
 
282
  f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}"
283
  )
284
 
 
285
  def trigger_run():
286
  """Set a flag so we can detect button press reliably across reruns"""
287
  st.session_state['run_requested'] = True
288
 
 
 
 
 
289
  def on_sample_change():
290
  """Read selected sample once and persist as text."""
291
  sel = st.session_state.get("sample_select", "-- Select Sample --")
 
304
  st.session_state["status_message"] = f"❌ Error loading sample: {e}"
305
  st.session_state["status_type"] = "error"
306
 
 
307
  def on_input_mode_change():
308
  """Reset sample when switching to Upload"""
309
  if st.session_state["input_mode"] == "Upload File":
 
311
  # πŸ”§ Reset when switching modes to prevent stale right-column visuals
312
  reset_results("Switched input mode")
313
 
 
314
  def on_model_change():
315
  """Force the right column back to init state when the model changes"""
316
  reset_results("Model changed")
317
 
 
318
  def reset_results(reason: str = ""):
319
  """Clear previous inference artifacts so the right column returns to initial state."""
320
  st.session_state["inference_run_once"] = False
 
356
 
357
  st.rerun()
358
 
359
+ def plot_confidence_bar(probabilities: list[float], class_labels: list[str]) -> None:
360
+ """Renders a horizontal bar chart of prediction confidences per class."""
361
+ fig, ax = plt.subplots(figsize=(4, 1.5))
362
+ bars = ax.barh(class_labels, probabilities, color=[
363
+ "green" if i == np.argmax(probabilities) else "gray"
364
+ for i in range(len(probabilities))
365
+ ])
366
+ ax.set_xlabel("Confidence")
367
+ ax.set_title("Prediction Confidence")
368
+ ax.xaxis.set_ticks([0, 0.5, 1.0])
369
+ ax.set_xlim(0, 1.0)
370
+ for i, (label, prob) in enumerate(zip(class_labels, probabilities)):
371
+ ax.text(prob + 0.01, i, f"{prob*100:.1f}%", va='center', fontsize=8)
372
+
373
+ st.pyplot(fig)
374
+
375
+
376
  # Main app
377
  def main():
378
  init_session_state()
 
631
  prediction = torch.argmax(logits, dim=1).item()
632
  logits_list = logits.detach().numpy().tolist()[0]
633
 
634
+ probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
635
+
636
+
637
  inference_time = time.time() - start_time
638
  log_message(
639
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
 
688
  st.info(
689
  "ℹ️ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
690
 
691
+ # ===display confidence results===
692
+ class_labels = ["Stable", "Weathered"]
693
+ plot_confidence_bar(probabilities=probs.tolist(), class_labels=class_labels)
694
+
695
+ # ===Detailed results tabs===
696
  tab1, tab2, tab3 = st.tabs(
697
  ["πŸ“Š Details", "πŸ”¬ Technical", "πŸ“˜ Explanation"])
698