c1tr0n75 commited on
Commit
ee6cc2b
·
verified ·
1 Parent(s): 7f40cc7

adding pathfinding_nn.py alongside app.py

Browse files
Files changed (1) hide show
  1. pathfinding_nn.py +742 -0
pathfinding_nn.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import Tuple, List, Optional
6
+
7
+ class VoxelCNNEncoder(nn.Module):
8
+ """
9
+ Enhanced 3D CNN encoder for voxelized obstruction data with multi-channel support.
10
+ Processes environment obstacles, start position, and goal position.
11
+ """
12
+ def __init__(self,
13
+ input_channels=3, # obstacles + start + goal
14
+ filters_1=32,
15
+ kernel_size_1=(3, 3, 3),
16
+ pool_size_1=(2, 2, 2),
17
+ filters_2=64,
18
+ kernel_size_2=(3, 3, 3),
19
+ pool_size_2=(2, 2, 2),
20
+ filters_3=128,
21
+ kernel_size_3=(3, 3, 3),
22
+ pool_size_3=(2, 2, 2),
23
+ dense_units=512,
24
+ input_voxel_dim=(32, 32, 32),
25
+ dropout_rate=0.2
26
+ ):
27
+ super(VoxelCNNEncoder, self).__init__()
28
+
29
+ self.input_voxel_dim = input_voxel_dim
30
+ self.input_channels = input_channels
31
+
32
+ # First 3D Convolutional Block (Conv-BN-ReLU)
33
+ padding_1 = tuple([(k - 1) // 2 for k in kernel_size_1])
34
+ self.conv1 = nn.Conv3d(input_channels, filters_1, kernel_size_1, padding=padding_1)
35
+ self.bn1 = nn.BatchNorm3d(filters_1)
36
+ self.pool1 = nn.MaxPool3d(pool_size_1)
37
+ self.dropout1 = nn.Dropout3d(dropout_rate)
38
+
39
+ # Second 3D Convolutional Block (Conv-BN-ReLU)
40
+ padding_2 = tuple([(k - 1) // 2 for k in kernel_size_2])
41
+ self.conv2 = nn.Conv3d(filters_1, filters_2, kernel_size_2, padding=padding_2)
42
+ self.bn2 = nn.BatchNorm3d(filters_2)
43
+ self.pool2 = nn.MaxPool3d(pool_size_2)
44
+ self.dropout2 = nn.Dropout3d(dropout_rate)
45
+
46
+ # Third 3D Convolutional Block (Conv-BN-ReLU)
47
+ padding_3 = tuple([(k - 1) // 2 for k in kernel_size_3])
48
+ self.conv3 = nn.Conv3d(filters_2, filters_3, kernel_size_3, padding=padding_3)
49
+ self.bn3 = nn.BatchNorm3d(filters_3)
50
+ self.pool3 = nn.MaxPool3d(pool_size_3)
51
+ self.dropout3 = nn.Dropout3d(dropout_rate)
52
+
53
+ # Calculate flattened size
54
+ self._to_linear_input_size = self._get_conv_output_size()
55
+
56
+ # Dense layers with residual connection
57
+ self.fc1 = nn.Linear(self._to_linear_input_size, dense_units)
58
+ self.fc2 = nn.Linear(dense_units, dense_units)
59
+ self.dropout_fc = nn.Dropout(dropout_rate)
60
+
61
+ def _get_conv_output_size(self):
62
+ with torch.no_grad():
63
+ dummy_input = torch.zeros(1, self.input_channels, *self.input_voxel_dim)
64
+ # Standardized Conv-BN-ReLU order
65
+ x = self.conv1(dummy_input)
66
+ x = self.bn1(x)
67
+ x = F.relu(x)
68
+ x = self.pool1(x)
69
+ x = self.dropout1(x)
70
+
71
+ x = self.conv2(x)
72
+ x = self.bn2(x)
73
+ x = F.relu(x)
74
+ x = self.pool2(x)
75
+ x = self.dropout2(x)
76
+
77
+ x = self.conv3(x)
78
+ x = self.bn3(x)
79
+ x = F.relu(x)
80
+ x = self.pool3(x)
81
+ x = self.dropout3(x)
82
+
83
+ return x.numel()
84
+
85
+ def forward(self, x):
86
+ # First conv block (Conv-BN-ReLU)
87
+ x = self.conv1(x)
88
+ x = self.bn1(x)
89
+ x = F.relu(x)
90
+ x = self.pool1(x)
91
+ x = self.dropout1(x)
92
+
93
+ # Second conv block (Conv-BN-ReLU)
94
+ x = self.conv2(x)
95
+ x = self.bn2(x)
96
+ x = F.relu(x)
97
+ x = self.pool2(x)
98
+ x = self.dropout2(x)
99
+
100
+ # Third conv block (Conv-BN-ReLU)
101
+ x = self.conv3(x)
102
+ x = self.bn3(x)
103
+ x = F.relu(x)
104
+ x = self.pool3(x)
105
+ x = self.dropout3(x)
106
+
107
+ # Flatten and dense layers
108
+ x = x.view(x.size(0), -1)
109
+ x1 = F.relu(self.fc1(x))
110
+ x1 = self.dropout_fc(x1)
111
+ x2 = F.relu(self.fc2(x1))
112
+
113
+ # Residual connection
114
+ return x1 + x2
115
+
116
+
117
+ class PositionEncoder(nn.Module):
118
+ """
119
+ Encodes start and goal positions with learned embeddings.
120
+ """
121
+ def __init__(self, voxel_dim=(32, 32, 32), position_embed_dim=64):
122
+ super(PositionEncoder, self).__init__()
123
+ self.voxel_dim = voxel_dim
124
+ self.position_embed_dim = position_embed_dim
125
+
126
+ # Calculate dimensions for each axis to sum to position_embed_dim
127
+ dim_per_axis = position_embed_dim // 3
128
+ remainder = position_embed_dim % 3
129
+
130
+ x_dim = dim_per_axis + (1 if remainder > 0 else 0)
131
+ y_dim = dim_per_axis + (1 if remainder > 1 else 0)
132
+ z_dim = dim_per_axis
133
+
134
+ # Learned position embeddings for each dimension
135
+ self.x_embed = nn.Embedding(voxel_dim[0], x_dim)
136
+ self.y_embed = nn.Embedding(voxel_dim[1], y_dim)
137
+ self.z_embed = nn.Embedding(voxel_dim[2], z_dim)
138
+
139
+ # Additional processing - fixed input dimension
140
+ self.fc = nn.Linear(2 * position_embed_dim, position_embed_dim)
141
+
142
+ def forward(self, positions):
143
+ """
144
+ positions: (batch_size, 2, 3) - [start_pos, goal_pos] with (x, y, z)
145
+ """
146
+ batch_size = positions.size(0)
147
+
148
+ # Extract coordinates
149
+ # Clamp coordinates defensively to valid index ranges to avoid embedding OOB
150
+ x_coords = positions[:, :, 0].long().clamp_(0, self.voxel_dim[0] - 1) # (batch_size, 2)
151
+ y_coords = positions[:, :, 1].long().clamp_(0, self.voxel_dim[1] - 1) # (batch_size, 2)
152
+ z_coords = positions[:, :, 2].long().clamp_(0, self.voxel_dim[2] - 1) # (batch_size, 2)
153
+
154
+ # Get embeddings
155
+ x_emb = self.x_embed(x_coords) # (batch_size, 2, x_dim)
156
+ y_emb = self.y_embed(y_coords) # (batch_size, 2, y_dim)
157
+ z_emb = self.z_embed(z_coords) # (batch_size, 2, z_dim)
158
+
159
+ # Concatenate embeddings
160
+ pos_emb = torch.cat([x_emb, y_emb, z_emb], dim=-1) # (batch_size, 2, position_embed_dim)
161
+
162
+ # Flatten start and goal embeddings
163
+ pos_emb = pos_emb.view(batch_size, -1) # (batch_size, 2 * position_embed_dim)
164
+
165
+ return F.relu(self.fc(pos_emb))
166
+
167
+
168
+ class PathPlannerTransformer(nn.Module):
169
+ """
170
+ Transformer-based path planner that generates action sequences.
171
+ Fixed token IDs to avoid collision:
172
+ - Actions: 0-5 (Forward, Back, Left, Right, Up, Down)
173
+ - START: 6
174
+ - END: 7
175
+ - PAD: 8 (used only for teacher forcing inputs; targets still use -1 for ignore)
176
+ """
177
+ def __init__(self,
178
+ env_feature_dim=512,
179
+ pos_feature_dim=64,
180
+ hidden_dim=256,
181
+ num_heads=8,
182
+ num_layers=4,
183
+ max_sequence_length=100,
184
+ num_actions=6, # Forward, Back, Left, Right, Up, Down
185
+ use_end_token=True):
186
+ super(PathPlannerTransformer, self).__init__()
187
+
188
+ self.hidden_dim = hidden_dim
189
+ self.max_sequence_length = max_sequence_length
190
+ self.num_actions = num_actions
191
+ self.use_end_token = use_end_token
192
+
193
+ # Fixed token IDs to avoid collision
194
+ self.start_token_id = num_actions # 6
195
+ self.end_token_id = num_actions + 1 if use_end_token else None # 7
196
+ # Reserve a PAD token for embedding inputs during teacher forcing
197
+ self.pad_token_id = (num_actions + 2) if use_end_token else (num_actions + 1)
198
+ # Total tokens include PAD
199
+ self.total_tokens = (num_actions + 3) if use_end_token else (num_actions + 2)
200
+
201
+ # Feature fusion
202
+ self.feature_fusion = nn.Linear(env_feature_dim + pos_feature_dim, hidden_dim)
203
+
204
+ # Action embeddings
205
+ self.action_embed = nn.Embedding(self.total_tokens, hidden_dim)
206
+
207
+ # Positional encoding - register as buffer for proper device handling
208
+ self.register_buffer('pos_encoding', self._create_positional_encoding(max_sequence_length, hidden_dim))
209
+
210
+ # Transformer decoder
211
+ decoder_layer = nn.TransformerDecoderLayer(
212
+ d_model=hidden_dim,
213
+ nhead=num_heads,
214
+ dim_feedforward=hidden_dim * 4,
215
+ dropout=0.1,
216
+ batch_first=True
217
+ )
218
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
219
+
220
+ # Output projection
221
+ self.output_proj = nn.Linear(hidden_dim, self.total_tokens)
222
+
223
+ # Turn head (now supervised via BCE-with-logits against turn labels)
224
+ self.turn_penalty_head = nn.Linear(hidden_dim, 1)
225
+
226
+ def _create_positional_encoding(self, max_len, d_model):
227
+ pe = torch.zeros(max_len, d_model)
228
+ position = torch.arange(0, max_len).unsqueeze(1).float()
229
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
230
+ pe[:, 0::2] = torch.sin(position * div_term)
231
+ pe[:, 1::2] = torch.cos(position * div_term)
232
+ return pe.unsqueeze(0)
233
+
234
+ def forward(self, env_features, pos_features, target_actions=None):
235
+ """
236
+ env_features: (batch_size, env_feature_dim)
237
+ pos_features: (batch_size, pos_feature_dim)
238
+ target_actions: (batch_size, seq_len) - for training (contains action IDs 0-5 and END token 7)
239
+ """
240
+ batch_size = env_features.size(0)
241
+
242
+ # Fuse environment and position features
243
+ fused_features = self.feature_fusion(torch.cat([env_features, pos_features], dim=1))
244
+
245
+ # Create memory (encoder output) by repeating fused features
246
+ memory = fused_features.unsqueeze(1).repeat(1, self.max_sequence_length, 1)
247
+
248
+ if target_actions is not None:
249
+ # Training mode: use teacher forcing
250
+ seq_len = target_actions.size(1)
251
+
252
+ # Create input sequence (START token + target_actions[:-1])
253
+ start_tokens = torch.full((batch_size, 1), self.start_token_id,
254
+ dtype=torch.long, device=target_actions.device)
255
+ input_seq = torch.cat([start_tokens, target_actions[:, :-1]], dim=1)
256
+ # Replace padding (-1) in teacher-forced inputs with PAD token id to avoid OOB in embedding
257
+ input_seq = torch.where(input_seq < 0, torch.full_like(input_seq, self.pad_token_id), input_seq)
258
+
259
+ # Embed actions and add positional encoding
260
+ embedded = self.action_embed(input_seq)
261
+ embedded = embedded + self.pos_encoding[:, :seq_len, :]
262
+
263
+ # Generate attention mask (causal mask)
264
+ tgt_mask = self._generate_square_subsequent_mask(seq_len).to(embedded.device)
265
+
266
+ # Transformer decoder forward pass
267
+ output = self.transformer_decoder(
268
+ tgt=embedded,
269
+ memory=memory[:, :seq_len, :],
270
+ tgt_mask=tgt_mask
271
+ )
272
+
273
+ # Output projections
274
+ action_logits = self.output_proj(output)
275
+ # Turn logits for supervised turn classification
276
+ turn_logits = self.turn_penalty_head(output)
277
+
278
+ return action_logits, turn_logits
279
+ else:
280
+ # Inference mode: generate sequence autoregressively
281
+ return self._generate_path(memory, batch_size)
282
+
283
+ def _generate_square_subsequent_mask(self, sz):
284
+ mask = torch.triu(torch.ones(sz, sz), diagonal=1)
285
+ mask = mask.masked_fill(mask == 1, float('-inf'))
286
+ return mask
287
+
288
+ def _generate_path(self, memory, batch_size):
289
+ """
290
+ Generate path sequence autoregressively, handling batches correctly.
291
+ Fixes bugs related to premature stopping and inclusion of special tokens.
292
+ """
293
+ device = memory.device
294
+
295
+ # Start with START token
296
+ input_seq = torch.full((batch_size, 1), self.start_token_id, dtype=torch.long, device=device)
297
+
298
+ # Keep track of sequences that have generated an END token
299
+ is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
300
+
301
+ for step in range(self.max_sequence_length):
302
+ # Embed current sequence
303
+ embedded = self.action_embed(input_seq)
304
+ seq_len = embedded.size(1)
305
+ embedded = embedded + self.pos_encoding[:, :seq_len, :]
306
+
307
+ # Generate causal mask
308
+ tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)
309
+
310
+ # Forward pass
311
+ output = self.transformer_decoder(
312
+ tgt=embedded,
313
+ memory=memory[:, :seq_len, :],
314
+ tgt_mask=tgt_mask
315
+ )
316
+
317
+ # Get next action probabilities from the last token in the sequence
318
+ next_action_logits = self.output_proj(output[:, -1, :])
319
+ next_actions = torch.argmax(next_action_logits, dim=-1, keepdim=True)
320
+
321
+ # Append the predicted actions to the sequence
322
+ input_seq = torch.cat([input_seq, next_actions], dim=1)
323
+
324
+ # Update the finished mask for any sequence that just produced an END token
325
+ if self.use_end_token:
326
+ is_finished |= (next_actions.squeeze(-1) == self.end_token_id)
327
+
328
+ # If all sequences in the batch are finished, we can stop early
329
+ if is_finished.all():
330
+ break
331
+
332
+ # Post-processing to create a clean, dense tensor of valid actions
333
+ # Remove the initial START token from all sequences
334
+ raw_paths = input_seq[:, 1:]
335
+
336
+ clean_paths_list = []
337
+ max_len = 0
338
+
339
+ for i in range(batch_size):
340
+ path = []
341
+ for token_id in raw_paths[i]:
342
+ # Stop decoding for this path if an END token is found
343
+ if self.use_end_token and token_id.item() == self.end_token_id:
344
+ break
345
+ # Only include valid movement actions in the final path
346
+ if token_id.item() < self.num_actions:
347
+ path.append(token_id.item())
348
+
349
+ clean_paths_list.append(path)
350
+ if len(path) > max_len:
351
+ max_len = len(path)
352
+
353
+ # Return an empty tensor if no valid actions were generated
354
+ if max_len == 0:
355
+ return torch.zeros(batch_size, 0, dtype=torch.long, device=device)
356
+
357
+ # Pad all paths to the length of the longest path in the batch
358
+ # We use the END token ID for padding, as downstream functions like
359
+ # check_collisions are designed to ignore non-action tokens.
360
+ pad_value = self.end_token_id if self.use_end_token else self.num_actions
361
+ padded_paths = torch.full((batch_size, max_len), pad_value, dtype=torch.long, device=device)
362
+
363
+ for i, path in enumerate(clean_paths_list):
364
+ if len(path) > 0:
365
+ padded_paths[i, :len(path)] = torch.tensor(path, dtype=torch.long, device=device)
366
+
367
+ return padded_paths
368
+
369
+
370
+ class PathfindingNetwork(nn.Module):
371
+ """
372
+ Complete pathfinding network combining CNN encoder, position encoder, and transformer planner.
373
+ """
374
+ def __init__(self,
375
+ voxel_dim=(32, 32, 32),
376
+ input_channels=3,
377
+ env_feature_dim=512,
378
+ pos_feature_dim=64,
379
+ hidden_dim=256,
380
+ num_actions=6,
381
+ use_end_token=True):
382
+ super(PathfindingNetwork, self).__init__()
383
+
384
+ self.voxel_dim = voxel_dim
385
+ self.num_actions = num_actions
386
+
387
+ self.voxel_encoder = VoxelCNNEncoder(
388
+ input_channels=input_channels,
389
+ dense_units=env_feature_dim,
390
+ input_voxel_dim=voxel_dim
391
+ )
392
+
393
+ self.position_encoder = PositionEncoder(
394
+ voxel_dim=voxel_dim,
395
+ position_embed_dim=pos_feature_dim
396
+ )
397
+
398
+ self.path_planner = PathPlannerTransformer(
399
+ env_feature_dim=env_feature_dim,
400
+ pos_feature_dim=pos_feature_dim,
401
+ hidden_dim=hidden_dim,
402
+ num_actions=num_actions,
403
+ use_end_token=use_end_token
404
+ )
405
+
406
+ def forward(self, voxel_data, positions, target_actions=None):
407
+ """
408
+ voxel_data: (batch_size, 3, D, H, W) - [obstacles, start_mask, goal_mask]
409
+ positions: (batch_size, 2, 3) - [start_pos, goal_pos]
410
+ target_actions: (batch_size, seq_len) - for training
411
+ """
412
+ # Encode environment
413
+ env_features = self.voxel_encoder(voxel_data)
414
+
415
+ # Encode positions
416
+ pos_features = self.position_encoder(positions)
417
+
418
+ # Generate path
419
+ if target_actions is not None:
420
+ action_logits, turn_penalties = self.path_planner(env_features, pos_features, target_actions)
421
+ return action_logits, turn_penalties
422
+ else:
423
+ generated_path = self.path_planner(env_features, pos_features)
424
+ return generated_path
425
+
426
+ def check_collisions(self, voxel_data, positions, actions):
427
+ """
428
+ Check if actions lead to collisions with obstacles.
429
+
430
+ voxel_data: (batch_size, 3, D, H, W)
431
+ positions: (batch_size, 2, 3) - start positions
432
+ actions: (batch_size, seq_len) - action sequences
433
+
434
+ Returns: (batch_size, seq_len) collision mask
435
+ """
436
+ batch_size, seq_len = actions.shape
437
+ device = actions.device
438
+
439
+ # Extract obstacle channel
440
+ obstacles = voxel_data[:, 0, :, :, :] # (batch_size, D, H, W)
441
+
442
+ # Action to direction mapping
443
+ directions = torch.tensor([
444
+ [1, 0, 0], # Forward (z+)
445
+ [-1, 0, 0], # Back (z-)
446
+ [0, 1, 0], # Left (x+)
447
+ [0, -1, 0], # Right (x-)
448
+ [0, 0, 1], # Up (y+)
449
+ [0, 0, -1] # Down (y-)
450
+ ], dtype=torch.long, device=device)
451
+
452
+ collision_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
453
+ current_pos = positions[:, 0, :].clone() # Start from start position
454
+
455
+ for t in range(seq_len):
456
+ # Get actions for this timestep
457
+ action_t = actions[:, t]
458
+
459
+ # Only process valid actions (0-5), skip special tokens
460
+ valid_actions = action_t < self.num_actions
461
+
462
+ # Update positions based on actions
463
+ for b in range(batch_size):
464
+ if valid_actions[b]:
465
+ direction = directions[action_t[b]]
466
+ new_pos = current_pos[b] + direction
467
+
468
+ # Check bounds
469
+ if (new_pos >= 0).all() and (new_pos[0] < self.voxel_dim[0]) and \
470
+ (new_pos[1] < self.voxel_dim[1]) and (new_pos[2] < self.voxel_dim[2]):
471
+ # Check collision
472
+ if obstacles[b, new_pos[0], new_pos[1], new_pos[2]] > 0:
473
+ collision_mask[b, t] = True
474
+ else:
475
+ current_pos[b] = new_pos
476
+ else:
477
+ # Out of bounds counts as collision
478
+ collision_mask[b, t] = True
479
+
480
+ return collision_mask
481
+
482
+
483
+ class PathfindingLoss(nn.Module):
484
+ """
485
+ Custom loss function that balances path correctness and turn minimization.
486
+ Properly handles special tokens (START=6, END=7) and action tokens (0-5).
487
+ Turn loss is supervised: a turn occurs when consecutive valid actions differ.
488
+ """
489
+ def __init__(self, turn_penalty_weight=0.1, collision_penalty_weight=10.0,
490
+ num_actions=6, use_end_token=True):
491
+ super(PathfindingLoss, self).__init__()
492
+ self.turn_penalty_weight = turn_penalty_weight
493
+ self.collision_penalty_weight = collision_penalty_weight
494
+ self.num_actions = num_actions
495
+ self.use_end_token = use_end_token
496
+ self.start_token_id = num_actions # 6
497
+ self.end_token_id = num_actions + 1 if use_end_token else None # 7
498
+ self.ce_loss = nn.CrossEntropyLoss(ignore_index=-1) # Ignore padding
499
+ # BCE with logits for supervised turn prediction
500
+ self.turn_bce = nn.BCEWithLogitsLoss(reduction='sum')
501
+
502
+ def forward(self, action_logits, turn_penalties, target_actions, collision_mask=None):
503
+ """
504
+ action_logits: (batch_size, seq_len, total_tokens) - includes all tokens (0-7)
505
+ turn_penalties: (batch_size, seq_len, 1) - interpreted as turn logits
506
+ target_actions: (batch_size, seq_len) - contains action IDs (0-5) and possibly END (7)
507
+ collision_mask: (batch_size, seq_len) - 1 if collision, 0 if safe
508
+ """
509
+ batch_size, seq_len, total_tokens = action_logits.shape
510
+
511
+ # Reshape for cross entropy loss
512
+ action_logits_flat = action_logits.view(-1, total_tokens)
513
+ target_actions_flat = target_actions.view(-1)
514
+
515
+ # Path correctness loss - now properly handles all token IDs
516
+ path_loss = self.ce_loss(action_logits_flat, target_actions_flat)
517
+
518
+ # Supervised turn loss
519
+ # Compute valid action mask (exclude special tokens)
520
+ valid_actions_mask = (target_actions < self.num_actions)
521
+ # Previous actions (pad first timestep with itself; will be masked out anyway)
522
+ prev_actions = torch.cat([target_actions[:, :1], target_actions[:, :-1]], dim=1)
523
+ prev_valid_mask = torch.cat([torch.zeros_like(valid_actions_mask[:, :1], dtype=torch.bool),
524
+ valid_actions_mask[:, :-1]], dim=1)
525
+ # A turn occurs if both current and previous are valid actions and they differ
526
+ both_valid = valid_actions_mask & prev_valid_mask
527
+ is_turn = ((target_actions != prev_actions) & both_valid).float()
528
+
529
+ # Turn logits predicted by the model
530
+ turn_logits = turn_penalties.squeeze(-1)
531
+
532
+ # Compute BCE-with-logits only over valid pairs
533
+ num_pairs = both_valid.sum().clamp_min(1).float()
534
+ if num_pairs > 0:
535
+ bce_sum = self.turn_bce(turn_logits[both_valid], is_turn[both_valid])
536
+ turn_loss = bce_sum / num_pairs
537
+ else:
538
+ turn_loss = torch.tensor(0.0, device=action_logits.device)
539
+
540
+ # Collision penalty - only apply to actual movement actions
541
+ collision_loss = torch.tensor(0.0, device=action_logits.device)
542
+ if collision_mask is not None:
543
+ # Mask collisions to only count for actual movement actions
544
+ masked_collisions = collision_mask.float() * valid_actions_mask.float()
545
+ if valid_actions_mask.sum() > 0:
546
+ collision_loss = (masked_collisions.sum() / valid_actions_mask.sum()) * self.collision_penalty_weight
547
+
548
+ total_loss = path_loss + self.turn_penalty_weight * turn_loss + collision_loss
549
+
550
+ return {
551
+ 'total_loss': total_loss,
552
+ 'path_loss': path_loss,
553
+ 'turn_loss': turn_loss,
554
+ 'collision_loss': collision_loss
555
+ }
556
+
557
+
558
+ # Utility functions for data preparation
559
+ def create_voxel_input(obstacles, start_pos, goal_pos, voxel_dim=(32, 32, 32)):
560
+ """
561
+ Create multi-channel voxel input.
562
+
563
+ obstacles: (D, H, W) binary array
564
+ start_pos: (x, y, z) tuple
565
+ goal_pos: (x, y, z) tuple
566
+ """
567
+ # Channel 0: obstacles
568
+ obstacle_channel = obstacles.astype(np.float32)
569
+
570
+ # Channel 1: start position
571
+ start_channel = np.zeros(voxel_dim, dtype=np.float32)
572
+ start_channel[start_pos] = 1.0
573
+
574
+ # Channel 2: goal position
575
+ goal_channel = np.zeros(voxel_dim, dtype=np.float32)
576
+ goal_channel[goal_pos] = 1.0
577
+
578
+ # Stack channels
579
+ voxel_input = np.stack([obstacle_channel, start_channel, goal_channel], axis=0)
580
+
581
+ return voxel_input
582
+
583
+
584
+ def prepare_training_targets(action_sequence, use_end_token=True, num_actions=6):
585
+ """
586
+ Prepare target action sequences for training.
587
+ Ensures action IDs are in range [0, num_actions-1] and adds END token if needed.
588
+
589
+ action_sequence: list or tensor of action IDs (0-5)
590
+ use_end_token: whether to append END token
591
+ num_actions: number of valid actions
592
+
593
+ Returns: tensor with proper token IDs
594
+ """
595
+ if isinstance(action_sequence, list):
596
+ action_sequence = torch.tensor(action_sequence)
597
+
598
+ # Ensure actions are in valid range
599
+ assert (action_sequence >= 0).all() and (action_sequence < num_actions).all(), \
600
+ f"Actions must be in range [0, {num_actions-1}]"
601
+
602
+ if use_end_token:
603
+ # Append END token (ID = num_actions + 1 = 7)
604
+ end_token = torch.tensor([num_actions + 1])
605
+ target = torch.cat([action_sequence, end_token])
606
+ else:
607
+ target = action_sequence
608
+
609
+ return target
610
+
611
+
612
+ # Example usage and testing
613
+ if __name__ == "__main__":
614
+ # Define problem parameters
615
+ voxel_dim = (32, 32, 32)
616
+ batch_size = 4
617
+ num_actions = 6 # Forward, Back, Left, Right, Up, Down
618
+
619
+ # Create the complete pathfinding network
620
+ pathfinding_net = PathfindingNetwork(
621
+ voxel_dim=voxel_dim,
622
+ input_channels=3,
623
+ env_feature_dim=512,
624
+ pos_feature_dim=64,
625
+ hidden_dim=256,
626
+ num_actions=num_actions,
627
+ use_end_token=True
628
+ )
629
+
630
+ print("=== 3D Pathfinding Network Architecture ===")
631
+ print(f"Total parameters: {sum(p.numel() for p in pathfinding_net.parameters()):,}")
632
+ print(f"\nToken ID mapping:")
633
+ print(f" Actions: 0-5 (Forward, Back, Left, Right, Up, Down)")
634
+ print(f" START token: {pathfinding_net.path_planner.start_token_id}")
635
+ print(f" END token: {pathfinding_net.path_planner.end_token_id}")
636
+
637
+ # Create dummy data
638
+ dummy_voxel_data = torch.randn(batch_size, 3, *voxel_dim)
639
+ dummy_positions = torch.randint(0, 32, (batch_size, 2, 3)) # start and goal positions
640
+
641
+ # Create proper target actions with END token
642
+ dummy_actions = torch.randint(0, num_actions, (batch_size, 19)) # 19 movement actions
643
+ dummy_target_actions = torch.cat([
644
+ dummy_actions,
645
+ torch.full((batch_size, 1), pathfinding_net.path_planner.end_token_id)
646
+ ], dim=1) # Add END token
647
+
648
+ print(f"\n=== Testing Forward Pass ===")
649
+ print(f"Input voxel shape: {dummy_voxel_data.shape}")
650
+ print(f"Input positions shape: {dummy_positions.shape}")
651
+ print(f"Target actions shape: {dummy_target_actions.shape}")
652
+ print(f"Target action values range: [{dummy_target_actions.min().item()}, {dummy_target_actions.max().item()}]")
653
+
654
+ # Training forward pass
655
+ pathfinding_net.train()
656
+ action_logits, turn_penalties = pathfinding_net(
657
+ dummy_voxel_data,
658
+ dummy_positions,
659
+ dummy_target_actions
660
+ )
661
+
662
+ print(f"\nTraining mode outputs:")
663
+ print(f"Action logits shape: {action_logits.shape} (should be {(batch_size, 20, 8)})")
664
+ print(f"Turn logits shape: {turn_penalties.shape}")
665
+
666
+ # Inference forward pass
667
+ pathfinding_net.eval()
668
+ with torch.no_grad():
669
+ generated_paths = pathfinding_net(dummy_voxel_data, dummy_positions)
670
+
671
+ print(f"\nInference mode outputs:")
672
+ print(f"Generated paths shape: {generated_paths.shape}")
673
+ if generated_paths.shape[1] > 0:
674
+ print(f"Generated action values range: [{generated_paths.min().item()}, {generated_paths.max().item()}]")
675
+
676
+ # Test collision checking
677
+ test_actions = generated_paths if generated_paths.shape[1] > 0 else dummy_actions
678
+ collision_mask = pathfinding_net.check_collisions(
679
+ dummy_voxel_data,
680
+ dummy_positions,
681
+ test_actions
682
+ )
683
+ print(f"Collision mask shape: {collision_mask.shape}")
684
+
685
+ # Test loss function with proper masking
686
+ loss_fn = PathfindingLoss(
687
+ turn_penalty_weight=0.1,
688
+ num_actions=num_actions,
689
+ use_end_token=True
690
+ )
691
+
692
+ # Adjust collision mask to match target sequence length
693
+ if collision_mask.shape[1] >= 20:
694
+ collision_mask_adjusted = collision_mask[:, :20]
695
+ else:
696
+ # Pad with zeros if collision mask is shorter
697
+ padding = torch.zeros(batch_size, 20 - collision_mask.shape[1],
698
+ dtype=torch.bool, device=collision_mask.device)
699
+ collision_mask_adjusted = torch.cat([collision_mask, padding], dim=1)
700
+
701
+ loss_dict = loss_fn(action_logits, turn_penalties, dummy_target_actions, collision_mask_adjusted)
702
+
703
+ print(f"\n=== Loss Components ===")
704
+ for key, value in loss_dict.items():
705
+ print(f"{key}: {value.item():.4f}")
706
+
707
+ # Verify that the loss properly masks special tokens
708
+ print(f"\n=== Verification Tests ===")
709
+
710
+ # Test 1: Verify token ID assignments
711
+ print(f"1. Token IDs are correctly assigned:")
712
+ print(f" - Movement actions use IDs 0-5: ✓")
713
+ print(f" - START token uses ID {pathfinding_net.path_planner.start_token_id}: ✓")
714
+ print(f" - END token uses ID {pathfinding_net.path_planner.end_token_id}: ✓")
715
+
716
+ # Test 2: Verify Conv-BN-ReLU order
717
+ print(f"2. Conv-BN-ReLU order is standardized: ✓")
718
+
719
+ # Test 3: Verify supervised turn labels mask
720
+ with torch.no_grad():
721
+ # Create a sequence with mixed actions and END token
722
+ test_sequence = torch.tensor([[0, 1, 2, 3, 4, 5, 7]]) # Actions 0-5 then END
723
+ valid_mask = (test_sequence < num_actions)
724
+ prev_seq = torch.cat([test_sequence[:, :1], test_sequence[:, :-1]], dim=1)
725
+ prev_valid = torch.cat([torch.zeros_like(valid_mask[:, :1], dtype=torch.bool), valid_mask[:, :-1]], dim=1)
726
+ both_valid = valid_mask & prev_valid
727
+ is_turn = ((test_sequence != prev_seq) & both_valid).float()
728
+ print(f"3. Supervised turn labels test:")
729
+ print(f" - Test sequence: {test_sequence.tolist()}")
730
+ print(f" - Valid mask: {valid_mask.tolist()}")
731
+ print(f" - Both valid mask: {both_valid.tolist()}")
732
+ print(f" - Turn labels: {is_turn.tolist()}")
733
+
734
+ # Test 4: Verify action generation doesn't output START token
735
+ print(f"4. Generated paths contain only valid action IDs (0-5):")
736
+ if generated_paths.shape[1] > 0:
737
+ contains_only_valid = (generated_paths >= 0).all() and (generated_paths < num_actions).all()
738
+ print(f" - Generated actions in valid range: {'✓' if contains_only_valid else '✗'}")
739
+ else:
740
+ print(f" - No actions generated (early END token)")
741
+
742
+ print(f"\n=== Network Ready for Training ===")