devjas1 commited on
Commit
6cfb4d3
·
1 Parent(s): 64728dc

(FEAT)(New Models): Add advanced spectral CNN architectures

Browse files

- Created `models/enhanced_cnn.py` for new model implementations.
- Added three model classes:
- `EnhancedCNN`: Combines attention blocks, multi-scale convolutions, and improved residual connections for robust spectral feature extraction.
- `EfficientSpectralCNN`: Lightweight, real-time model using depthwise separable convolutions for fast inference.
- `HybridSpectralNet`: Integrates CNN backbone with self-attention for hybrid spectral learning.
- All architectures are tailored for 1D polymer spectral data and inspired by SE-Net, ResNet, and Inception.
- Includes a factory function for easy model registration and instantiation.
- Enables extensible, high-performance model selection in the platform.

Files changed (1) hide show
  1. models/enhanced_cnn.py +405 -0
models/enhanced_cnn.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ All neural network blocks and architectures in models/enhanced_cnn.py are custom implementations, developed to expand the model registry for advanced polymer spectral classification. While inspired by established deep learning concepts (such as residual connections, attention mechanisms, and multi-scale convolutions), they are are unique to this project and tailored for 1D spectral data.
3
+
4
+ Registry expansion: The purpose is to enrich the available models.
5
+ Literature inspiration: SE-Net, ResNet, Inception.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class AttentionBlock1D(nn.Module):
14
+ """1D attention mechanism for spectral data."""
15
+
16
+ def __init__(self, channels: int, reduction: int = 8):
17
+ super().__init__()
18
+ self.channels = channels
19
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
20
+ self.fc = nn.Sequential(
21
+ nn.Linear(channels, channels // reduction),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(channels // reduction, channels),
24
+ nn.Sigmoid(),
25
+ )
26
+
27
+ def forward(self, x):
28
+ # x shape: [batch, channels, length]
29
+ b, c, _ = x.size()
30
+
31
+ # Global average pooling
32
+ y = self.global_pool(x).view(b, c)
33
+
34
+ # Fully connected layers
35
+ y = self.fc(y).view(b, c, 1)
36
+
37
+ # Apply attention weights
38
+ return x * y.expand_as(x)
39
+
40
+
41
+ class EnhancedResidualBlock1D(nn.Module):
42
+ """Enhanced residual block with attention and improved normalization."""
43
+
44
+ def __init__(
45
+ self,
46
+ in_channels: int,
47
+ out_channels: int,
48
+ kernel_size: int = 3,
49
+ use_attention: bool = True,
50
+ dropout_rate: float = 0.1,
51
+ ):
52
+ super().__init__()
53
+ padding = kernel_size // 2
54
+
55
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
56
+ self.bn1 = nn.BatchNorm1d(out_channels)
57
+ self.relu = nn.ReLU(inplace=True)
58
+
59
+ self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding)
60
+ self.bn2 = nn.BatchNorm1d(out_channels)
61
+
62
+ self.dropout = nn.Dropout1d(dropout_rate) if dropout_rate > 0 else nn.Identity()
63
+
64
+ # Skip connection
65
+ self.skip = (
66
+ nn.Identity()
67
+ if in_channels == out_channels
68
+ else nn.Sequential(
69
+ nn.Conv1d(in_channels, out_channels, kernel_size=1),
70
+ nn.BatchNorm1d(out_channels),
71
+ )
72
+ )
73
+
74
+ # Attention mechanism
75
+ self.attention = (
76
+ AttentionBlock1D(out_channels) if use_attention else nn.Identity()
77
+ )
78
+
79
+ def forward(self, x):
80
+ identity = self.skip(x)
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+ out = self.dropout(out)
86
+
87
+ out = self.conv2(out)
88
+ out = self.bn2(out)
89
+
90
+ # Apply attention
91
+ out = self.attention(out)
92
+
93
+ out = out + identity
94
+ return self.relu(out)
95
+
96
+
97
+ class MultiScaleConvBlock(nn.Module):
98
+ """Multi-scale convolution block for capturing features at different scales."""
99
+
100
+ def __init__(self, in_channels: int, out_channels: int):
101
+ super().__init__()
102
+
103
+ # Different kernel sizes for multi-scale feature extraction
104
+ self.conv1 = nn.Conv1d(in_channels, out_channels // 4, kernel_size=3, padding=1)
105
+ self.conv2 = nn.Conv1d(in_channels, out_channels // 4, kernel_size=5, padding=2)
106
+ self.conv3 = nn.Conv1d(in_channels, out_channels // 4, kernel_size=7, padding=3)
107
+ self.conv4 = nn.Conv1d(in_channels, out_channels // 4, kernel_size=9, padding=4)
108
+
109
+ self.bn = nn.BatchNorm1d(out_channels)
110
+ self.relu = nn.ReLU(inplace=True)
111
+
112
+ def forward(self, x):
113
+ # Parallel convolutions with different kernel sizes
114
+ out1 = self.conv1(x)
115
+ out2 = self.conv2(x)
116
+ out3 = self.conv3(x)
117
+ out4 = self.conv4(x)
118
+
119
+ # Concatenate along channel dimension
120
+ out = torch.cat([out1, out2, out3, out4], dim=1)
121
+ out = self.bn(out)
122
+ return self.relu(out)
123
+
124
+
125
+ class EnhancedCNN(nn.Module):
126
+ """Enhanced CNN with attention, multi-scale features, and improved architecture."""
127
+
128
+ def __init__(
129
+ self,
130
+ input_length: int = 500,
131
+ num_classes: int = 2,
132
+ dropout_rate: float = 0.2,
133
+ use_attention: bool = True,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.input_length = input_length
138
+ self.num_classes = num_classes
139
+
140
+ # Initial feature extraction
141
+ self.initial_conv = nn.Sequential(
142
+ nn.Conv1d(1, 32, kernel_size=7, padding=3),
143
+ nn.BatchNorm1d(32),
144
+ nn.ReLU(inplace=True),
145
+ nn.MaxPool1d(kernel_size=2),
146
+ )
147
+
148
+ # Multi-scale feature extraction
149
+ self.multiscale_block = MultiScaleConvBlock(32, 64)
150
+ self.pool1 = nn.MaxPool1d(kernel_size=2)
151
+
152
+ # Enhanced residual blocks
153
+ self.res_block1 = EnhancedResidualBlock1D(64, 96, use_attention=use_attention)
154
+ self.pool2 = nn.MaxPool1d(kernel_size=2)
155
+
156
+ self.res_block2 = EnhancedResidualBlock1D(96, 128, use_attention=use_attention)
157
+ self.pool3 = nn.MaxPool1d(kernel_size=2)
158
+
159
+ self.res_block3 = EnhancedResidualBlock1D(128, 160, use_attention=use_attention)
160
+
161
+ # Global feature extraction
162
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
163
+
164
+ # Calculate feature size after convolutions
165
+ self.feature_size = 160
166
+
167
+ # Enhanced classifier with dropout
168
+ self.classifier = nn.Sequential(
169
+ nn.Linear(self.feature_size, 256),
170
+ nn.BatchNorm1d(256),
171
+ nn.ReLU(inplace=True),
172
+ nn.Dropout(dropout_rate),
173
+ nn.Linear(256, 128),
174
+ nn.BatchNorm1d(128),
175
+ nn.ReLU(inplace=True),
176
+ nn.Dropout(dropout_rate),
177
+ nn.Linear(128, 64),
178
+ nn.BatchNorm1d(64),
179
+ nn.ReLU(inplace=True),
180
+ nn.Dropout(dropout_rate / 2),
181
+ nn.Linear(64, num_classes),
182
+ )
183
+
184
+ # Initialize weights
185
+ self._initialize_weights()
186
+
187
+ def _initialize_weights(self):
188
+ """Initialize model weights using Xavier initialization."""
189
+ for m in self.modules():
190
+ if isinstance(m, nn.Conv1d):
191
+ nn.init.xavier_uniform_(m.weight)
192
+ if m.bias is not None:
193
+ nn.init.constant_(m.bias, 0)
194
+ elif isinstance(m, nn.Linear):
195
+ nn.init.xavier_uniform_(m.weight)
196
+ nn.init.constant_(m.bias, 0)
197
+ elif isinstance(m, nn.BatchNorm1d):
198
+ nn.init.constant_(m.weight, 1)
199
+ nn.init.constant_(m.bias, 0)
200
+
201
+ def forward(self, x):
202
+ # Ensure input is 3D: [batch, channels, length]
203
+ if x.dim() == 2:
204
+ x = x.unsqueeze(1)
205
+
206
+ # Feature extraction
207
+ x = self.initial_conv(x)
208
+ x = self.multiscale_block(x)
209
+ x = self.pool1(x)
210
+
211
+ x = self.res_block1(x)
212
+ x = self.pool2(x)
213
+
214
+ x = self.res_block2(x)
215
+ x = self.pool3(x)
216
+
217
+ x = self.res_block3(x)
218
+
219
+ # Global pooling
220
+ x = self.global_pool(x)
221
+ x = x.view(x.size(0), -1)
222
+
223
+ # Classification
224
+ x = self.classifier(x)
225
+
226
+ return x
227
+
228
+ def get_feature_maps(self, x):
229
+ """Extract intermediate feature maps for visualization."""
230
+ if x.dim() == 2:
231
+ x = x.unsqueeze(1)
232
+
233
+ features = {}
234
+
235
+ x = self.initial_conv(x)
236
+ features["initial"] = x
237
+
238
+ x = self.multiscale_block(x)
239
+ features["multiscale"] = x
240
+ x = self.pool1(x)
241
+
242
+ x = self.res_block1(x)
243
+ features["res1"] = x
244
+ x = self.pool2(x)
245
+
246
+ x = self.res_block2(x)
247
+ features["res2"] = x
248
+ x = self.pool3(x)
249
+
250
+ x = self.res_block3(x)
251
+ features["res3"] = x
252
+
253
+ return features
254
+
255
+
256
+ class EfficientSpectralCNN(nn.Module):
257
+ """Efficient CNN designed for real-time inference with good performance."""
258
+
259
+ def __init__(self, input_length: int = 500, num_classes: int = 2):
260
+ super().__init__()
261
+
262
+ # Efficient feature extraction with depthwise separable convolutions
263
+ self.features = nn.Sequential(
264
+ # Initial convolution
265
+ nn.Conv1d(1, 32, kernel_size=7, padding=3),
266
+ nn.BatchNorm1d(32),
267
+ nn.ReLU(inplace=True),
268
+ nn.MaxPool1d(2),
269
+ # Depthwise separable convolutions
270
+ self._make_depthwise_sep_conv(32, 64),
271
+ nn.MaxPool1d(2),
272
+ self._make_depthwise_sep_conv(64, 96),
273
+ nn.MaxPool1d(2),
274
+ self._make_depthwise_sep_conv(96, 128),
275
+ nn.MaxPool1d(2),
276
+ # Final feature extraction
277
+ nn.Conv1d(128, 160, kernel_size=3, padding=1),
278
+ nn.BatchNorm1d(160),
279
+ nn.ReLU(inplace=True),
280
+ nn.AdaptiveAvgPool1d(1),
281
+ )
282
+
283
+ # Lightweight classifier
284
+ self.classifier = nn.Sequential(
285
+ nn.Linear(160, 64),
286
+ nn.ReLU(inplace=True),
287
+ nn.Dropout(0.1),
288
+ nn.Linear(64, num_classes),
289
+ )
290
+
291
+ self._initialize_weights()
292
+
293
+ def _make_depthwise_sep_conv(self, in_channels, out_channels):
294
+ """Create depthwise separable convolution block."""
295
+ return nn.Sequential(
296
+ # Depthwise convolution
297
+ nn.Conv1d(
298
+ in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels
299
+ ),
300
+ nn.BatchNorm1d(in_channels),
301
+ nn.ReLU(inplace=True),
302
+ # Pointwise convolution
303
+ nn.Conv1d(in_channels, out_channels, kernel_size=1),
304
+ nn.BatchNorm1d(out_channels),
305
+ nn.ReLU(inplace=True),
306
+ )
307
+
308
+ def _initialize_weights(self):
309
+ """Initialize model weights."""
310
+ for m in self.modules():
311
+ if isinstance(m, nn.Conv1d):
312
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
313
+ if m.bias is not None:
314
+ nn.init.constant_(m.bias, 0)
315
+ elif isinstance(m, nn.Linear):
316
+ nn.init.xavier_uniform_(m.weight)
317
+ nn.init.constant_(m.bias, 0)
318
+ elif isinstance(m, nn.BatchNorm1d):
319
+ nn.init.constant_(m.weight, 1)
320
+ nn.init.constant_(m.bias, 0)
321
+
322
+ def forward(self, x):
323
+ if x.dim() == 2:
324
+ x = x.unsqueeze(1)
325
+
326
+ x = self.features(x)
327
+ x = x.view(x.size(0), -1)
328
+ x = self.classifier(x)
329
+
330
+ return x
331
+
332
+
333
+ class HybridSpectralNet(nn.Module):
334
+ """Hybrid network combining CNN and attention mechanisms."""
335
+
336
+ def __init__(self, input_length: int = 500, num_classes: int = 2):
337
+ super().__init__()
338
+
339
+ # CNN backbone
340
+ self.cnn_backbone = nn.Sequential(
341
+ nn.Conv1d(1, 64, kernel_size=7, padding=3),
342
+ nn.BatchNorm1d(64),
343
+ nn.ReLU(inplace=True),
344
+ nn.MaxPool1d(2),
345
+ nn.Conv1d(64, 128, kernel_size=5, padding=2),
346
+ nn.BatchNorm1d(128),
347
+ nn.ReLU(inplace=True),
348
+ nn.MaxPool1d(2),
349
+ nn.Conv1d(128, 256, kernel_size=3, padding=1),
350
+ nn.BatchNorm1d(256),
351
+ nn.ReLU(inplace=True),
352
+ )
353
+
354
+ # Self-attention layer
355
+ self.attention = nn.MultiheadAttention(
356
+ embed_dim=256, num_heads=8, dropout=0.1, batch_first=True
357
+ )
358
+
359
+ # Final pooling and classification
360
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
361
+ self.classifier = nn.Sequential(
362
+ nn.Linear(256, 128),
363
+ nn.ReLU(inplace=True),
364
+ nn.Dropout(0.2),
365
+ nn.Linear(128, num_classes),
366
+ )
367
+
368
+ def forward(self, x):
369
+ if x.dim() == 2:
370
+ x = x.unsqueeze(1)
371
+
372
+ # CNN feature extraction
373
+ x = self.cnn_backbone(x)
374
+
375
+ # Prepare for attention: [batch, length, channels]
376
+ x = x.transpose(1, 2)
377
+
378
+ # Self-attention
379
+ attn_out, _ = self.attention(x, x, x)
380
+
381
+ # Back to [batch, channels, length]
382
+ x = attn_out.transpose(1, 2)
383
+
384
+ # Global pooling and classification
385
+ x = self.global_pool(x)
386
+ x = x.view(x.size(0), -1)
387
+ x = self.classifier(x)
388
+
389
+ return x
390
+
391
+
392
+ def create_enhanced_model(model_type: str = "enhanced", **kwargs):
393
+ """Factory function to create enhanced models."""
394
+ models = {
395
+ "enhanced": EnhancedCNN,
396
+ "efficient": EfficientSpectralCNN,
397
+ "hybrid": HybridSpectralNet,
398
+ }
399
+
400
+ if model_type not in models:
401
+ raise ValueError(
402
+ f"Unknown model type: {model_type}. Available: {list(models.keys())}"
403
+ )
404
+
405
+ return models[model_type](**kwargs)