Spaces:
Running
Running
File size: 7,549 Bytes
dd49e6b edc1cf7 dd49e6b 078ed21 dd49e6b 078ed21 dd49e6b 078ed21 dd49e6b 078ed21 edc1cf7 dd49e6b 078ed21 dd49e6b 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 dd49e6b edc1cf7 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 dd49e6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import os
# --- New Imports ---
from config import MODEL_CONFIG, TARGET_LEN
import time
import gc
import torch
import torch.nn.functional as F
import numpy as np
import streamlit as st
from pathlib import Path
from config import SAMPLE_DATA_DIR
from datetime import datetime
def label_file(filename: str) -> int:
"""Extract label from filename based on naming convention"""
name = Path(filename).name.lower()
if name.startswith("sta"):
return 0
elif name.startswith("wea"):
return 1
else:
# Return None for unknown patterns instead of raising error
return -1 # Default value for unknown patterns
@st.cache_data
def load_state_dict(_mtime, model_path):
"""Load state dict with mtime in cache key to detect file changes"""
try:
return torch.load(model_path, map_location="cpu")
except (FileNotFoundError, RuntimeError) as e:
st.warning(f"Error loading state dict: {e}")
return None
@st.cache_resource
def load_model(model_name):
"""Load and cache the specified model with error handling"""
try:
config = MODEL_CONFIG[model_name]
model_class = config["class"]
model_path = config["path"]
# Initialize model
model = model_class(input_length=TARGET_LEN)
# Check if model file exists
if not os.path.exists(model_path):
st.warning(f"⚠️ Model weights not found: {model_path}")
st.info("Using randomly initialized model for demonstration purposes.")
return model, False
# Get mtime for cache invalidation
mtime = os.path.getmtime(model_path)
# Load weights
state_dict = load_state_dict(mtime, model_path)
if state_dict:
model.load_state_dict(state_dict, strict=True)
if model is None:
raise ValueError(
"Model is not loaded. Please check the model configuration or weights."
)
if model is None:
raise ValueError(
"Model is not loaded. Please check the model configuration or weights."
)
if model is None:
raise ValueError(
"Model is not loaded. Please check the model configuration or weights."
)
model.eval()
return model, True
else:
return model, False
except (FileNotFoundError, KeyError, RuntimeError) as e:
st.error(f"❌ Error loading model {model_name}: {str(e)}")
return None, False
def cleanup_memory():
"""Clean up memory after inference"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@st.cache_data
def run_inference(y_resampled, model_choice, _cache_key=None):
"""Run model inference and cache results with performance tracking"""
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
from datetime import datetime
model, model_loaded = load_model(model_choice)
if not model_loaded:
return None, None, None, None, None
# Performance tracking setup
tracker = get_performance_tracker()
input_tensor = (
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
)
# Track inference performance
start_time = time.time()
start_memory = _get_memory_usage()
model.eval() # type: ignore
with torch.no_grad():
if model is None:
raise ValueError(
"Model is not loaded. Please check the model configuration or weights."
)
logits = model(input_tensor)
prediction = torch.argmax(logits, dim=1).item()
logits_list = logits.detach().numpy().tolist()[0]
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
inference_time = time.time() - start_time
end_memory = _get_memory_usage()
memory_usage = max(end_memory - start_memory, 0)
# Log performance metrics
try:
modality = st.session_state.get("modality_select", "raman")
confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
metrics = PerformanceMetrics(
model_name=model_choice,
prediction_time=inference_time,
preprocessing_time=0.0, # Will be updated by calling function if available
total_time=inference_time,
memory_usage_mb=memory_usage,
accuracy=None, # Will be updated if ground truth is available
confidence=confidence,
timestamp=datetime.now().isoformat(),
input_size=(
len(y_resampled) if hasattr(y_resampled, "__len__") else TARGET_LEN
),
modality=modality,
)
tracker.log_performance(metrics)
except (AttributeError, ValueError, KeyError) as e:
# Don't fail inference if performance tracking fails
print(f"Performance tracking failed: {e}")
cleanup_memory()
return prediction, logits_list, probs, inference_time, logits
def _get_memory_usage() -> float:
"""Get current memory usage in MB"""
try:
import psutil
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024 # Convert to MB
except ImportError:
return 0.0 # psutil not available
@st.cache_data
def get_sample_files():
"""Get list of sample files if available"""
sample_dir = Path(SAMPLE_DATA_DIR)
if sample_dir.exists():
return sorted(list(sample_dir.glob("*.txt")))
return []
def parse_spectrum_data(raw_text):
"""Parse spectrum data from text with robust error handling and validation"""
x_vals, y_vals = [], []
for line in raw_text.splitlines():
line = line.strip()
if not line or line.startswith("#"): # Skip empty lines and comments
continue
try:
# Handle different separators
parts = line.replace(",", " ").split()
numbers = [
p
for p in parts
if p.replace(".", "", 1)
.replace("-", "", 1)
.replace("+", "", 1)
.isdigit()
]
if len(numbers) >= 2:
x, y = float(numbers[0]), float(numbers[1])
x_vals.append(x)
y_vals.append(y)
except ValueError:
# Skip problematic lines but don't fail completely
continue
if len(x_vals) < 10: # Minimum reasonable spectrum length
raise ValueError(
f"Insufficient data points: {len(x_vals)}. Need at least 10 points."
)
x = np.array(x_vals)
y = np.array(y_vals)
# Check for NaNs
if np.any(np.isnan(x)) or np.any(np.isnan(y)):
raise ValueError("Input data contains NaN values")
# Check monotonic increasing x
if not np.all(np.diff(x) > 0):
raise ValueError("Wavenumbers must be strictly increasing")
# Check reasonable range for Raman spectroscopy
if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
raise ValueError(
f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100"
)
return x, y
|