devjas1 commited on
Commit
73712cd
·
1 Parent(s): 5543304

(FEAT)[Model Optimization Suite]: Add Model Optimization Suite for quantization, pruning, and benchmarking

Browse files

- Created `ModelOptimizer` class with utilities for:
- Dynamic quantization (`quantize_model`)
- Magnitude-based pruning (`prune_model`)
- Operation fusion (`optimize_for_inference`, `_fuse_conv_bn`)
- Benchmarking (`benchmark_model`) and multi-technique comparison (`compare_optimizations`)
- Optimization suggestions based on speed/size requirements
- Added reporting and model-saving functions:
- `create_optimization_report`
- `save_optimized_model`
- Enables detailed performance analysis, speed/size reduction, and export of optimized models.
- Designed for extensible

Files changed (1) hide show
  1. utils/model_optimization.py +311 -0
utils/model_optimization.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model performance optimization utilities.
3
+ Includes model quantization, pruning, and optimization techniques.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.utils.prune as prune
9
+ from typing import Dict, Any, List, Optional, Tuple
10
+ import time
11
+ import numpy as np
12
+ from pathlib import Path
13
+
14
+
15
+ class ModelOptimizer:
16
+ """Utility class for optimizing trained models."""
17
+
18
+ def __init__(self):
19
+ self.optimization_history = []
20
+
21
+ def quantize_model(
22
+ self, model: nn.Module, dtype: torch.dtype = torch.qint8
23
+ ) -> nn.Module:
24
+ """Apply dynamic quantization to reduce model size and inference time."""
25
+ # Prepare for quantization
26
+ model.eval()
27
+
28
+ # Apply dynamic quantization
29
+ quantized_model = torch.quantization.quantize_dynamic(
30
+ model, {nn.Linear, nn.Conv1d}, dtype=dtype # Layers to quantize
31
+ )
32
+
33
+ return quantized_model
34
+
35
+ def prune_model(
36
+ self, model: nn.Module, pruning_ratio: float = 0.2, structured: bool = False
37
+ ) -> nn.Module:
38
+ """Apply magnitude-based pruning to reduce model parameters."""
39
+ model_copy = type(model)(
40
+ model.input_length if hasattr(model, "input_length") else 500
41
+ )
42
+ model_copy.load_state_dict(model.state_dict())
43
+
44
+ # Collect modules to prune
45
+ modules_to_prune = []
46
+ for name, module in model_copy.named_modules():
47
+ if isinstance(module, (nn.Conv1d, nn.Linear)):
48
+ modules_to_prune.append((module, "weight"))
49
+
50
+ if structured:
51
+ # Structured pruning (entire channels/filters)
52
+ for module, param_name in modules_to_prune:
53
+ if isinstance(module, nn.Conv1d):
54
+ prune.ln_structured(
55
+ module, name=param_name, amount=pruning_ratio, n=2, dim=0
56
+ )
57
+ else:
58
+ prune.l1_unstructured(module, name=param_name, amount=pruning_ratio)
59
+ else:
60
+ # Unstructured pruning
61
+ prune.global_unstructured(
62
+ modules_to_prune,
63
+ pruning_method=prune.L1Unstructured,
64
+ amount=pruning_ratio,
65
+ )
66
+
67
+ # Make pruning permanent
68
+ for module, param_name in modules_to_prune:
69
+ prune.remove(module, param_name)
70
+
71
+ return model_copy
72
+
73
+ def optimize_for_inference(self, model: nn.Module) -> nn.Module:
74
+ """Apply multiple optimizations for faster inference."""
75
+ model.eval()
76
+
77
+ # Fuse operations where possible
78
+ optimized_model = self._fuse_conv_bn(model)
79
+
80
+ # Apply quantization
81
+ optimized_model = self.quantize_model(optimized_model)
82
+
83
+ return optimized_model
84
+
85
+ def _fuse_conv_bn(self, model: nn.Module) -> nn.Module:
86
+ """Fuse convolution and batch normalization layers."""
87
+ model_copy = type(model)(
88
+ model.input_length if hasattr(model, "input_length") else 500
89
+ )
90
+ model_copy.load_state_dict(model.state_dict())
91
+
92
+ # Simple fusion for sequential Conv1d + BatchNorm1d patterns
93
+ for name, module in model_copy.named_children():
94
+ if isinstance(module, nn.Sequential):
95
+ self._fuse_sequential_conv_bn(module)
96
+
97
+ return model_copy
98
+
99
+ def _fuse_sequential_conv_bn(self, sequential: nn.Sequential):
100
+ """Fuse Conv1d + BatchNorm1d in sequential modules."""
101
+ layers = list(sequential.children())
102
+ i = 0
103
+ while i < len(layers) - 1:
104
+ if isinstance(layers[i], nn.Conv1d) and isinstance(
105
+ layers[i + 1], nn.BatchNorm1d
106
+ ):
107
+ # Fuse the layers
108
+ if isinstance(layers[i], nn.Conv1d) and isinstance(
109
+ layers[i + 1], nn.BatchNorm1d
110
+ ):
111
+ if isinstance(layers[i + 1], nn.BatchNorm1d):
112
+ if isinstance(layers[i], nn.Conv1d) and isinstance(
113
+ layers[i + 1], nn.BatchNorm1d
114
+ ):
115
+ fused = self._fuse_conv_bn_layer(layers[i], layers[i + 1])
116
+ else:
117
+ fused = None
118
+ else:
119
+ fused = None
120
+ else:
121
+ fused = None
122
+ if fused:
123
+ # Replace in sequential
124
+ new_layers = layers[:i] + [fused] + layers[i + 2 :]
125
+ sequential = nn.Sequential(*new_layers)
126
+ layers = new_layers
127
+ i += 1
128
+
129
+ def _fuse_conv_bn_layer(self, conv: nn.Conv1d, bn: nn.BatchNorm1d) -> nn.Conv1d:
130
+ """Fuse a single Conv1d and BatchNorm1d layer."""
131
+ # Create new conv layer
132
+ fused_conv = nn.Conv1d(
133
+ conv.in_channels,
134
+ conv.out_channels,
135
+ conv.kernel_size[0],
136
+ conv.stride[0] if isinstance(conv.stride, tuple) else conv.stride,
137
+ conv.padding[0] if isinstance(conv.padding, tuple) else conv.padding,
138
+ conv.dilation[0] if isinstance(conv.dilation, tuple) else conv.dilation,
139
+ conv.groups,
140
+ bias=True, # Always add bias after fusion
141
+ )
142
+
143
+ # Calculate fused parameters
144
+ w_conv = conv.weight.clone()
145
+ w_bn = bn.weight.clone()
146
+ b_bn = bn.bias.clone()
147
+ mean_bn = (
148
+ bn.running_mean.clone()
149
+ if bn.running_mean is not None
150
+ else torch.zeros_like(bn.weight)
151
+ )
152
+ var_bn = (
153
+ bn.running_var.clone()
154
+ if bn.running_var is not None
155
+ else torch.zeros_like(bn.weight)
156
+ )
157
+ eps = bn.eps
158
+
159
+ # Fuse weights
160
+ factor = w_bn / torch.sqrt(var_bn + eps)
161
+ fused_conv.weight.data = w_conv * factor.reshape(-1, 1, 1)
162
+
163
+ # Fuse bias
164
+ if conv.bias is not None:
165
+ b_conv = conv.bias.clone()
166
+ else:
167
+ b_conv = torch.zeros_like(b_bn)
168
+
169
+ fused_conv.bias.data = (b_conv - mean_bn) * factor + b_bn
170
+
171
+ return fused_conv
172
+
173
+ def benchmark_model(
174
+ self,
175
+ model: nn.Module,
176
+ input_shape: Tuple[int, ...] = (1, 1, 500),
177
+ num_runs: int = 100,
178
+ warmup_runs: int = 10,
179
+ ) -> Dict[str, float]:
180
+ """Benchmark model performance."""
181
+ model.eval()
182
+
183
+ # Create dummy input
184
+ dummy_input = torch.randn(input_shape)
185
+
186
+ # Warmup
187
+ with torch.no_grad():
188
+ for _ in range(warmup_runs):
189
+ _ = model(dummy_input)
190
+
191
+ # Benchmark
192
+ times = []
193
+ with torch.no_grad():
194
+ for _ in range(num_runs):
195
+ start_time = time.time()
196
+ _ = model(dummy_input)
197
+ end_time = time.time()
198
+ times.append(end_time - start_time)
199
+
200
+ # Calculate statistics
201
+ times = np.array(times)
202
+
203
+ # Count parameters
204
+ total_params = sum(p.numel() for p in model.parameters())
205
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
206
+
207
+ # Calculate model size (approximate)
208
+ param_size = sum(p.numel() * p.element_size() for p in model.parameters())
209
+ buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
210
+ model_size_mb = (param_size + buffer_size) / (1024 * 1024)
211
+
212
+ return {
213
+ "mean_inference_time": float(np.mean(times)),
214
+ "std_inference_time": float(np.std(times)),
215
+ "min_inference_time": float(np.min(times)),
216
+ "max_inference_time": float(np.max(times)),
217
+ "fps": 1.0 / float(np.mean(times)),
218
+ "total_parameters": total_params,
219
+ "trainable_parameters": trainable_params,
220
+ "model_size_mb": model_size_mb,
221
+ }
222
+
223
+ def compare_optimizations(
224
+ self,
225
+ original_model: nn.Module,
226
+ optimizations: Optional[List[str]] = None,
227
+ input_shape: Tuple[int, ...] = (1, 1, 500),
228
+ ) -> Dict[str, Dict[str, Any]]:
229
+ if optimizations is None:
230
+ optimizations = ["quantize", "prune", "full_optimize"]
231
+ results = {}
232
+
233
+ # Benchmark original model
234
+ results["original"] = self.benchmark_model(original_model, input_shape)
235
+
236
+ for opt in optimizations:
237
+ try:
238
+ if opt == "quantize":
239
+ optimized_model = self.quantize_model(original_model)
240
+ elif opt == "prune":
241
+ optimized_model = self.prune_model(
242
+ original_model, pruning_ratio=0.3
243
+ )
244
+ elif opt == "full_optimize":
245
+ optimized_model = self.optimize_for_inference(original_model)
246
+ else:
247
+ continue
248
+
249
+ # Benchmark optimized model
250
+ benchmark_results = self.benchmark_model(optimized_model, input_shape)
251
+
252
+ # Calculate improvements
253
+ speedup = (
254
+ results["original"]["mean_inference_time"]
255
+ / benchmark_results["mean_inference_time"]
256
+ )
257
+ size_reduction = (
258
+ results["original"]["model_size_mb"]
259
+ - benchmark_results["model_size_mb"]
260
+ ) / results["original"]["model_size_mb"]
261
+ param_reduction = (
262
+ results["original"]["total_parameters"]
263
+ - benchmark_results["total_parameters"]
264
+ ) / results["original"]["total_parameters"]
265
+
266
+ benchmark_results.update(
267
+ {
268
+ "speedup": speedup,
269
+ "size_reduction_ratio": size_reduction,
270
+ "parameter_reduction_ratio": param_reduction,
271
+ }
272
+ )
273
+
274
+ results[opt] = benchmark_results
275
+
276
+ except (RuntimeError, ValueError, TypeError) as e:
277
+ results[opt] = {"error": str(e)}
278
+
279
+ return results
280
+
281
+ def suggest_optimizations(
282
+ self,
283
+ model: nn.Module,
284
+ target_speed: Optional[float] = None,
285
+ target_size: Optional[float] = None,
286
+ ) -> List[str]:
287
+ """Suggest optimization strategies based on requirements."""
288
+ suggestions = []
289
+
290
+ # Get baseline metrics
291
+ baseline = self.benchmark_model(model)
292
+
293
+ if target_speed and baseline["mean_inference_time"] > target_speed:
294
+ suggestions.append("Apply quantization for 2-4x speedup")
295
+ suggestions.append("Use pruning to reduce model size by 20-50%")
296
+ suggestions.append(
297
+ "Consider using EfficientSpectralCNN for real-time inference"
298
+ )
299
+
300
+ if target_size and baseline["model_size_mb"] > target_size:
301
+ suggestions.append("Apply magnitude-based pruning")
302
+ suggestions.append("Use quantization to reduce model size")
303
+ suggestions.append("Consider knowledge distillation to a smaller model")
304
+
305
+ # Model-specific suggestions
306
+ if baseline["total_parameters"] > 1000000:
307
+ suggestions.append(
308
+ "Model is large - consider using efficient architectures"
309
+ )
310
+
311
+ return suggestions