File size: 3,612 Bytes
7f40cc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from pathlib import Path
import numpy as np
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from pathfinding_nn import PathfindingNetwork, create_voxel_input

ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN']

def load_model():
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = PathfindingNetwork().to(device).eval()

  # Prefer local checkpoint
  local_ckpt = Path('training_outputs/final_model.pth')
  ckpt_path = None
  if local_ckpt.exists():
    ckpt_path = str(local_ckpt)
  else:
    # Fallback to Hub (configure your repo and filename)
    repo_id = os.getenv('MODEL_REPO_ID', '')  # e.g. "your-username/voxel-pathfinder"
    filename = os.getenv('MODEL_FILENAME', 'final_model.pth')
    if repo_id:
      ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)

  if ckpt_path is None:
    raise FileNotFoundError("Model checkpoint not found. Upload to training_outputs/final_model.pth or set MODEL_REPO_ID+MODEL_FILENAME env vars.")

  ckpt = torch.load(ckpt_path, map_location=device)
  state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
  model.load_state_dict(state)
  return model, device

MODEL, DEVICE = load_model()

def decode(actions):
  return [ACTION_NAMES[a] for a in actions if 0 <= a < 6]

def infer_random(obstacle_prob=0.2, seed=None):
  if seed is not None:
    np.random.seed(int(seed))
  voxel_dim = MODEL.voxel_dim  # (32,32,32)
  D,H,W = voxel_dim
  obstacles = (np.random.rand(D,H,W) < float(obstacle_prob)).astype(np.float32)
  free = np.argwhere(obstacles == 0)
  if len(free) < 2:
    return {"error": "Not enough free cells; lower obstacle_prob."}
  s_idx, g_idx = np.random.choice(len(free), size=2, replace=False)
  start = tuple(free[s_idx])
  goal = tuple(free[g_idx])

  voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim)
  voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE)    # (1,3,32,32,32)
  pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE)  # (1,2,3)

  with torch.no_grad():
    actions = MODEL(voxel, pos)[0].tolist()
  return {
    "start": start,
    "goal": goal,
    "num_actions": len([a for a in actions if 0 <= a < 6]),
    "actions_ids": actions,
    "actions_decoded": decode(actions)[:50]
  }

def infer_npz(npz_file):
  if npz_file is None:
    return {"error": "Please upload a .npz with keys 'voxel_data' and 'positions'."}
  data = np.load(npz_file.name)
  voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE)  # (1,3,32,32,32)
  pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE)      # (1,2,3)
  with torch.no_grad():
    actions = MODEL(voxel, pos)[0].tolist()
  return {
    "num_actions": len([a for a in actions if 0 <= a < 6]),
    "actions_ids": actions,
    "actions_decoded": decode(actions)[:50]
  }

with gr.Blocks(title="Voxel Path Finder") as demo:
  gr.Markdown("## 3D Voxel Path Finder — Inference")
  with gr.Tab("Random environment"):
    obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability")
    seed = gr.Number(value=None, label="Seed (optional)")
    btn = gr.Button("Run inference")
    out = gr.JSON(label="Result")
    btn.click(infer_random, inputs=[obstacle, seed], outputs=out)

  with gr.Tab("Upload .npz sample"):
    file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)")
    btn2 = gr.Button("Run inference")
    out2 = gr.JSON(label="Result")
    btn2.click(infer_npz, inputs=file, outputs=out2)

if __name__ == "__main__":
  demo.launch()