Spaces:
Running
Running
File size: 5,675 Bytes
dd49e6b e469fd8 dd49e6b e469fd8 dd49e6b fe030dd dd49e6b e469fd8 fe030dd e469fd8 5054409 e469fd8 dd49e6b e469fd8 5054409 dd49e6b fe030dd 078ed21 dd49e6b 078ed21 dd49e6b 078ed21 dd49e6b 078ed21 edc1cf7 dd49e6b 078ed21 dd49e6b 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 dd49e6b edc1cf7 078ed21 edc1cf7 078ed21 edc1cf7 078ed21 dd49e6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
# --- New Imports ---
from config import TARGET_LEN
import time
import gc
import torch
import torch.nn.functional as F
import numpy as np
import streamlit as st
from pathlib import Path
from config import SAMPLE_DATA_DIR
from models.registry import build, choices
def label_file(filename: str) -> int:
"""Extract label from filename based on naming convention"""
name = Path(filename).name.lower()
if name.startswith("sta"):
return 0
elif name.startswith("wea"):
return 1
else:
# Return None for unknown patterns instead of raising error
return -1 # Default value for unknown patterns
@st.cache_data
def load_state_dict(mtime, model_path):
"""Load state dict with mtime in cache key to detect file changes"""
try:
return torch.load(model_path, map_location="cpu")
except (FileNotFoundError, RuntimeError) as e:
st.warning(f"Error loading state dict: {e}")
return None
@st.cache_resource
def load_model(model_name):
# First try registry system (new approach)
if model_name in choices():
# Use registry system
model = build(model_name, TARGET_LEN)
# Try to load weights from standard locations
weight_paths = [
f"model_weights/{model_name}_model.pth",
f"outputs/{model_name}_model.pth",
f"model_weights/{model_name}.pth",
f"outputs/{model_name}.pth",
]
weights_loaded = False
for weight_path in weight_paths:
if os.path.exists(weight_path):
try:
mtime = os.path.getmtime(weight_path)
state_dict = load_state_dict(mtime, weight_path)
if state_dict:
model.load_state_dict(state_dict, strict=True)
model.eval()
weights_loaded = True
break # Exit loop after successful load
except (OSError, RuntimeError):
continue
if not weights_loaded:
st.warning(
f"⚠️ Model weights not found for '{model_name}'. Using randomly initialized model."
)
st.info(
"This model will provide random predictions for demonstration purposes."
)
return model, weights_loaded
# If model not in registry, raise error
st.error(f"Unknown model '{model_name}'. Available models: {choices()}")
return None, False
def cleanup_memory():
"""Clean up memory after inference"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@st.cache_data
def run_inference(y_resampled, model_choice, modality: str, cache_key=None):
"""Run model inference and cache results with performance tracking"""
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
from datetime import datetime
model, model_loaded = load_model(model_choice)
if not model_loaded:
return None, None, None, None, None
# Performance tracking setup
tracker = get_performance_tracker()
input_tensor = (
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
)
# Track inference performance
start_time = time.time()
start_memory = _get_memory_usage()
model.eval() # type: ignore
with torch.no_grad():
if model is None:
raise ValueError(
"Model is not loaded. Please check the model configuration or weights."
)
logits = model(input_tensor)
prediction = torch.argmax(logits, dim=1).item()
logits_list = logits.detach().numpy().tolist()[0]
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
inference_time = time.time() - start_time
end_memory = _get_memory_usage()
memory_usage = max(end_memory - start_memory, 0)
# Log performance metrics
try:
confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
metrics = PerformanceMetrics(
model_name=model_choice,
prediction_time=inference_time,
preprocessing_time=0.0, # Will be updated by calling function if available
total_time=inference_time,
memory_usage_mb=memory_usage,
accuracy=None, # Will be updated if ground truth is available
confidence=confidence,
timestamp=datetime.now().isoformat(),
input_size=(
len(y_resampled) if hasattr(y_resampled, "__len__") else TARGET_LEN
),
modality=modality,
)
tracker.log_performance(metrics)
except (AttributeError, ValueError, KeyError) as e:
# Don't fail inference if performance tracking fails
print(f"Performance tracking failed: {e}")
cleanup_memory()
return prediction, logits_list, probs, inference_time, logits
def _get_memory_usage() -> float:
"""Get current memory usage in MB"""
try:
import psutil
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024 # Convert to MB
except ImportError:
return 0.0 # psutil not available
@st.cache_data
def get_sample_files():
"""Get list of sample files if available"""
sample_dir = Path(SAMPLE_DATA_DIR)
if sample_dir.exists():
return sorted(list(sample_dir.glob("*.txt")))
return []
|