devjas1 commited on
Commit
078ed21
·
1 Parent(s): 59c6133

(FEAT/REFAC)[Inference Logic]: Integrate performance tracking into model inference

Browse files

- Updated 'run_inference' to log performance metrics ('time', 'memory', 'confidence', 'input size', 'modality')
- Performance metrics are captured using a new tracker, storing data such as model name, inference/preprocessing time, memory usage, and timestamp
- Added error handling to avoid inference failure if tracking fails
- Introduced '_get_memory_usage()' to retrieve memory using in MB using 'psutil', with fallback if unavailable

Files changed (1) hide show
  1. core_logic.py +47 -1
core_logic.py CHANGED
@@ -89,15 +89,25 @@ def cleanup_memory():
89
 
90
  @st.cache_data
91
  def run_inference(y_resampled, model_choice, _cache_key=None):
92
- """Run model inference and cache results"""
 
 
 
93
  model, model_loaded = load_model(model_choice)
94
  if not model_loaded:
95
  return None, None, None, None, None
96
 
 
 
 
97
  input_tensor = (
98
  torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
99
  )
 
 
100
  start_time = time.time()
 
 
101
  model.eval()
102
  with torch.no_grad():
103
  if model is None:
@@ -108,10 +118,46 @@ def run_inference(y_resampled, model_choice, _cache_key=None):
108
  prediction = torch.argmax(logits, dim=1).item()
109
  logits_list = logits.detach().numpy().tolist()[0]
110
  probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
 
111
  inference_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  cleanup_memory()
113
  return prediction, logits_list, probs, inference_time, logits
114
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  @st.cache_data
117
  def get_sample_files():
 
89
 
90
  @st.cache_data
91
  def run_inference(y_resampled, model_choice, _cache_key=None):
92
+ """Run model inference and cache results with performance tracking"""
93
+ from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
94
+ from datetime import datetime
95
+
96
  model, model_loaded = load_model(model_choice)
97
  if not model_loaded:
98
  return None, None, None, None, None
99
 
100
+ # Performance tracking setup
101
+ tracker = get_performance_tracker()
102
+
103
  input_tensor = (
104
  torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
105
  )
106
+
107
+ # Track inference performance
108
  start_time = time.time()
109
+ start_memory = _get_memory_usage()
110
+
111
  model.eval()
112
  with torch.no_grad():
113
  if model is None:
 
118
  prediction = torch.argmax(logits, dim=1).item()
119
  logits_list = logits.detach().numpy().tolist()[0]
120
  probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
121
+
122
  inference_time = time.time() - start_time
123
+ end_memory = _get_memory_usage()
124
+ memory_usage = max(end_memory - start_memory, 0)
125
+
126
+ # Log performance metrics
127
+ try:
128
+ modality = st.session_state.get("modality_select", "raman")
129
+ confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
130
+
131
+ metrics = PerformanceMetrics(
132
+ model_name=model_choice,
133
+ prediction_time=inference_time,
134
+ preprocessing_time=0.0, # Will be updated by calling function if available
135
+ total_time=inference_time,
136
+ memory_usage_mb=memory_usage,
137
+ accuracy=None, # Will be updated if ground truth is available
138
+ confidence=confidence,
139
+ timestamp=datetime.not().isofformat(),
140
+ input_size=len(y_resampled) if hasattr(y_resampled, '__len__') else 500,
141
+ modality=modality,
142
+ )
143
+ tracker.log_perfomance(metrics)
144
+ except Exception as e:
145
+ # Dont fail inference if performance tracking fails
146
+ print(f"Performance tracking failed: {e}")
147
+
148
  cleanup_memory()
149
  return prediction, logits_list, probs, inference_time, logits
150
 
151
+ def _get_memory_usage() -> float:
152
+ """Get current memory usage in MB"""
153
+ try:
154
+ import psutil
155
+ process = psutil.Process()
156
+ return process.memory_info().rss / 1024 / 1024 # Convert to MB
157
+ except ImportError:
158
+ return 0.0 # psutil not available
159
+
160
+
161
 
162
  @st.cache_data
163
  def get_sample_files():