c1tr0n75 commited on
Commit
7f40cc7
·
verified ·
1 Parent(s): e667b88

Create app.py

Browse files

Add an app.py that:
Loads PathfindingNetwork and your weights.
Lets users either:
Upload a .npz sample (voxel_data [1,3,32,32,32], positions [1,2,3]), or
Generate a random environment and run inference.
Displays decoded actions.

Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
+ from pathfinding_nn import PathfindingNetwork, create_voxel_input
8
+
9
+ ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN']
10
+
11
+ def load_model():
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model = PathfindingNetwork().to(device).eval()
14
+
15
+ # Prefer local checkpoint
16
+ local_ckpt = Path('training_outputs/final_model.pth')
17
+ ckpt_path = None
18
+ if local_ckpt.exists():
19
+ ckpt_path = str(local_ckpt)
20
+ else:
21
+ # Fallback to Hub (configure your repo and filename)
22
+ repo_id = os.getenv('MODEL_REPO_ID', '') # e.g. "your-username/voxel-pathfinder"
23
+ filename = os.getenv('MODEL_FILENAME', 'final_model.pth')
24
+ if repo_id:
25
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
26
+
27
+ if ckpt_path is None:
28
+ raise FileNotFoundError("Model checkpoint not found. Upload to training_outputs/final_model.pth or set MODEL_REPO_ID+MODEL_FILENAME env vars.")
29
+
30
+ ckpt = torch.load(ckpt_path, map_location=device)
31
+ state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
32
+ model.load_state_dict(state)
33
+ return model, device
34
+
35
+ MODEL, DEVICE = load_model()
36
+
37
+ def decode(actions):
38
+ return [ACTION_NAMES[a] for a in actions if 0 <= a < 6]
39
+
40
+ def infer_random(obstacle_prob=0.2, seed=None):
41
+ if seed is not None:
42
+ np.random.seed(int(seed))
43
+ voxel_dim = MODEL.voxel_dim # (32,32,32)
44
+ D,H,W = voxel_dim
45
+ obstacles = (np.random.rand(D,H,W) < float(obstacle_prob)).astype(np.float32)
46
+ free = np.argwhere(obstacles == 0)
47
+ if len(free) < 2:
48
+ return {"error": "Not enough free cells; lower obstacle_prob."}
49
+ s_idx, g_idx = np.random.choice(len(free), size=2, replace=False)
50
+ start = tuple(free[s_idx])
51
+ goal = tuple(free[g_idx])
52
+
53
+ voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim)
54
+ voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE) # (1,3,32,32,32)
55
+ pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE) # (1,2,3)
56
+
57
+ with torch.no_grad():
58
+ actions = MODEL(voxel, pos)[0].tolist()
59
+ return {
60
+ "start": start,
61
+ "goal": goal,
62
+ "num_actions": len([a for a in actions if 0 <= a < 6]),
63
+ "actions_ids": actions,
64
+ "actions_decoded": decode(actions)[:50]
65
+ }
66
+
67
+ def infer_npz(npz_file):
68
+ if npz_file is None:
69
+ return {"error": "Please upload a .npz with keys 'voxel_data' and 'positions'."}
70
+ data = np.load(npz_file.name)
71
+ voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE) # (1,3,32,32,32)
72
+ pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE) # (1,2,3)
73
+ with torch.no_grad():
74
+ actions = MODEL(voxel, pos)[0].tolist()
75
+ return {
76
+ "num_actions": len([a for a in actions if 0 <= a < 6]),
77
+ "actions_ids": actions,
78
+ "actions_decoded": decode(actions)[:50]
79
+ }
80
+
81
+ with gr.Blocks(title="Voxel Path Finder") as demo:
82
+ gr.Markdown("## 3D Voxel Path Finder — Inference")
83
+ with gr.Tab("Random environment"):
84
+ obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability")
85
+ seed = gr.Number(value=None, label="Seed (optional)")
86
+ btn = gr.Button("Run inference")
87
+ out = gr.JSON(label="Result")
88
+ btn.click(infer_random, inputs=[obstacle, seed], outputs=out)
89
+
90
+ with gr.Tab("Upload .npz sample"):
91
+ file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)")
92
+ btn2 = gr.Button("Run inference")
93
+ out2 = gr.JSON(label="Result")
94
+ btn2.click(infer_npz, inputs=file, outputs=out2)
95
+
96
+ if __name__ == "__main__":
97
+ demo.launch()