File size: 30,471 Bytes
ee6cc2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Optional

class VoxelCNNEncoder(nn.Module):
    """
    Enhanced 3D CNN encoder for voxelized obstruction data with multi-channel support.
    Processes environment obstacles, start position, and goal position.
    """
    def __init__(self,
                 input_channels=3,  # obstacles + start + goal
                 filters_1=32,
                 kernel_size_1=(3, 3, 3),
                 pool_size_1=(2, 2, 2),
                 filters_2=64,
                 kernel_size_2=(3, 3, 3),
                 pool_size_2=(2, 2, 2),
                 filters_3=128,
                 kernel_size_3=(3, 3, 3),
                 pool_size_3=(2, 2, 2),
                 dense_units=512,
                 input_voxel_dim=(32, 32, 32),
                 dropout_rate=0.2
                ):
        super(VoxelCNNEncoder, self).__init__()

        self.input_voxel_dim = input_voxel_dim
        self.input_channels = input_channels

        # First 3D Convolutional Block (Conv-BN-ReLU)
        padding_1 = tuple([(k - 1) // 2 for k in kernel_size_1])
        self.conv1 = nn.Conv3d(input_channels, filters_1, kernel_size_1, padding=padding_1)
        self.bn1 = nn.BatchNorm3d(filters_1)
        self.pool1 = nn.MaxPool3d(pool_size_1)
        self.dropout1 = nn.Dropout3d(dropout_rate)

        # Second 3D Convolutional Block (Conv-BN-ReLU)
        padding_2 = tuple([(k - 1) // 2 for k in kernel_size_2])
        self.conv2 = nn.Conv3d(filters_1, filters_2, kernel_size_2, padding=padding_2)
        self.bn2 = nn.BatchNorm3d(filters_2)
        self.pool2 = nn.MaxPool3d(pool_size_2)
        self.dropout2 = nn.Dropout3d(dropout_rate)

        # Third 3D Convolutional Block (Conv-BN-ReLU)
        padding_3 = tuple([(k - 1) // 2 for k in kernel_size_3])
        self.conv3 = nn.Conv3d(filters_2, filters_3, kernel_size_3, padding=padding_3)
        self.bn3 = nn.BatchNorm3d(filters_3)
        self.pool3 = nn.MaxPool3d(pool_size_3)
        self.dropout3 = nn.Dropout3d(dropout_rate)

        # Calculate flattened size
        self._to_linear_input_size = self._get_conv_output_size()

        # Dense layers with residual connection
        self.fc1 = nn.Linear(self._to_linear_input_size, dense_units)
        self.fc2 = nn.Linear(dense_units, dense_units)
        self.dropout_fc = nn.Dropout(dropout_rate)

    def _get_conv_output_size(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, self.input_channels, *self.input_voxel_dim)
            # Standardized Conv-BN-ReLU order
            x = self.conv1(dummy_input)
            x = self.bn1(x)
            x = F.relu(x)
            x = self.pool1(x)
            x = self.dropout1(x)
            
            x = self.conv2(x)
            x = self.bn2(x)
            x = F.relu(x)
            x = self.pool2(x)
            x = self.dropout2(x)
            
            x = self.conv3(x)
            x = self.bn3(x)
            x = F.relu(x)
            x = self.pool3(x)
            x = self.dropout3(x)
            
            return x.numel()

    def forward(self, x):
        # First conv block (Conv-BN-ReLU)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        # Second conv block (Conv-BN-ReLU)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        # Third conv block (Conv-BN-ReLU)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.pool3(x)
        x = self.dropout3(x)

        # Flatten and dense layers
        x = x.view(x.size(0), -1)
        x1 = F.relu(self.fc1(x))
        x1 = self.dropout_fc(x1)
        x2 = F.relu(self.fc2(x1))
        
        # Residual connection
        return x1 + x2


class PositionEncoder(nn.Module):
    """
    Encodes start and goal positions with learned embeddings.
    """
    def __init__(self, voxel_dim=(32, 32, 32), position_embed_dim=64):
        super(PositionEncoder, self).__init__()
        self.voxel_dim = voxel_dim
        self.position_embed_dim = position_embed_dim
        
        # Calculate dimensions for each axis to sum to position_embed_dim
        dim_per_axis = position_embed_dim // 3
        remainder = position_embed_dim % 3
        
        x_dim = dim_per_axis + (1 if remainder > 0 else 0)
        y_dim = dim_per_axis + (1 if remainder > 1 else 0)
        z_dim = dim_per_axis
        
        # Learned position embeddings for each dimension
        self.x_embed = nn.Embedding(voxel_dim[0], x_dim)
        self.y_embed = nn.Embedding(voxel_dim[1], y_dim)
        self.z_embed = nn.Embedding(voxel_dim[2], z_dim)
        
        # Additional processing - fixed input dimension
        self.fc = nn.Linear(2 * position_embed_dim, position_embed_dim)
        
    def forward(self, positions):
        """
        positions: (batch_size, 2, 3) - [start_pos, goal_pos] with (x, y, z)
        """
        batch_size = positions.size(0)
        
        # Extract coordinates
        # Clamp coordinates defensively to valid index ranges to avoid embedding OOB
        x_coords = positions[:, :, 0].long().clamp_(0, self.voxel_dim[0] - 1)  # (batch_size, 2)
        y_coords = positions[:, :, 1].long().clamp_(0, self.voxel_dim[1] - 1)  # (batch_size, 2)
        z_coords = positions[:, :, 2].long().clamp_(0, self.voxel_dim[2] - 1)  # (batch_size, 2)
        
        # Get embeddings
        x_emb = self.x_embed(x_coords)  # (batch_size, 2, x_dim)
        y_emb = self.y_embed(y_coords)  # (batch_size, 2, y_dim)
        z_emb = self.z_embed(z_coords)  # (batch_size, 2, z_dim)
        
        # Concatenate embeddings
        pos_emb = torch.cat([x_emb, y_emb, z_emb], dim=-1)  # (batch_size, 2, position_embed_dim)
        
        # Flatten start and goal embeddings
        pos_emb = pos_emb.view(batch_size, -1)  # (batch_size, 2 * position_embed_dim)
        
        return F.relu(self.fc(pos_emb))


class PathPlannerTransformer(nn.Module):
    """
    Transformer-based path planner that generates action sequences.
    Fixed token IDs to avoid collision:
    - Actions: 0-5 (Forward, Back, Left, Right, Up, Down)
    - START: 6
    - END: 7
    - PAD: 8 (used only for teacher forcing inputs; targets still use -1 for ignore)
    """
    def __init__(self, 
                 env_feature_dim=512,
                 pos_feature_dim=64,
                 hidden_dim=256,
                 num_heads=8,
                 num_layers=4,
                 max_sequence_length=100,
                 num_actions=6,  # Forward, Back, Left, Right, Up, Down
                 use_end_token=True):
        super(PathPlannerTransformer, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.max_sequence_length = max_sequence_length
        self.num_actions = num_actions
        self.use_end_token = use_end_token
        
        # Fixed token IDs to avoid collision
        self.start_token_id = num_actions  # 6
        self.end_token_id = num_actions + 1 if use_end_token else None  # 7
        # Reserve a PAD token for embedding inputs during teacher forcing
        self.pad_token_id = (num_actions + 2) if use_end_token else (num_actions + 1)
        # Total tokens include PAD
        self.total_tokens = (num_actions + 3) if use_end_token else (num_actions + 2)
        
        # Feature fusion
        self.feature_fusion = nn.Linear(env_feature_dim + pos_feature_dim, hidden_dim)
        
        # Action embeddings
        self.action_embed = nn.Embedding(self.total_tokens, hidden_dim)
        
        # Positional encoding - register as buffer for proper device handling
        self.register_buffer('pos_encoding', self._create_positional_encoding(max_sequence_length, hidden_dim))
        
        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, self.total_tokens)
        
        # Turn head (now supervised via BCE-with-logits against turn labels)
        self.turn_penalty_head = nn.Linear(hidden_dim, 1)
        
    def _create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)
    
    def forward(self, env_features, pos_features, target_actions=None):
        """
        env_features: (batch_size, env_feature_dim)
        pos_features: (batch_size, pos_feature_dim)
        target_actions: (batch_size, seq_len) - for training (contains action IDs 0-5 and END token 7)
        """
        batch_size = env_features.size(0)
        
        # Fuse environment and position features
        fused_features = self.feature_fusion(torch.cat([env_features, pos_features], dim=1))
        
        # Create memory (encoder output) by repeating fused features
        memory = fused_features.unsqueeze(1).repeat(1, self.max_sequence_length, 1)
        
        if target_actions is not None:
            # Training mode: use teacher forcing
            seq_len = target_actions.size(1)
            
            # Create input sequence (START token + target_actions[:-1])
            start_tokens = torch.full((batch_size, 1), self.start_token_id, 
                                    dtype=torch.long, device=target_actions.device)
            input_seq = torch.cat([start_tokens, target_actions[:, :-1]], dim=1)
            # Replace padding (-1) in teacher-forced inputs with PAD token id to avoid OOB in embedding
            input_seq = torch.where(input_seq < 0, torch.full_like(input_seq, self.pad_token_id), input_seq)
            
            # Embed actions and add positional encoding
            embedded = self.action_embed(input_seq)
            embedded = embedded + self.pos_encoding[:, :seq_len, :]
            
            # Generate attention mask (causal mask)
            tgt_mask = self._generate_square_subsequent_mask(seq_len).to(embedded.device)
            
            # Transformer decoder forward pass
            output = self.transformer_decoder(
                tgt=embedded,
                memory=memory[:, :seq_len, :],
                tgt_mask=tgt_mask
            )
            
            # Output projections
            action_logits = self.output_proj(output)
            # Turn logits for supervised turn classification
            turn_logits = self.turn_penalty_head(output)
            
            return action_logits, turn_logits
        else:
            # Inference mode: generate sequence autoregressively
            return self._generate_path(memory, batch_size)
    
    def _generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def _generate_path(self, memory, batch_size):
        """
        Generate path sequence autoregressively, handling batches correctly.
        Fixes bugs related to premature stopping and inclusion of special tokens.
        """
        device = memory.device
        
        # Start with START token
        input_seq = torch.full((batch_size, 1), self.start_token_id, dtype=torch.long, device=device)
        
        # Keep track of sequences that have generated an END token
        is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for step in range(self.max_sequence_length):
            # Embed current sequence
            embedded = self.action_embed(input_seq)
            seq_len = embedded.size(1)
            embedded = embedded + self.pos_encoding[:, :seq_len, :]

            # Generate causal mask
            tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)

            # Forward pass
            output = self.transformer_decoder(
                tgt=embedded,
                memory=memory[:, :seq_len, :],
                tgt_mask=tgt_mask
            )
            
            # Get next action probabilities from the last token in the sequence
            next_action_logits = self.output_proj(output[:, -1, :])
            next_actions = torch.argmax(next_action_logits, dim=-1, keepdim=True)
            
            # Append the predicted actions to the sequence
            input_seq = torch.cat([input_seq, next_actions], dim=1)
            
            # Update the finished mask for any sequence that just produced an END token
            if self.use_end_token:
                is_finished |= (next_actions.squeeze(-1) == self.end_token_id)
            
            # If all sequences in the batch are finished, we can stop early
            if is_finished.all():
                break
        
        # Post-processing to create a clean, dense tensor of valid actions
        # Remove the initial START token from all sequences
        raw_paths = input_seq[:, 1:]
        
        clean_paths_list = []
        max_len = 0
        
        for i in range(batch_size):
            path = []
            for token_id in raw_paths[i]:
                # Stop decoding for this path if an END token is found
                if self.use_end_token and token_id.item() == self.end_token_id:
                    break
                # Only include valid movement actions in the final path
                if token_id.item() < self.num_actions:
                    path.append(token_id.item())
            
            clean_paths_list.append(path)
            if len(path) > max_len:
                max_len = len(path)
        
        # Return an empty tensor if no valid actions were generated
        if max_len == 0:
            return torch.zeros(batch_size, 0, dtype=torch.long, device=device)
        
        # Pad all paths to the length of the longest path in the batch
        # We use the END token ID for padding, as downstream functions like
        # check_collisions are designed to ignore non-action tokens.
        pad_value = self.end_token_id if self.use_end_token else self.num_actions
        padded_paths = torch.full((batch_size, max_len), pad_value, dtype=torch.long, device=device)
        
        for i, path in enumerate(clean_paths_list):
            if len(path) > 0:
                padded_paths[i, :len(path)] = torch.tensor(path, dtype=torch.long, device=device)
                
        return padded_paths


class PathfindingNetwork(nn.Module):
    """
    Complete pathfinding network combining CNN encoder, position encoder, and transformer planner.
    """
    def __init__(self, 
                 voxel_dim=(32, 32, 32),
                 input_channels=3,
                 env_feature_dim=512,
                 pos_feature_dim=64,
                 hidden_dim=256,
                 num_actions=6,
                 use_end_token=True):
        super(PathfindingNetwork, self).__init__()
        
        self.voxel_dim = voxel_dim
        self.num_actions = num_actions
        
        self.voxel_encoder = VoxelCNNEncoder(
            input_channels=input_channels,
            dense_units=env_feature_dim,
            input_voxel_dim=voxel_dim
        )
        
        self.position_encoder = PositionEncoder(
            voxel_dim=voxel_dim,
            position_embed_dim=pos_feature_dim
        )
        
        self.path_planner = PathPlannerTransformer(
            env_feature_dim=env_feature_dim,
            pos_feature_dim=pos_feature_dim,
            hidden_dim=hidden_dim,
            num_actions=num_actions,
            use_end_token=use_end_token
        )
        
    def forward(self, voxel_data, positions, target_actions=None):
        """
        voxel_data: (batch_size, 3, D, H, W) - [obstacles, start_mask, goal_mask]
        positions: (batch_size, 2, 3) - [start_pos, goal_pos]
        target_actions: (batch_size, seq_len) - for training
        """
        # Encode environment
        env_features = self.voxel_encoder(voxel_data)
        
        # Encode positions
        pos_features = self.position_encoder(positions)
        
        # Generate path
        if target_actions is not None:
            action_logits, turn_penalties = self.path_planner(env_features, pos_features, target_actions)
            return action_logits, turn_penalties
        else:
            generated_path = self.path_planner(env_features, pos_features)
            return generated_path
    
    def check_collisions(self, voxel_data, positions, actions):
        """
        Check if actions lead to collisions with obstacles.
        
        voxel_data: (batch_size, 3, D, H, W)
        positions: (batch_size, 2, 3) - start positions
        actions: (batch_size, seq_len) - action sequences
        
        Returns: (batch_size, seq_len) collision mask
        """
        batch_size, seq_len = actions.shape
        device = actions.device
        
        # Extract obstacle channel
        obstacles = voxel_data[:, 0, :, :, :]  # (batch_size, D, H, W)
        
        # Action to direction mapping
        directions = torch.tensor([
            [1, 0, 0],   # Forward (z+)
            [-1, 0, 0],  # Back (z-)
            [0, 1, 0],   # Left (x+)
            [0, -1, 0],  # Right (x-)
            [0, 0, 1],   # Up (y+)
            [0, 0, -1]   # Down (y-)
        ], dtype=torch.long, device=device)
        
        collision_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
        current_pos = positions[:, 0, :].clone()  # Start from start position
        
        for t in range(seq_len):
            # Get actions for this timestep
            action_t = actions[:, t]
            
            # Only process valid actions (0-5), skip special tokens
            valid_actions = action_t < self.num_actions
            
            # Update positions based on actions
            for b in range(batch_size):
                if valid_actions[b]:
                    direction = directions[action_t[b]]
                    new_pos = current_pos[b] + direction
                    
                    # Check bounds
                    if (new_pos >= 0).all() and (new_pos[0] < self.voxel_dim[0]) and \
                       (new_pos[1] < self.voxel_dim[1]) and (new_pos[2] < self.voxel_dim[2]):
                        # Check collision
                        if obstacles[b, new_pos[0], new_pos[1], new_pos[2]] > 0:
                            collision_mask[b, t] = True
                        else:
                            current_pos[b] = new_pos
                    else:
                        # Out of bounds counts as collision
                        collision_mask[b, t] = True
        
        return collision_mask


class PathfindingLoss(nn.Module):
    """
    Custom loss function that balances path correctness and turn minimization.
    Properly handles special tokens (START=6, END=7) and action tokens (0-5).
    Turn loss is supervised: a turn occurs when consecutive valid actions differ.
    """
    def __init__(self, turn_penalty_weight=0.1, collision_penalty_weight=10.0, 
                 num_actions=6, use_end_token=True):
        super(PathfindingLoss, self).__init__()
        self.turn_penalty_weight = turn_penalty_weight
        self.collision_penalty_weight = collision_penalty_weight
        self.num_actions = num_actions
        self.use_end_token = use_end_token
        self.start_token_id = num_actions  # 6
        self.end_token_id = num_actions + 1 if use_end_token else None  # 7
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-1)  # Ignore padding
        # BCE with logits for supervised turn prediction
        self.turn_bce = nn.BCEWithLogitsLoss(reduction='sum')
        
    def forward(self, action_logits, turn_penalties, target_actions, collision_mask=None):
        """
        action_logits: (batch_size, seq_len, total_tokens) - includes all tokens (0-7)
        turn_penalties: (batch_size, seq_len, 1) - interpreted as turn logits
        target_actions: (batch_size, seq_len) - contains action IDs (0-5) and possibly END (7)
        collision_mask: (batch_size, seq_len) - 1 if collision, 0 if safe
        """
        batch_size, seq_len, total_tokens = action_logits.shape
        
        # Reshape for cross entropy loss
        action_logits_flat = action_logits.view(-1, total_tokens)
        target_actions_flat = target_actions.view(-1)
        
        # Path correctness loss - now properly handles all token IDs
        path_loss = self.ce_loss(action_logits_flat, target_actions_flat)
        
        # Supervised turn loss
        # Compute valid action mask (exclude special tokens)
        valid_actions_mask = (target_actions < self.num_actions)
        # Previous actions (pad first timestep with itself; will be masked out anyway)
        prev_actions = torch.cat([target_actions[:, :1], target_actions[:, :-1]], dim=1)
        prev_valid_mask = torch.cat([torch.zeros_like(valid_actions_mask[:, :1], dtype=torch.bool),
                                     valid_actions_mask[:, :-1]], dim=1)
        # A turn occurs if both current and previous are valid actions and they differ
        both_valid = valid_actions_mask & prev_valid_mask
        is_turn = ((target_actions != prev_actions) & both_valid).float()
        
        # Turn logits predicted by the model
        turn_logits = turn_penalties.squeeze(-1)
        
        # Compute BCE-with-logits only over valid pairs
        num_pairs = both_valid.sum().clamp_min(1).float()
        if num_pairs > 0:
            bce_sum = self.turn_bce(turn_logits[both_valid], is_turn[both_valid])
            turn_loss = bce_sum / num_pairs
        else:
            turn_loss = torch.tensor(0.0, device=action_logits.device)
        
        # Collision penalty - only apply to actual movement actions
        collision_loss = torch.tensor(0.0, device=action_logits.device)
        if collision_mask is not None:
            # Mask collisions to only count for actual movement actions
            masked_collisions = collision_mask.float() * valid_actions_mask.float()
            if valid_actions_mask.sum() > 0:
                collision_loss = (masked_collisions.sum() / valid_actions_mask.sum()) * self.collision_penalty_weight
            
        total_loss = path_loss + self.turn_penalty_weight * turn_loss + collision_loss
        
        return {
            'total_loss': total_loss,
            'path_loss': path_loss,
            'turn_loss': turn_loss,
            'collision_loss': collision_loss
        }


# Utility functions for data preparation
def create_voxel_input(obstacles, start_pos, goal_pos, voxel_dim=(32, 32, 32)):
    """
    Create multi-channel voxel input.
    
    obstacles: (D, H, W) binary array
    start_pos: (x, y, z) tuple
    goal_pos: (x, y, z) tuple
    """
    # Channel 0: obstacles
    obstacle_channel = obstacles.astype(np.float32)
    
    # Channel 1: start position
    start_channel = np.zeros(voxel_dim, dtype=np.float32)
    start_channel[start_pos] = 1.0
    
    # Channel 2: goal position
    goal_channel = np.zeros(voxel_dim, dtype=np.float32)
    goal_channel[goal_pos] = 1.0
    
    # Stack channels
    voxel_input = np.stack([obstacle_channel, start_channel, goal_channel], axis=0)
    
    return voxel_input


def prepare_training_targets(action_sequence, use_end_token=True, num_actions=6):
    """
    Prepare target action sequences for training.
    Ensures action IDs are in range [0, num_actions-1] and adds END token if needed.
    
    action_sequence: list or tensor of action IDs (0-5)
    use_end_token: whether to append END token
    num_actions: number of valid actions
    
    Returns: tensor with proper token IDs
    """
    if isinstance(action_sequence, list):
        action_sequence = torch.tensor(action_sequence)
    
    # Ensure actions are in valid range
    assert (action_sequence >= 0).all() and (action_sequence < num_actions).all(), \
        f"Actions must be in range [0, {num_actions-1}]"
    
    if use_end_token:
        # Append END token (ID = num_actions + 1 = 7)
        end_token = torch.tensor([num_actions + 1])
        target = torch.cat([action_sequence, end_token])
    else:
        target = action_sequence
    
    return target


# Example usage and testing
if __name__ == "__main__":
    # Define problem parameters
    voxel_dim = (32, 32, 32)
    batch_size = 4
    num_actions = 6  # Forward, Back, Left, Right, Up, Down
    
    # Create the complete pathfinding network
    pathfinding_net = PathfindingNetwork(
        voxel_dim=voxel_dim,
        input_channels=3,
        env_feature_dim=512,
        pos_feature_dim=64,
        hidden_dim=256,
        num_actions=num_actions,
        use_end_token=True
    )
    
    print("=== 3D Pathfinding Network Architecture ===")
    print(f"Total parameters: {sum(p.numel() for p in pathfinding_net.parameters()):,}")
    print(f"\nToken ID mapping:")
    print(f"  Actions: 0-5 (Forward, Back, Left, Right, Up, Down)")
    print(f"  START token: {pathfinding_net.path_planner.start_token_id}")
    print(f"  END token: {pathfinding_net.path_planner.end_token_id}")
    
    # Create dummy data
    dummy_voxel_data = torch.randn(batch_size, 3, *voxel_dim)
    dummy_positions = torch.randint(0, 32, (batch_size, 2, 3))  # start and goal positions
    
    # Create proper target actions with END token
    dummy_actions = torch.randint(0, num_actions, (batch_size, 19))  # 19 movement actions
    dummy_target_actions = torch.cat([
        dummy_actions, 
        torch.full((batch_size, 1), pathfinding_net.path_planner.end_token_id)
    ], dim=1)  # Add END token
    
    print(f"\n=== Testing Forward Pass ===")
    print(f"Input voxel shape: {dummy_voxel_data.shape}")
    print(f"Input positions shape: {dummy_positions.shape}")
    print(f"Target actions shape: {dummy_target_actions.shape}")
    print(f"Target action values range: [{dummy_target_actions.min().item()}, {dummy_target_actions.max().item()}]")
    
    # Training forward pass
    pathfinding_net.train()
    action_logits, turn_penalties = pathfinding_net(
        dummy_voxel_data, 
        dummy_positions, 
        dummy_target_actions
    )
    
    print(f"\nTraining mode outputs:")
    print(f"Action logits shape: {action_logits.shape} (should be {(batch_size, 20, 8)})")
    print(f"Turn logits shape: {turn_penalties.shape}")
    
    # Inference forward pass
    pathfinding_net.eval()
    with torch.no_grad():
        generated_paths = pathfinding_net(dummy_voxel_data, dummy_positions)
    
    print(f"\nInference mode outputs:")
    print(f"Generated paths shape: {generated_paths.shape}")
    if generated_paths.shape[1] > 0:
        print(f"Generated action values range: [{generated_paths.min().item()}, {generated_paths.max().item()}]")
    
    # Test collision checking
    test_actions = generated_paths if generated_paths.shape[1] > 0 else dummy_actions
    collision_mask = pathfinding_net.check_collisions(
        dummy_voxel_data, 
        dummy_positions, 
        test_actions
    )
    print(f"Collision mask shape: {collision_mask.shape}")
    
    # Test loss function with proper masking
    loss_fn = PathfindingLoss(
        turn_penalty_weight=0.1, 
        num_actions=num_actions,
        use_end_token=True
    )
    
    # Adjust collision mask to match target sequence length
    if collision_mask.shape[1] >= 20:
        collision_mask_adjusted = collision_mask[:, :20]
    else:
        # Pad with zeros if collision mask is shorter
        padding = torch.zeros(batch_size, 20 - collision_mask.shape[1], 
                             dtype=torch.bool, device=collision_mask.device)
        collision_mask_adjusted = torch.cat([collision_mask, padding], dim=1)
    
    loss_dict = loss_fn(action_logits, turn_penalties, dummy_target_actions, collision_mask_adjusted)
    
    print(f"\n=== Loss Components ===")
    for key, value in loss_dict.items():
        print(f"{key}: {value.item():.4f}")
    
    # Verify that the loss properly masks special tokens
    print(f"\n=== Verification Tests ===")
    
    # Test 1: Verify token ID assignments
    print(f"1. Token IDs are correctly assigned:")
    print(f"   - Movement actions use IDs 0-5: βœ“")
    print(f"   - START token uses ID {pathfinding_net.path_planner.start_token_id}: βœ“")
    print(f"   - END token uses ID {pathfinding_net.path_planner.end_token_id}: βœ“")
    
    # Test 2: Verify Conv-BN-ReLU order
    print(f"2. Conv-BN-ReLU order is standardized: βœ“")
    
    # Test 3: Verify supervised turn labels mask
    with torch.no_grad():
        # Create a sequence with mixed actions and END token
        test_sequence = torch.tensor([[0, 1, 2, 3, 4, 5, 7]])  # Actions 0-5 then END
        valid_mask = (test_sequence < num_actions)
        prev_seq = torch.cat([test_sequence[:, :1], test_sequence[:, :-1]], dim=1)
        prev_valid = torch.cat([torch.zeros_like(valid_mask[:, :1], dtype=torch.bool), valid_mask[:, :-1]], dim=1)
        both_valid = valid_mask & prev_valid
        is_turn = ((test_sequence != prev_seq) & both_valid).float()
        print(f"3. Supervised turn labels test:")
        print(f"   - Test sequence: {test_sequence.tolist()}")
        print(f"   - Valid mask: {valid_mask.tolist()}")
        print(f"   - Both valid mask: {both_valid.tolist()}")
        print(f"   - Turn labels: {is_turn.tolist()}")
    
    # Test 4: Verify action generation doesn't output START token
    print(f"4. Generated paths contain only valid action IDs (0-5):")
    if generated_paths.shape[1] > 0:
        contains_only_valid = (generated_paths >= 0).all() and (generated_paths < num_actions).all()
        print(f"   - Generated actions in valid range: {'βœ“' if contains_only_valid else 'βœ—'}")
    else:
        print(f"   - No actions generated (early END token)")
    
    print(f"\n=== Network Ready for Training ===")