Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
078ed21
1
Parent(s):
59c6133
(FEAT/REFAC)[Inference Logic]: Integrate performance tracking into model inference
Browse files- Updated 'run_inference' to log performance metrics ('time', 'memory', 'confidence', 'input size', 'modality')
- Performance metrics are captured using a new tracker, storing data such as model name, inference/preprocessing time, memory usage, and timestamp
- Added error handling to avoid inference failure if tracking fails
- Introduced '_get_memory_usage()' to retrieve memory using in MB using 'psutil', with fallback if unavailable
- core_logic.py +47 -1
core_logic.py
CHANGED
|
@@ -89,15 +89,25 @@ def cleanup_memory():
|
|
| 89 |
|
| 90 |
@st.cache_data
|
| 91 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
| 92 |
-
"""Run model inference and cache results"""
|
|
|
|
|
|
|
|
|
|
| 93 |
model, model_loaded = load_model(model_choice)
|
| 94 |
if not model_loaded:
|
| 95 |
return None, None, None, None, None
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
input_tensor = (
|
| 98 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 99 |
)
|
|
|
|
|
|
|
| 100 |
start_time = time.time()
|
|
|
|
|
|
|
| 101 |
model.eval()
|
| 102 |
with torch.no_grad():
|
| 103 |
if model is None:
|
|
@@ -108,10 +118,46 @@ def run_inference(y_resampled, model_choice, _cache_key=None):
|
|
| 108 |
prediction = torch.argmax(logits, dim=1).item()
|
| 109 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 110 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
|
|
|
| 111 |
inference_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
cleanup_memory()
|
| 113 |
return prediction, logits_list, probs, inference_time, logits
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
@st.cache_data
|
| 117 |
def get_sample_files():
|
|
|
|
| 89 |
|
| 90 |
@st.cache_data
|
| 91 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
| 92 |
+
"""Run model inference and cache results with performance tracking"""
|
| 93 |
+
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
|
| 94 |
+
from datetime import datetime
|
| 95 |
+
|
| 96 |
model, model_loaded = load_model(model_choice)
|
| 97 |
if not model_loaded:
|
| 98 |
return None, None, None, None, None
|
| 99 |
|
| 100 |
+
# Performance tracking setup
|
| 101 |
+
tracker = get_performance_tracker()
|
| 102 |
+
|
| 103 |
input_tensor = (
|
| 104 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 105 |
)
|
| 106 |
+
|
| 107 |
+
# Track inference performance
|
| 108 |
start_time = time.time()
|
| 109 |
+
start_memory = _get_memory_usage()
|
| 110 |
+
|
| 111 |
model.eval()
|
| 112 |
with torch.no_grad():
|
| 113 |
if model is None:
|
|
|
|
| 118 |
prediction = torch.argmax(logits, dim=1).item()
|
| 119 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 120 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
| 121 |
+
|
| 122 |
inference_time = time.time() - start_time
|
| 123 |
+
end_memory = _get_memory_usage()
|
| 124 |
+
memory_usage = max(end_memory - start_memory, 0)
|
| 125 |
+
|
| 126 |
+
# Log performance metrics
|
| 127 |
+
try:
|
| 128 |
+
modality = st.session_state.get("modality_select", "raman")
|
| 129 |
+
confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
|
| 130 |
+
|
| 131 |
+
metrics = PerformanceMetrics(
|
| 132 |
+
model_name=model_choice,
|
| 133 |
+
prediction_time=inference_time,
|
| 134 |
+
preprocessing_time=0.0, # Will be updated by calling function if available
|
| 135 |
+
total_time=inference_time,
|
| 136 |
+
memory_usage_mb=memory_usage,
|
| 137 |
+
accuracy=None, # Will be updated if ground truth is available
|
| 138 |
+
confidence=confidence,
|
| 139 |
+
timestamp=datetime.not().isofformat(),
|
| 140 |
+
input_size=len(y_resampled) if hasattr(y_resampled, '__len__') else 500,
|
| 141 |
+
modality=modality,
|
| 142 |
+
)
|
| 143 |
+
tracker.log_perfomance(metrics)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
# Dont fail inference if performance tracking fails
|
| 146 |
+
print(f"Performance tracking failed: {e}")
|
| 147 |
+
|
| 148 |
cleanup_memory()
|
| 149 |
return prediction, logits_list, probs, inference_time, logits
|
| 150 |
|
| 151 |
+
def _get_memory_usage() -> float:
|
| 152 |
+
"""Get current memory usage in MB"""
|
| 153 |
+
try:
|
| 154 |
+
import psutil
|
| 155 |
+
process = psutil.Process()
|
| 156 |
+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
|
| 157 |
+
except ImportError:
|
| 158 |
+
return 0.0 # psutil not available
|
| 159 |
+
|
| 160 |
+
|
| 161 |
|
| 162 |
@st.cache_data
|
| 163 |
def get_sample_files():
|