Spaces:
Running
Running
devjas1
commited on
Commit
·
078ed21
1
Parent(s):
59c6133
(FEAT/REFAC)[Inference Logic]: Integrate performance tracking into model inference
Browse files- Updated 'run_inference' to log performance metrics ('time', 'memory', 'confidence', 'input size', 'modality')
- Performance metrics are captured using a new tracker, storing data such as model name, inference/preprocessing time, memory usage, and timestamp
- Added error handling to avoid inference failure if tracking fails
- Introduced '_get_memory_usage()' to retrieve memory using in MB using 'psutil', with fallback if unavailable
- core_logic.py +47 -1
core_logic.py
CHANGED
@@ -89,15 +89,25 @@ def cleanup_memory():
|
|
89 |
|
90 |
@st.cache_data
|
91 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
92 |
-
"""Run model inference and cache results"""
|
|
|
|
|
|
|
93 |
model, model_loaded = load_model(model_choice)
|
94 |
if not model_loaded:
|
95 |
return None, None, None, None, None
|
96 |
|
|
|
|
|
|
|
97 |
input_tensor = (
|
98 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
99 |
)
|
|
|
|
|
100 |
start_time = time.time()
|
|
|
|
|
101 |
model.eval()
|
102 |
with torch.no_grad():
|
103 |
if model is None:
|
@@ -108,10 +118,46 @@ def run_inference(y_resampled, model_choice, _cache_key=None):
|
|
108 |
prediction = torch.argmax(logits, dim=1).item()
|
109 |
logits_list = logits.detach().numpy().tolist()[0]
|
110 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
|
|
111 |
inference_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
cleanup_memory()
|
113 |
return prediction, logits_list, probs, inference_time, logits
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
@st.cache_data
|
117 |
def get_sample_files():
|
|
|
89 |
|
90 |
@st.cache_data
|
91 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
92 |
+
"""Run model inference and cache results with performance tracking"""
|
93 |
+
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
|
94 |
+
from datetime import datetime
|
95 |
+
|
96 |
model, model_loaded = load_model(model_choice)
|
97 |
if not model_loaded:
|
98 |
return None, None, None, None, None
|
99 |
|
100 |
+
# Performance tracking setup
|
101 |
+
tracker = get_performance_tracker()
|
102 |
+
|
103 |
input_tensor = (
|
104 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
105 |
)
|
106 |
+
|
107 |
+
# Track inference performance
|
108 |
start_time = time.time()
|
109 |
+
start_memory = _get_memory_usage()
|
110 |
+
|
111 |
model.eval()
|
112 |
with torch.no_grad():
|
113 |
if model is None:
|
|
|
118 |
prediction = torch.argmax(logits, dim=1).item()
|
119 |
logits_list = logits.detach().numpy().tolist()[0]
|
120 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
121 |
+
|
122 |
inference_time = time.time() - start_time
|
123 |
+
end_memory = _get_memory_usage()
|
124 |
+
memory_usage = max(end_memory - start_memory, 0)
|
125 |
+
|
126 |
+
# Log performance metrics
|
127 |
+
try:
|
128 |
+
modality = st.session_state.get("modality_select", "raman")
|
129 |
+
confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
|
130 |
+
|
131 |
+
metrics = PerformanceMetrics(
|
132 |
+
model_name=model_choice,
|
133 |
+
prediction_time=inference_time,
|
134 |
+
preprocessing_time=0.0, # Will be updated by calling function if available
|
135 |
+
total_time=inference_time,
|
136 |
+
memory_usage_mb=memory_usage,
|
137 |
+
accuracy=None, # Will be updated if ground truth is available
|
138 |
+
confidence=confidence,
|
139 |
+
timestamp=datetime.not().isofformat(),
|
140 |
+
input_size=len(y_resampled) if hasattr(y_resampled, '__len__') else 500,
|
141 |
+
modality=modality,
|
142 |
+
)
|
143 |
+
tracker.log_perfomance(metrics)
|
144 |
+
except Exception as e:
|
145 |
+
# Dont fail inference if performance tracking fails
|
146 |
+
print(f"Performance tracking failed: {e}")
|
147 |
+
|
148 |
cleanup_memory()
|
149 |
return prediction, logits_list, probs, inference_time, logits
|
150 |
|
151 |
+
def _get_memory_usage() -> float:
|
152 |
+
"""Get current memory usage in MB"""
|
153 |
+
try:
|
154 |
+
import psutil
|
155 |
+
process = psutil.Process()
|
156 |
+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
|
157 |
+
except ImportError:
|
158 |
+
return 0.0 # psutil not available
|
159 |
+
|
160 |
+
|
161 |
|
162 |
@st.cache_data
|
163 |
def get_sample_files():
|