|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from typing import Tuple, List, Optional |
|
|
|
class VoxelCNNEncoder(nn.Module): |
|
""" |
|
Enhanced 3D CNN encoder for voxelized obstruction data with multi-channel support. |
|
Processes environment obstacles, start position, and goal position. |
|
""" |
|
def __init__(self, |
|
input_channels=3, |
|
filters_1=32, |
|
kernel_size_1=(3, 3, 3), |
|
pool_size_1=(2, 2, 2), |
|
filters_2=64, |
|
kernel_size_2=(3, 3, 3), |
|
pool_size_2=(2, 2, 2), |
|
filters_3=128, |
|
kernel_size_3=(3, 3, 3), |
|
pool_size_3=(2, 2, 2), |
|
dense_units=512, |
|
input_voxel_dim=(32, 32, 32), |
|
dropout_rate=0.2 |
|
): |
|
super(VoxelCNNEncoder, self).__init__() |
|
|
|
self.input_voxel_dim = input_voxel_dim |
|
self.input_channels = input_channels |
|
|
|
|
|
padding_1 = tuple([(k - 1) // 2 for k in kernel_size_1]) |
|
self.conv1 = nn.Conv3d(input_channels, filters_1, kernel_size_1, padding=padding_1) |
|
self.bn1 = nn.BatchNorm3d(filters_1) |
|
self.pool1 = nn.MaxPool3d(pool_size_1) |
|
self.dropout1 = nn.Dropout3d(dropout_rate) |
|
|
|
|
|
padding_2 = tuple([(k - 1) // 2 for k in kernel_size_2]) |
|
self.conv2 = nn.Conv3d(filters_1, filters_2, kernel_size_2, padding=padding_2) |
|
self.bn2 = nn.BatchNorm3d(filters_2) |
|
self.pool2 = nn.MaxPool3d(pool_size_2) |
|
self.dropout2 = nn.Dropout3d(dropout_rate) |
|
|
|
|
|
padding_3 = tuple([(k - 1) // 2 for k in kernel_size_3]) |
|
self.conv3 = nn.Conv3d(filters_2, filters_3, kernel_size_3, padding=padding_3) |
|
self.bn3 = nn.BatchNorm3d(filters_3) |
|
self.pool3 = nn.MaxPool3d(pool_size_3) |
|
self.dropout3 = nn.Dropout3d(dropout_rate) |
|
|
|
|
|
self._to_linear_input_size = self._get_conv_output_size() |
|
|
|
|
|
self.fc1 = nn.Linear(self._to_linear_input_size, dense_units) |
|
self.fc2 = nn.Linear(dense_units, dense_units) |
|
self.dropout_fc = nn.Dropout(dropout_rate) |
|
|
|
def _get_conv_output_size(self): |
|
with torch.no_grad(): |
|
dummy_input = torch.zeros(1, self.input_channels, *self.input_voxel_dim) |
|
|
|
x = self.conv1(dummy_input) |
|
x = self.bn1(x) |
|
x = F.relu(x) |
|
x = self.pool1(x) |
|
x = self.dropout1(x) |
|
|
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
x = F.relu(x) |
|
x = self.pool2(x) |
|
x = self.dropout2(x) |
|
|
|
x = self.conv3(x) |
|
x = self.bn3(x) |
|
x = F.relu(x) |
|
x = self.pool3(x) |
|
x = self.dropout3(x) |
|
|
|
return x.numel() |
|
|
|
def forward(self, x): |
|
|
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = F.relu(x) |
|
x = self.pool1(x) |
|
x = self.dropout1(x) |
|
|
|
|
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
x = F.relu(x) |
|
x = self.pool2(x) |
|
x = self.dropout2(x) |
|
|
|
|
|
x = self.conv3(x) |
|
x = self.bn3(x) |
|
x = F.relu(x) |
|
x = self.pool3(x) |
|
x = self.dropout3(x) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
x1 = F.relu(self.fc1(x)) |
|
x1 = self.dropout_fc(x1) |
|
x2 = F.relu(self.fc2(x1)) |
|
|
|
|
|
return x1 + x2 |
|
|
|
|
|
class PositionEncoder(nn.Module): |
|
""" |
|
Encodes start and goal positions with learned embeddings. |
|
""" |
|
def __init__(self, voxel_dim=(32, 32, 32), position_embed_dim=64): |
|
super(PositionEncoder, self).__init__() |
|
self.voxel_dim = voxel_dim |
|
self.position_embed_dim = position_embed_dim |
|
|
|
|
|
dim_per_axis = position_embed_dim // 3 |
|
remainder = position_embed_dim % 3 |
|
|
|
x_dim = dim_per_axis + (1 if remainder > 0 else 0) |
|
y_dim = dim_per_axis + (1 if remainder > 1 else 0) |
|
z_dim = dim_per_axis |
|
|
|
|
|
self.x_embed = nn.Embedding(voxel_dim[0], x_dim) |
|
self.y_embed = nn.Embedding(voxel_dim[1], y_dim) |
|
self.z_embed = nn.Embedding(voxel_dim[2], z_dim) |
|
|
|
|
|
self.fc = nn.Linear(2 * position_embed_dim, position_embed_dim) |
|
|
|
def forward(self, positions): |
|
""" |
|
positions: (batch_size, 2, 3) - [start_pos, goal_pos] with (x, y, z) |
|
""" |
|
batch_size = positions.size(0) |
|
|
|
|
|
|
|
x_coords = positions[:, :, 0].long().clamp_(0, self.voxel_dim[0] - 1) |
|
y_coords = positions[:, :, 1].long().clamp_(0, self.voxel_dim[1] - 1) |
|
z_coords = positions[:, :, 2].long().clamp_(0, self.voxel_dim[2] - 1) |
|
|
|
|
|
x_emb = self.x_embed(x_coords) |
|
y_emb = self.y_embed(y_coords) |
|
z_emb = self.z_embed(z_coords) |
|
|
|
|
|
pos_emb = torch.cat([x_emb, y_emb, z_emb], dim=-1) |
|
|
|
|
|
pos_emb = pos_emb.view(batch_size, -1) |
|
|
|
return F.relu(self.fc(pos_emb)) |
|
|
|
|
|
class PathPlannerTransformer(nn.Module): |
|
""" |
|
Transformer-based path planner that generates action sequences. |
|
Fixed token IDs to avoid collision: |
|
- Actions: 0-5 (Forward, Back, Left, Right, Up, Down) |
|
- START: 6 |
|
- END: 7 |
|
- PAD: 8 (used only for teacher forcing inputs; targets still use -1 for ignore) |
|
""" |
|
def __init__(self, |
|
env_feature_dim=512, |
|
pos_feature_dim=64, |
|
hidden_dim=256, |
|
num_heads=8, |
|
num_layers=4, |
|
max_sequence_length=100, |
|
num_actions=6, |
|
use_end_token=True): |
|
super(PathPlannerTransformer, self).__init__() |
|
|
|
self.hidden_dim = hidden_dim |
|
self.max_sequence_length = max_sequence_length |
|
self.num_actions = num_actions |
|
self.use_end_token = use_end_token |
|
|
|
|
|
self.start_token_id = num_actions |
|
self.end_token_id = num_actions + 1 if use_end_token else None |
|
|
|
self.pad_token_id = (num_actions + 2) if use_end_token else (num_actions + 1) |
|
|
|
self.total_tokens = (num_actions + 3) if use_end_token else (num_actions + 2) |
|
|
|
|
|
self.feature_fusion = nn.Linear(env_feature_dim + pos_feature_dim, hidden_dim) |
|
|
|
|
|
self.action_embed = nn.Embedding(self.total_tokens, hidden_dim) |
|
|
|
|
|
self.register_buffer('pos_encoding', self._create_positional_encoding(max_sequence_length, hidden_dim)) |
|
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=hidden_dim, |
|
nhead=num_heads, |
|
dim_feedforward=hidden_dim * 4, |
|
dropout=0.1, |
|
batch_first=True |
|
) |
|
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) |
|
|
|
|
|
self.output_proj = nn.Linear(hidden_dim, self.total_tokens) |
|
|
|
|
|
self.turn_penalty_head = nn.Linear(hidden_dim, 1) |
|
|
|
def _create_positional_encoding(self, max_len, d_model): |
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len).unsqueeze(1).float() |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
return pe.unsqueeze(0) |
|
|
|
def forward(self, env_features, pos_features, target_actions=None): |
|
""" |
|
env_features: (batch_size, env_feature_dim) |
|
pos_features: (batch_size, pos_feature_dim) |
|
target_actions: (batch_size, seq_len) - for training (contains action IDs 0-5 and END token 7) |
|
""" |
|
batch_size = env_features.size(0) |
|
|
|
|
|
fused_features = self.feature_fusion(torch.cat([env_features, pos_features], dim=1)) |
|
|
|
|
|
memory = fused_features.unsqueeze(1).repeat(1, self.max_sequence_length, 1) |
|
|
|
if target_actions is not None: |
|
|
|
seq_len = target_actions.size(1) |
|
|
|
|
|
start_tokens = torch.full((batch_size, 1), self.start_token_id, |
|
dtype=torch.long, device=target_actions.device) |
|
input_seq = torch.cat([start_tokens, target_actions[:, :-1]], dim=1) |
|
|
|
input_seq = torch.where(input_seq < 0, torch.full_like(input_seq, self.pad_token_id), input_seq) |
|
|
|
|
|
embedded = self.action_embed(input_seq) |
|
embedded = embedded + self.pos_encoding[:, :seq_len, :] |
|
|
|
|
|
tgt_mask = self._generate_square_subsequent_mask(seq_len).to(embedded.device) |
|
|
|
|
|
output = self.transformer_decoder( |
|
tgt=embedded, |
|
memory=memory[:, :seq_len, :], |
|
tgt_mask=tgt_mask |
|
) |
|
|
|
|
|
action_logits = self.output_proj(output) |
|
|
|
turn_logits = self.turn_penalty_head(output) |
|
|
|
return action_logits, turn_logits |
|
else: |
|
|
|
return self._generate_path(memory, batch_size) |
|
|
|
def _generate_square_subsequent_mask(self, sz): |
|
mask = torch.triu(torch.ones(sz, sz), diagonal=1) |
|
mask = mask.masked_fill(mask == 1, float('-inf')) |
|
return mask |
|
|
|
def _generate_path(self, memory, batch_size): |
|
""" |
|
Generate path sequence autoregressively, handling batches correctly. |
|
Fixes bugs related to premature stopping and inclusion of special tokens. |
|
""" |
|
device = memory.device |
|
|
|
|
|
input_seq = torch.full((batch_size, 1), self.start_token_id, dtype=torch.long, device=device) |
|
|
|
|
|
is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
|
for step in range(self.max_sequence_length): |
|
|
|
embedded = self.action_embed(input_seq) |
|
seq_len = embedded.size(1) |
|
embedded = embedded + self.pos_encoding[:, :seq_len, :] |
|
|
|
|
|
tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device) |
|
|
|
|
|
output = self.transformer_decoder( |
|
tgt=embedded, |
|
memory=memory[:, :seq_len, :], |
|
tgt_mask=tgt_mask |
|
) |
|
|
|
|
|
next_action_logits = self.output_proj(output[:, -1, :]) |
|
next_actions = torch.argmax(next_action_logits, dim=-1, keepdim=True) |
|
|
|
|
|
input_seq = torch.cat([input_seq, next_actions], dim=1) |
|
|
|
|
|
if self.use_end_token: |
|
is_finished |= (next_actions.squeeze(-1) == self.end_token_id) |
|
|
|
|
|
if is_finished.all(): |
|
break |
|
|
|
|
|
|
|
raw_paths = input_seq[:, 1:] |
|
|
|
clean_paths_list = [] |
|
max_len = 0 |
|
|
|
for i in range(batch_size): |
|
path = [] |
|
for token_id in raw_paths[i]: |
|
|
|
if self.use_end_token and token_id.item() == self.end_token_id: |
|
break |
|
|
|
if token_id.item() < self.num_actions: |
|
path.append(token_id.item()) |
|
|
|
clean_paths_list.append(path) |
|
if len(path) > max_len: |
|
max_len = len(path) |
|
|
|
|
|
if max_len == 0: |
|
return torch.zeros(batch_size, 0, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
|
pad_value = self.end_token_id if self.use_end_token else self.num_actions |
|
padded_paths = torch.full((batch_size, max_len), pad_value, dtype=torch.long, device=device) |
|
|
|
for i, path in enumerate(clean_paths_list): |
|
if len(path) > 0: |
|
padded_paths[i, :len(path)] = torch.tensor(path, dtype=torch.long, device=device) |
|
|
|
return padded_paths |
|
|
|
|
|
class PathfindingNetwork(nn.Module): |
|
""" |
|
Complete pathfinding network combining CNN encoder, position encoder, and transformer planner. |
|
""" |
|
def __init__(self, |
|
voxel_dim=(32, 32, 32), |
|
input_channels=3, |
|
env_feature_dim=512, |
|
pos_feature_dim=64, |
|
hidden_dim=256, |
|
num_actions=6, |
|
use_end_token=True): |
|
super(PathfindingNetwork, self).__init__() |
|
|
|
self.voxel_dim = voxel_dim |
|
self.num_actions = num_actions |
|
|
|
self.voxel_encoder = VoxelCNNEncoder( |
|
input_channels=input_channels, |
|
dense_units=env_feature_dim, |
|
input_voxel_dim=voxel_dim |
|
) |
|
|
|
self.position_encoder = PositionEncoder( |
|
voxel_dim=voxel_dim, |
|
position_embed_dim=pos_feature_dim |
|
) |
|
|
|
self.path_planner = PathPlannerTransformer( |
|
env_feature_dim=env_feature_dim, |
|
pos_feature_dim=pos_feature_dim, |
|
hidden_dim=hidden_dim, |
|
num_actions=num_actions, |
|
use_end_token=use_end_token |
|
) |
|
|
|
def forward(self, voxel_data, positions, target_actions=None): |
|
""" |
|
voxel_data: (batch_size, 3, D, H, W) - [obstacles, start_mask, goal_mask] |
|
positions: (batch_size, 2, 3) - [start_pos, goal_pos] |
|
target_actions: (batch_size, seq_len) - for training |
|
""" |
|
|
|
env_features = self.voxel_encoder(voxel_data) |
|
|
|
|
|
pos_features = self.position_encoder(positions) |
|
|
|
|
|
if target_actions is not None: |
|
action_logits, turn_penalties = self.path_planner(env_features, pos_features, target_actions) |
|
return action_logits, turn_penalties |
|
else: |
|
generated_path = self.path_planner(env_features, pos_features) |
|
return generated_path |
|
|
|
def check_collisions(self, voxel_data, positions, actions): |
|
""" |
|
Check if actions lead to collisions with obstacles. |
|
|
|
voxel_data: (batch_size, 3, D, H, W) |
|
positions: (batch_size, 2, 3) - start positions |
|
actions: (batch_size, seq_len) - action sequences |
|
|
|
Returns: (batch_size, seq_len) collision mask |
|
""" |
|
batch_size, seq_len = actions.shape |
|
device = actions.device |
|
|
|
|
|
obstacles = voxel_data[:, 0, :, :, :] |
|
|
|
|
|
directions = torch.tensor([ |
|
[1, 0, 0], |
|
[-1, 0, 0], |
|
[0, 1, 0], |
|
[0, -1, 0], |
|
[0, 0, 1], |
|
[0, 0, -1] |
|
], dtype=torch.long, device=device) |
|
|
|
collision_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) |
|
current_pos = positions[:, 0, :].clone() |
|
|
|
for t in range(seq_len): |
|
|
|
action_t = actions[:, t] |
|
|
|
|
|
valid_actions = action_t < self.num_actions |
|
|
|
|
|
for b in range(batch_size): |
|
if valid_actions[b]: |
|
direction = directions[action_t[b]] |
|
new_pos = current_pos[b] + direction |
|
|
|
|
|
if (new_pos >= 0).all() and (new_pos[0] < self.voxel_dim[0]) and \ |
|
(new_pos[1] < self.voxel_dim[1]) and (new_pos[2] < self.voxel_dim[2]): |
|
|
|
if obstacles[b, new_pos[0], new_pos[1], new_pos[2]] > 0: |
|
collision_mask[b, t] = True |
|
else: |
|
current_pos[b] = new_pos |
|
else: |
|
|
|
collision_mask[b, t] = True |
|
|
|
return collision_mask |
|
|
|
|
|
class PathfindingLoss(nn.Module): |
|
""" |
|
Custom loss function that balances path correctness and turn minimization. |
|
Properly handles special tokens (START=6, END=7) and action tokens (0-5). |
|
Turn loss is supervised: a turn occurs when consecutive valid actions differ. |
|
""" |
|
def __init__(self, turn_penalty_weight=0.1, collision_penalty_weight=10.0, |
|
num_actions=6, use_end_token=True): |
|
super(PathfindingLoss, self).__init__() |
|
self.turn_penalty_weight = turn_penalty_weight |
|
self.collision_penalty_weight = collision_penalty_weight |
|
self.num_actions = num_actions |
|
self.use_end_token = use_end_token |
|
self.start_token_id = num_actions |
|
self.end_token_id = num_actions + 1 if use_end_token else None |
|
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-1) |
|
|
|
self.turn_bce = nn.BCEWithLogitsLoss(reduction='sum') |
|
|
|
def forward(self, action_logits, turn_penalties, target_actions, collision_mask=None): |
|
""" |
|
action_logits: (batch_size, seq_len, total_tokens) - includes all tokens (0-7) |
|
turn_penalties: (batch_size, seq_len, 1) - interpreted as turn logits |
|
target_actions: (batch_size, seq_len) - contains action IDs (0-5) and possibly END (7) |
|
collision_mask: (batch_size, seq_len) - 1 if collision, 0 if safe |
|
""" |
|
batch_size, seq_len, total_tokens = action_logits.shape |
|
|
|
|
|
action_logits_flat = action_logits.view(-1, total_tokens) |
|
target_actions_flat = target_actions.view(-1) |
|
|
|
|
|
path_loss = self.ce_loss(action_logits_flat, target_actions_flat) |
|
|
|
|
|
|
|
valid_actions_mask = (target_actions < self.num_actions) |
|
|
|
prev_actions = torch.cat([target_actions[:, :1], target_actions[:, :-1]], dim=1) |
|
prev_valid_mask = torch.cat([torch.zeros_like(valid_actions_mask[:, :1], dtype=torch.bool), |
|
valid_actions_mask[:, :-1]], dim=1) |
|
|
|
both_valid = valid_actions_mask & prev_valid_mask |
|
is_turn = ((target_actions != prev_actions) & both_valid).float() |
|
|
|
|
|
turn_logits = turn_penalties.squeeze(-1) |
|
|
|
|
|
num_pairs = both_valid.sum().clamp_min(1).float() |
|
if num_pairs > 0: |
|
bce_sum = self.turn_bce(turn_logits[both_valid], is_turn[both_valid]) |
|
turn_loss = bce_sum / num_pairs |
|
else: |
|
turn_loss = torch.tensor(0.0, device=action_logits.device) |
|
|
|
|
|
collision_loss = torch.tensor(0.0, device=action_logits.device) |
|
if collision_mask is not None: |
|
|
|
masked_collisions = collision_mask.float() * valid_actions_mask.float() |
|
if valid_actions_mask.sum() > 0: |
|
collision_loss = (masked_collisions.sum() / valid_actions_mask.sum()) * self.collision_penalty_weight |
|
|
|
total_loss = path_loss + self.turn_penalty_weight * turn_loss + collision_loss |
|
|
|
return { |
|
'total_loss': total_loss, |
|
'path_loss': path_loss, |
|
'turn_loss': turn_loss, |
|
'collision_loss': collision_loss |
|
} |
|
|
|
|
|
|
|
def create_voxel_input(obstacles, start_pos, goal_pos, voxel_dim=(32, 32, 32)): |
|
""" |
|
Create multi-channel voxel input. |
|
|
|
obstacles: (D, H, W) binary array |
|
start_pos: (x, y, z) tuple |
|
goal_pos: (x, y, z) tuple |
|
""" |
|
|
|
obstacle_channel = obstacles.astype(np.float32) |
|
|
|
|
|
start_channel = np.zeros(voxel_dim, dtype=np.float32) |
|
start_channel[start_pos] = 1.0 |
|
|
|
|
|
goal_channel = np.zeros(voxel_dim, dtype=np.float32) |
|
goal_channel[goal_pos] = 1.0 |
|
|
|
|
|
voxel_input = np.stack([obstacle_channel, start_channel, goal_channel], axis=0) |
|
|
|
return voxel_input |
|
|
|
|
|
def prepare_training_targets(action_sequence, use_end_token=True, num_actions=6): |
|
""" |
|
Prepare target action sequences for training. |
|
Ensures action IDs are in range [0, num_actions-1] and adds END token if needed. |
|
|
|
action_sequence: list or tensor of action IDs (0-5) |
|
use_end_token: whether to append END token |
|
num_actions: number of valid actions |
|
|
|
Returns: tensor with proper token IDs |
|
""" |
|
if isinstance(action_sequence, list): |
|
action_sequence = torch.tensor(action_sequence) |
|
|
|
|
|
assert (action_sequence >= 0).all() and (action_sequence < num_actions).all(), \ |
|
f"Actions must be in range [0, {num_actions-1}]" |
|
|
|
if use_end_token: |
|
|
|
end_token = torch.tensor([num_actions + 1]) |
|
target = torch.cat([action_sequence, end_token]) |
|
else: |
|
target = action_sequence |
|
|
|
return target |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
voxel_dim = (32, 32, 32) |
|
batch_size = 4 |
|
num_actions = 6 |
|
|
|
|
|
pathfinding_net = PathfindingNetwork( |
|
voxel_dim=voxel_dim, |
|
input_channels=3, |
|
env_feature_dim=512, |
|
pos_feature_dim=64, |
|
hidden_dim=256, |
|
num_actions=num_actions, |
|
use_end_token=True |
|
) |
|
|
|
print("=== 3D Pathfinding Network Architecture ===") |
|
print(f"Total parameters: {sum(p.numel() for p in pathfinding_net.parameters()):,}") |
|
print(f"\nToken ID mapping:") |
|
print(f" Actions: 0-5 (Forward, Back, Left, Right, Up, Down)") |
|
print(f" START token: {pathfinding_net.path_planner.start_token_id}") |
|
print(f" END token: {pathfinding_net.path_planner.end_token_id}") |
|
|
|
|
|
dummy_voxel_data = torch.randn(batch_size, 3, *voxel_dim) |
|
dummy_positions = torch.randint(0, 32, (batch_size, 2, 3)) |
|
|
|
|
|
dummy_actions = torch.randint(0, num_actions, (batch_size, 19)) |
|
dummy_target_actions = torch.cat([ |
|
dummy_actions, |
|
torch.full((batch_size, 1), pathfinding_net.path_planner.end_token_id) |
|
], dim=1) |
|
|
|
print(f"\n=== Testing Forward Pass ===") |
|
print(f"Input voxel shape: {dummy_voxel_data.shape}") |
|
print(f"Input positions shape: {dummy_positions.shape}") |
|
print(f"Target actions shape: {dummy_target_actions.shape}") |
|
print(f"Target action values range: [{dummy_target_actions.min().item()}, {dummy_target_actions.max().item()}]") |
|
|
|
|
|
pathfinding_net.train() |
|
action_logits, turn_penalties = pathfinding_net( |
|
dummy_voxel_data, |
|
dummy_positions, |
|
dummy_target_actions |
|
) |
|
|
|
print(f"\nTraining mode outputs:") |
|
print(f"Action logits shape: {action_logits.shape} (should be {(batch_size, 20, 8)})") |
|
print(f"Turn logits shape: {turn_penalties.shape}") |
|
|
|
|
|
pathfinding_net.eval() |
|
with torch.no_grad(): |
|
generated_paths = pathfinding_net(dummy_voxel_data, dummy_positions) |
|
|
|
print(f"\nInference mode outputs:") |
|
print(f"Generated paths shape: {generated_paths.shape}") |
|
if generated_paths.shape[1] > 0: |
|
print(f"Generated action values range: [{generated_paths.min().item()}, {generated_paths.max().item()}]") |
|
|
|
|
|
test_actions = generated_paths if generated_paths.shape[1] > 0 else dummy_actions |
|
collision_mask = pathfinding_net.check_collisions( |
|
dummy_voxel_data, |
|
dummy_positions, |
|
test_actions |
|
) |
|
print(f"Collision mask shape: {collision_mask.shape}") |
|
|
|
|
|
loss_fn = PathfindingLoss( |
|
turn_penalty_weight=0.1, |
|
num_actions=num_actions, |
|
use_end_token=True |
|
) |
|
|
|
|
|
if collision_mask.shape[1] >= 20: |
|
collision_mask_adjusted = collision_mask[:, :20] |
|
else: |
|
|
|
padding = torch.zeros(batch_size, 20 - collision_mask.shape[1], |
|
dtype=torch.bool, device=collision_mask.device) |
|
collision_mask_adjusted = torch.cat([collision_mask, padding], dim=1) |
|
|
|
loss_dict = loss_fn(action_logits, turn_penalties, dummy_target_actions, collision_mask_adjusted) |
|
|
|
print(f"\n=== Loss Components ===") |
|
for key, value in loss_dict.items(): |
|
print(f"{key}: {value.item():.4f}") |
|
|
|
|
|
print(f"\n=== Verification Tests ===") |
|
|
|
|
|
print(f"1. Token IDs are correctly assigned:") |
|
print(f" - Movement actions use IDs 0-5: β") |
|
print(f" - START token uses ID {pathfinding_net.path_planner.start_token_id}: β") |
|
print(f" - END token uses ID {pathfinding_net.path_planner.end_token_id}: β") |
|
|
|
|
|
print(f"2. Conv-BN-ReLU order is standardized: β") |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
test_sequence = torch.tensor([[0, 1, 2, 3, 4, 5, 7]]) |
|
valid_mask = (test_sequence < num_actions) |
|
prev_seq = torch.cat([test_sequence[:, :1], test_sequence[:, :-1]], dim=1) |
|
prev_valid = torch.cat([torch.zeros_like(valid_mask[:, :1], dtype=torch.bool), valid_mask[:, :-1]], dim=1) |
|
both_valid = valid_mask & prev_valid |
|
is_turn = ((test_sequence != prev_seq) & both_valid).float() |
|
print(f"3. Supervised turn labels test:") |
|
print(f" - Test sequence: {test_sequence.tolist()}") |
|
print(f" - Valid mask: {valid_mask.tolist()}") |
|
print(f" - Both valid mask: {both_valid.tolist()}") |
|
print(f" - Turn labels: {is_turn.tolist()}") |
|
|
|
|
|
print(f"4. Generated paths contain only valid action IDs (0-5):") |
|
if generated_paths.shape[1] > 0: |
|
contains_only_valid = (generated_paths >= 0).all() and (generated_paths < num_actions).all() |
|
print(f" - Generated actions in valid range: {'β' if contains_only_valid else 'β'}") |
|
else: |
|
print(f" - No actions generated (early END token)") |
|
|
|
print(f"\n=== Network Ready for Training ===") |