VoxelPathFinder / pathfinding_nn.py
c1tr0n75's picture
adding pathfinding_nn.py alongside app.py
ee6cc2b verified
raw
history blame
30.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Optional
class VoxelCNNEncoder(nn.Module):
"""
Enhanced 3D CNN encoder for voxelized obstruction data with multi-channel support.
Processes environment obstacles, start position, and goal position.
"""
def __init__(self,
input_channels=3, # obstacles + start + goal
filters_1=32,
kernel_size_1=(3, 3, 3),
pool_size_1=(2, 2, 2),
filters_2=64,
kernel_size_2=(3, 3, 3),
pool_size_2=(2, 2, 2),
filters_3=128,
kernel_size_3=(3, 3, 3),
pool_size_3=(2, 2, 2),
dense_units=512,
input_voxel_dim=(32, 32, 32),
dropout_rate=0.2
):
super(VoxelCNNEncoder, self).__init__()
self.input_voxel_dim = input_voxel_dim
self.input_channels = input_channels
# First 3D Convolutional Block (Conv-BN-ReLU)
padding_1 = tuple([(k - 1) // 2 for k in kernel_size_1])
self.conv1 = nn.Conv3d(input_channels, filters_1, kernel_size_1, padding=padding_1)
self.bn1 = nn.BatchNorm3d(filters_1)
self.pool1 = nn.MaxPool3d(pool_size_1)
self.dropout1 = nn.Dropout3d(dropout_rate)
# Second 3D Convolutional Block (Conv-BN-ReLU)
padding_2 = tuple([(k - 1) // 2 for k in kernel_size_2])
self.conv2 = nn.Conv3d(filters_1, filters_2, kernel_size_2, padding=padding_2)
self.bn2 = nn.BatchNorm3d(filters_2)
self.pool2 = nn.MaxPool3d(pool_size_2)
self.dropout2 = nn.Dropout3d(dropout_rate)
# Third 3D Convolutional Block (Conv-BN-ReLU)
padding_3 = tuple([(k - 1) // 2 for k in kernel_size_3])
self.conv3 = nn.Conv3d(filters_2, filters_3, kernel_size_3, padding=padding_3)
self.bn3 = nn.BatchNorm3d(filters_3)
self.pool3 = nn.MaxPool3d(pool_size_3)
self.dropout3 = nn.Dropout3d(dropout_rate)
# Calculate flattened size
self._to_linear_input_size = self._get_conv_output_size()
# Dense layers with residual connection
self.fc1 = nn.Linear(self._to_linear_input_size, dense_units)
self.fc2 = nn.Linear(dense_units, dense_units)
self.dropout_fc = nn.Dropout(dropout_rate)
def _get_conv_output_size(self):
with torch.no_grad():
dummy_input = torch.zeros(1, self.input_channels, *self.input_voxel_dim)
# Standardized Conv-BN-ReLU order
x = self.conv1(dummy_input)
x = self.bn1(x)
x = F.relu(x)
x = self.pool1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.pool2(x)
x = self.dropout2(x)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.pool3(x)
x = self.dropout3(x)
return x.numel()
def forward(self, x):
# First conv block (Conv-BN-ReLU)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.pool1(x)
x = self.dropout1(x)
# Second conv block (Conv-BN-ReLU)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.pool2(x)
x = self.dropout2(x)
# Third conv block (Conv-BN-ReLU)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.pool3(x)
x = self.dropout3(x)
# Flatten and dense layers
x = x.view(x.size(0), -1)
x1 = F.relu(self.fc1(x))
x1 = self.dropout_fc(x1)
x2 = F.relu(self.fc2(x1))
# Residual connection
return x1 + x2
class PositionEncoder(nn.Module):
"""
Encodes start and goal positions with learned embeddings.
"""
def __init__(self, voxel_dim=(32, 32, 32), position_embed_dim=64):
super(PositionEncoder, self).__init__()
self.voxel_dim = voxel_dim
self.position_embed_dim = position_embed_dim
# Calculate dimensions for each axis to sum to position_embed_dim
dim_per_axis = position_embed_dim // 3
remainder = position_embed_dim % 3
x_dim = dim_per_axis + (1 if remainder > 0 else 0)
y_dim = dim_per_axis + (1 if remainder > 1 else 0)
z_dim = dim_per_axis
# Learned position embeddings for each dimension
self.x_embed = nn.Embedding(voxel_dim[0], x_dim)
self.y_embed = nn.Embedding(voxel_dim[1], y_dim)
self.z_embed = nn.Embedding(voxel_dim[2], z_dim)
# Additional processing - fixed input dimension
self.fc = nn.Linear(2 * position_embed_dim, position_embed_dim)
def forward(self, positions):
"""
positions: (batch_size, 2, 3) - [start_pos, goal_pos] with (x, y, z)
"""
batch_size = positions.size(0)
# Extract coordinates
# Clamp coordinates defensively to valid index ranges to avoid embedding OOB
x_coords = positions[:, :, 0].long().clamp_(0, self.voxel_dim[0] - 1) # (batch_size, 2)
y_coords = positions[:, :, 1].long().clamp_(0, self.voxel_dim[1] - 1) # (batch_size, 2)
z_coords = positions[:, :, 2].long().clamp_(0, self.voxel_dim[2] - 1) # (batch_size, 2)
# Get embeddings
x_emb = self.x_embed(x_coords) # (batch_size, 2, x_dim)
y_emb = self.y_embed(y_coords) # (batch_size, 2, y_dim)
z_emb = self.z_embed(z_coords) # (batch_size, 2, z_dim)
# Concatenate embeddings
pos_emb = torch.cat([x_emb, y_emb, z_emb], dim=-1) # (batch_size, 2, position_embed_dim)
# Flatten start and goal embeddings
pos_emb = pos_emb.view(batch_size, -1) # (batch_size, 2 * position_embed_dim)
return F.relu(self.fc(pos_emb))
class PathPlannerTransformer(nn.Module):
"""
Transformer-based path planner that generates action sequences.
Fixed token IDs to avoid collision:
- Actions: 0-5 (Forward, Back, Left, Right, Up, Down)
- START: 6
- END: 7
- PAD: 8 (used only for teacher forcing inputs; targets still use -1 for ignore)
"""
def __init__(self,
env_feature_dim=512,
pos_feature_dim=64,
hidden_dim=256,
num_heads=8,
num_layers=4,
max_sequence_length=100,
num_actions=6, # Forward, Back, Left, Right, Up, Down
use_end_token=True):
super(PathPlannerTransformer, self).__init__()
self.hidden_dim = hidden_dim
self.max_sequence_length = max_sequence_length
self.num_actions = num_actions
self.use_end_token = use_end_token
# Fixed token IDs to avoid collision
self.start_token_id = num_actions # 6
self.end_token_id = num_actions + 1 if use_end_token else None # 7
# Reserve a PAD token for embedding inputs during teacher forcing
self.pad_token_id = (num_actions + 2) if use_end_token else (num_actions + 1)
# Total tokens include PAD
self.total_tokens = (num_actions + 3) if use_end_token else (num_actions + 2)
# Feature fusion
self.feature_fusion = nn.Linear(env_feature_dim + pos_feature_dim, hidden_dim)
# Action embeddings
self.action_embed = nn.Embedding(self.total_tokens, hidden_dim)
# Positional encoding - register as buffer for proper device handling
self.register_buffer('pos_encoding', self._create_positional_encoding(max_sequence_length, hidden_dim))
# Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=0.1,
batch_first=True
)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
# Output projection
self.output_proj = nn.Linear(hidden_dim, self.total_tokens)
# Turn head (now supervised via BCE-with-logits against turn labels)
self.turn_penalty_head = nn.Linear(hidden_dim, 1)
def _create_positional_encoding(self, max_len, d_model):
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0)
def forward(self, env_features, pos_features, target_actions=None):
"""
env_features: (batch_size, env_feature_dim)
pos_features: (batch_size, pos_feature_dim)
target_actions: (batch_size, seq_len) - for training (contains action IDs 0-5 and END token 7)
"""
batch_size = env_features.size(0)
# Fuse environment and position features
fused_features = self.feature_fusion(torch.cat([env_features, pos_features], dim=1))
# Create memory (encoder output) by repeating fused features
memory = fused_features.unsqueeze(1).repeat(1, self.max_sequence_length, 1)
if target_actions is not None:
# Training mode: use teacher forcing
seq_len = target_actions.size(1)
# Create input sequence (START token + target_actions[:-1])
start_tokens = torch.full((batch_size, 1), self.start_token_id,
dtype=torch.long, device=target_actions.device)
input_seq = torch.cat([start_tokens, target_actions[:, :-1]], dim=1)
# Replace padding (-1) in teacher-forced inputs with PAD token id to avoid OOB in embedding
input_seq = torch.where(input_seq < 0, torch.full_like(input_seq, self.pad_token_id), input_seq)
# Embed actions and add positional encoding
embedded = self.action_embed(input_seq)
embedded = embedded + self.pos_encoding[:, :seq_len, :]
# Generate attention mask (causal mask)
tgt_mask = self._generate_square_subsequent_mask(seq_len).to(embedded.device)
# Transformer decoder forward pass
output = self.transformer_decoder(
tgt=embedded,
memory=memory[:, :seq_len, :],
tgt_mask=tgt_mask
)
# Output projections
action_logits = self.output_proj(output)
# Turn logits for supervised turn classification
turn_logits = self.turn_penalty_head(output)
return action_logits, turn_logits
else:
# Inference mode: generate sequence autoregressively
return self._generate_path(memory, batch_size)
def _generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def _generate_path(self, memory, batch_size):
"""
Generate path sequence autoregressively, handling batches correctly.
Fixes bugs related to premature stopping and inclusion of special tokens.
"""
device = memory.device
# Start with START token
input_seq = torch.full((batch_size, 1), self.start_token_id, dtype=torch.long, device=device)
# Keep track of sequences that have generated an END token
is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
for step in range(self.max_sequence_length):
# Embed current sequence
embedded = self.action_embed(input_seq)
seq_len = embedded.size(1)
embedded = embedded + self.pos_encoding[:, :seq_len, :]
# Generate causal mask
tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)
# Forward pass
output = self.transformer_decoder(
tgt=embedded,
memory=memory[:, :seq_len, :],
tgt_mask=tgt_mask
)
# Get next action probabilities from the last token in the sequence
next_action_logits = self.output_proj(output[:, -1, :])
next_actions = torch.argmax(next_action_logits, dim=-1, keepdim=True)
# Append the predicted actions to the sequence
input_seq = torch.cat([input_seq, next_actions], dim=1)
# Update the finished mask for any sequence that just produced an END token
if self.use_end_token:
is_finished |= (next_actions.squeeze(-1) == self.end_token_id)
# If all sequences in the batch are finished, we can stop early
if is_finished.all():
break
# Post-processing to create a clean, dense tensor of valid actions
# Remove the initial START token from all sequences
raw_paths = input_seq[:, 1:]
clean_paths_list = []
max_len = 0
for i in range(batch_size):
path = []
for token_id in raw_paths[i]:
# Stop decoding for this path if an END token is found
if self.use_end_token and token_id.item() == self.end_token_id:
break
# Only include valid movement actions in the final path
if token_id.item() < self.num_actions:
path.append(token_id.item())
clean_paths_list.append(path)
if len(path) > max_len:
max_len = len(path)
# Return an empty tensor if no valid actions were generated
if max_len == 0:
return torch.zeros(batch_size, 0, dtype=torch.long, device=device)
# Pad all paths to the length of the longest path in the batch
# We use the END token ID for padding, as downstream functions like
# check_collisions are designed to ignore non-action tokens.
pad_value = self.end_token_id if self.use_end_token else self.num_actions
padded_paths = torch.full((batch_size, max_len), pad_value, dtype=torch.long, device=device)
for i, path in enumerate(clean_paths_list):
if len(path) > 0:
padded_paths[i, :len(path)] = torch.tensor(path, dtype=torch.long, device=device)
return padded_paths
class PathfindingNetwork(nn.Module):
"""
Complete pathfinding network combining CNN encoder, position encoder, and transformer planner.
"""
def __init__(self,
voxel_dim=(32, 32, 32),
input_channels=3,
env_feature_dim=512,
pos_feature_dim=64,
hidden_dim=256,
num_actions=6,
use_end_token=True):
super(PathfindingNetwork, self).__init__()
self.voxel_dim = voxel_dim
self.num_actions = num_actions
self.voxel_encoder = VoxelCNNEncoder(
input_channels=input_channels,
dense_units=env_feature_dim,
input_voxel_dim=voxel_dim
)
self.position_encoder = PositionEncoder(
voxel_dim=voxel_dim,
position_embed_dim=pos_feature_dim
)
self.path_planner = PathPlannerTransformer(
env_feature_dim=env_feature_dim,
pos_feature_dim=pos_feature_dim,
hidden_dim=hidden_dim,
num_actions=num_actions,
use_end_token=use_end_token
)
def forward(self, voxel_data, positions, target_actions=None):
"""
voxel_data: (batch_size, 3, D, H, W) - [obstacles, start_mask, goal_mask]
positions: (batch_size, 2, 3) - [start_pos, goal_pos]
target_actions: (batch_size, seq_len) - for training
"""
# Encode environment
env_features = self.voxel_encoder(voxel_data)
# Encode positions
pos_features = self.position_encoder(positions)
# Generate path
if target_actions is not None:
action_logits, turn_penalties = self.path_planner(env_features, pos_features, target_actions)
return action_logits, turn_penalties
else:
generated_path = self.path_planner(env_features, pos_features)
return generated_path
def check_collisions(self, voxel_data, positions, actions):
"""
Check if actions lead to collisions with obstacles.
voxel_data: (batch_size, 3, D, H, W)
positions: (batch_size, 2, 3) - start positions
actions: (batch_size, seq_len) - action sequences
Returns: (batch_size, seq_len) collision mask
"""
batch_size, seq_len = actions.shape
device = actions.device
# Extract obstacle channel
obstacles = voxel_data[:, 0, :, :, :] # (batch_size, D, H, W)
# Action to direction mapping
directions = torch.tensor([
[1, 0, 0], # Forward (z+)
[-1, 0, 0], # Back (z-)
[0, 1, 0], # Left (x+)
[0, -1, 0], # Right (x-)
[0, 0, 1], # Up (y+)
[0, 0, -1] # Down (y-)
], dtype=torch.long, device=device)
collision_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
current_pos = positions[:, 0, :].clone() # Start from start position
for t in range(seq_len):
# Get actions for this timestep
action_t = actions[:, t]
# Only process valid actions (0-5), skip special tokens
valid_actions = action_t < self.num_actions
# Update positions based on actions
for b in range(batch_size):
if valid_actions[b]:
direction = directions[action_t[b]]
new_pos = current_pos[b] + direction
# Check bounds
if (new_pos >= 0).all() and (new_pos[0] < self.voxel_dim[0]) and \
(new_pos[1] < self.voxel_dim[1]) and (new_pos[2] < self.voxel_dim[2]):
# Check collision
if obstacles[b, new_pos[0], new_pos[1], new_pos[2]] > 0:
collision_mask[b, t] = True
else:
current_pos[b] = new_pos
else:
# Out of bounds counts as collision
collision_mask[b, t] = True
return collision_mask
class PathfindingLoss(nn.Module):
"""
Custom loss function that balances path correctness and turn minimization.
Properly handles special tokens (START=6, END=7) and action tokens (0-5).
Turn loss is supervised: a turn occurs when consecutive valid actions differ.
"""
def __init__(self, turn_penalty_weight=0.1, collision_penalty_weight=10.0,
num_actions=6, use_end_token=True):
super(PathfindingLoss, self).__init__()
self.turn_penalty_weight = turn_penalty_weight
self.collision_penalty_weight = collision_penalty_weight
self.num_actions = num_actions
self.use_end_token = use_end_token
self.start_token_id = num_actions # 6
self.end_token_id = num_actions + 1 if use_end_token else None # 7
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-1) # Ignore padding
# BCE with logits for supervised turn prediction
self.turn_bce = nn.BCEWithLogitsLoss(reduction='sum')
def forward(self, action_logits, turn_penalties, target_actions, collision_mask=None):
"""
action_logits: (batch_size, seq_len, total_tokens) - includes all tokens (0-7)
turn_penalties: (batch_size, seq_len, 1) - interpreted as turn logits
target_actions: (batch_size, seq_len) - contains action IDs (0-5) and possibly END (7)
collision_mask: (batch_size, seq_len) - 1 if collision, 0 if safe
"""
batch_size, seq_len, total_tokens = action_logits.shape
# Reshape for cross entropy loss
action_logits_flat = action_logits.view(-1, total_tokens)
target_actions_flat = target_actions.view(-1)
# Path correctness loss - now properly handles all token IDs
path_loss = self.ce_loss(action_logits_flat, target_actions_flat)
# Supervised turn loss
# Compute valid action mask (exclude special tokens)
valid_actions_mask = (target_actions < self.num_actions)
# Previous actions (pad first timestep with itself; will be masked out anyway)
prev_actions = torch.cat([target_actions[:, :1], target_actions[:, :-1]], dim=1)
prev_valid_mask = torch.cat([torch.zeros_like(valid_actions_mask[:, :1], dtype=torch.bool),
valid_actions_mask[:, :-1]], dim=1)
# A turn occurs if both current and previous are valid actions and they differ
both_valid = valid_actions_mask & prev_valid_mask
is_turn = ((target_actions != prev_actions) & both_valid).float()
# Turn logits predicted by the model
turn_logits = turn_penalties.squeeze(-1)
# Compute BCE-with-logits only over valid pairs
num_pairs = both_valid.sum().clamp_min(1).float()
if num_pairs > 0:
bce_sum = self.turn_bce(turn_logits[both_valid], is_turn[both_valid])
turn_loss = bce_sum / num_pairs
else:
turn_loss = torch.tensor(0.0, device=action_logits.device)
# Collision penalty - only apply to actual movement actions
collision_loss = torch.tensor(0.0, device=action_logits.device)
if collision_mask is not None:
# Mask collisions to only count for actual movement actions
masked_collisions = collision_mask.float() * valid_actions_mask.float()
if valid_actions_mask.sum() > 0:
collision_loss = (masked_collisions.sum() / valid_actions_mask.sum()) * self.collision_penalty_weight
total_loss = path_loss + self.turn_penalty_weight * turn_loss + collision_loss
return {
'total_loss': total_loss,
'path_loss': path_loss,
'turn_loss': turn_loss,
'collision_loss': collision_loss
}
# Utility functions for data preparation
def create_voxel_input(obstacles, start_pos, goal_pos, voxel_dim=(32, 32, 32)):
"""
Create multi-channel voxel input.
obstacles: (D, H, W) binary array
start_pos: (x, y, z) tuple
goal_pos: (x, y, z) tuple
"""
# Channel 0: obstacles
obstacle_channel = obstacles.astype(np.float32)
# Channel 1: start position
start_channel = np.zeros(voxel_dim, dtype=np.float32)
start_channel[start_pos] = 1.0
# Channel 2: goal position
goal_channel = np.zeros(voxel_dim, dtype=np.float32)
goal_channel[goal_pos] = 1.0
# Stack channels
voxel_input = np.stack([obstacle_channel, start_channel, goal_channel], axis=0)
return voxel_input
def prepare_training_targets(action_sequence, use_end_token=True, num_actions=6):
"""
Prepare target action sequences for training.
Ensures action IDs are in range [0, num_actions-1] and adds END token if needed.
action_sequence: list or tensor of action IDs (0-5)
use_end_token: whether to append END token
num_actions: number of valid actions
Returns: tensor with proper token IDs
"""
if isinstance(action_sequence, list):
action_sequence = torch.tensor(action_sequence)
# Ensure actions are in valid range
assert (action_sequence >= 0).all() and (action_sequence < num_actions).all(), \
f"Actions must be in range [0, {num_actions-1}]"
if use_end_token:
# Append END token (ID = num_actions + 1 = 7)
end_token = torch.tensor([num_actions + 1])
target = torch.cat([action_sequence, end_token])
else:
target = action_sequence
return target
# Example usage and testing
if __name__ == "__main__":
# Define problem parameters
voxel_dim = (32, 32, 32)
batch_size = 4
num_actions = 6 # Forward, Back, Left, Right, Up, Down
# Create the complete pathfinding network
pathfinding_net = PathfindingNetwork(
voxel_dim=voxel_dim,
input_channels=3,
env_feature_dim=512,
pos_feature_dim=64,
hidden_dim=256,
num_actions=num_actions,
use_end_token=True
)
print("=== 3D Pathfinding Network Architecture ===")
print(f"Total parameters: {sum(p.numel() for p in pathfinding_net.parameters()):,}")
print(f"\nToken ID mapping:")
print(f" Actions: 0-5 (Forward, Back, Left, Right, Up, Down)")
print(f" START token: {pathfinding_net.path_planner.start_token_id}")
print(f" END token: {pathfinding_net.path_planner.end_token_id}")
# Create dummy data
dummy_voxel_data = torch.randn(batch_size, 3, *voxel_dim)
dummy_positions = torch.randint(0, 32, (batch_size, 2, 3)) # start and goal positions
# Create proper target actions with END token
dummy_actions = torch.randint(0, num_actions, (batch_size, 19)) # 19 movement actions
dummy_target_actions = torch.cat([
dummy_actions,
torch.full((batch_size, 1), pathfinding_net.path_planner.end_token_id)
], dim=1) # Add END token
print(f"\n=== Testing Forward Pass ===")
print(f"Input voxel shape: {dummy_voxel_data.shape}")
print(f"Input positions shape: {dummy_positions.shape}")
print(f"Target actions shape: {dummy_target_actions.shape}")
print(f"Target action values range: [{dummy_target_actions.min().item()}, {dummy_target_actions.max().item()}]")
# Training forward pass
pathfinding_net.train()
action_logits, turn_penalties = pathfinding_net(
dummy_voxel_data,
dummy_positions,
dummy_target_actions
)
print(f"\nTraining mode outputs:")
print(f"Action logits shape: {action_logits.shape} (should be {(batch_size, 20, 8)})")
print(f"Turn logits shape: {turn_penalties.shape}")
# Inference forward pass
pathfinding_net.eval()
with torch.no_grad():
generated_paths = pathfinding_net(dummy_voxel_data, dummy_positions)
print(f"\nInference mode outputs:")
print(f"Generated paths shape: {generated_paths.shape}")
if generated_paths.shape[1] > 0:
print(f"Generated action values range: [{generated_paths.min().item()}, {generated_paths.max().item()}]")
# Test collision checking
test_actions = generated_paths if generated_paths.shape[1] > 0 else dummy_actions
collision_mask = pathfinding_net.check_collisions(
dummy_voxel_data,
dummy_positions,
test_actions
)
print(f"Collision mask shape: {collision_mask.shape}")
# Test loss function with proper masking
loss_fn = PathfindingLoss(
turn_penalty_weight=0.1,
num_actions=num_actions,
use_end_token=True
)
# Adjust collision mask to match target sequence length
if collision_mask.shape[1] >= 20:
collision_mask_adjusted = collision_mask[:, :20]
else:
# Pad with zeros if collision mask is shorter
padding = torch.zeros(batch_size, 20 - collision_mask.shape[1],
dtype=torch.bool, device=collision_mask.device)
collision_mask_adjusted = torch.cat([collision_mask, padding], dim=1)
loss_dict = loss_fn(action_logits, turn_penalties, dummy_target_actions, collision_mask_adjusted)
print(f"\n=== Loss Components ===")
for key, value in loss_dict.items():
print(f"{key}: {value.item():.4f}")
# Verify that the loss properly masks special tokens
print(f"\n=== Verification Tests ===")
# Test 1: Verify token ID assignments
print(f"1. Token IDs are correctly assigned:")
print(f" - Movement actions use IDs 0-5: βœ“")
print(f" - START token uses ID {pathfinding_net.path_planner.start_token_id}: βœ“")
print(f" - END token uses ID {pathfinding_net.path_planner.end_token_id}: βœ“")
# Test 2: Verify Conv-BN-ReLU order
print(f"2. Conv-BN-ReLU order is standardized: βœ“")
# Test 3: Verify supervised turn labels mask
with torch.no_grad():
# Create a sequence with mixed actions and END token
test_sequence = torch.tensor([[0, 1, 2, 3, 4, 5, 7]]) # Actions 0-5 then END
valid_mask = (test_sequence < num_actions)
prev_seq = torch.cat([test_sequence[:, :1], test_sequence[:, :-1]], dim=1)
prev_valid = torch.cat([torch.zeros_like(valid_mask[:, :1], dtype=torch.bool), valid_mask[:, :-1]], dim=1)
both_valid = valid_mask & prev_valid
is_turn = ((test_sequence != prev_seq) & both_valid).float()
print(f"3. Supervised turn labels test:")
print(f" - Test sequence: {test_sequence.tolist()}")
print(f" - Valid mask: {valid_mask.tolist()}")
print(f" - Both valid mask: {both_valid.tolist()}")
print(f" - Turn labels: {is_turn.tolist()}")
# Test 4: Verify action generation doesn't output START token
print(f"4. Generated paths contain only valid action IDs (0-5):")
if generated_paths.shape[1] > 0:
contains_only_valid = (generated_paths >= 0).all() and (generated_paths < num_actions).all()
print(f" - Generated actions in valid range: {'βœ“' if contains_only_valid else 'βœ—'}")
else:
print(f" - No actions generated (early END token)")
print(f"\n=== Network Ready for Training ===")