c1tr0n75 commited on
Commit
8539646
·
verified ·
1 Parent(s): 17c84c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -13
app.py CHANGED
@@ -1,17 +1,92 @@
1
- from huggingface_hub import snapshot_download
2
- import os, glob, torch
 
 
 
 
3
 
4
- local_dir = snapshot_download(repo_id="c1tr0n75/VoxelPathFinder")
5
 
6
- # preferred path in subfolder
7
- ckpt_path = os.path.join(local_dir, "training_outputs", "final_model.pth")
8
- if not os.path.exists(ckpt_path):
9
- # fallback: find any .pth
10
- pths = [p for p in glob.glob(os.path.join(local_dir, "**", "*.pth"), recursive=True)]
11
- if not pths:
12
- raise FileNotFoundError("No .pth found in repo snapshot.")
13
- ckpt_path = pths[0]
14
 
15
- ckpt = torch.load(ckpt_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  state = ckpt.get("model_state_dict", ckpt)
17
- model.load_state_dict(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ REPO_ID = "c1tr0n75/VoxelPathFinder"
9
 
10
+ # 1) Make sure torch is imported, then define device BEFORE using it anywhere
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
 
13
+ # 2) Download model code and weights from your model repo
14
+ PY_PATH = hf_hub_download(repo_id=REPO_ID, filename="pathfinding_nn.py")
15
+ CKPT_PATH = hf_hub_download(repo_id=REPO_ID, filename="training_outputs/final_model.pth")
16
+
17
+ # 3) Dynamically import your model definitions
18
+ spec = importlib.util.spec_from_file_location("pathfinding_nn", PY_PATH)
19
+ mod = importlib.util.module_from_spec(spec)
20
+ spec.loader.exec_module(mod)
21
+ PathfindingNetwork = mod.PathfindingNetwork
22
+ create_voxel_input = mod.create_voxel_input
23
+
24
+ # 4) Build and load model
25
+ MODEL = PathfindingNetwork().to(DEVICE).eval()
26
+ ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
27
  state = ckpt.get("model_state_dict", ckpt)
28
+ MODEL.load_state_dict(state)
29
+
30
+ ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN']
31
+
32
+ def decode(actions):
33
+ return [ACTION_NAMES[a] for a in actions if 0 <= a < 6]
34
+
35
+ def infer_random(obstacle_prob=0.2, seed=None):
36
+ if seed is not None:
37
+ np.random.seed(int(seed))
38
+ voxel_dim = MODEL.voxel_dim
39
+ D, H, W = voxel_dim
40
+ obstacles = (np.random.rand(D, H, W) < float(obstacle_prob)).astype(np.float32)
41
+ free = np.argwhere(obstacles == 0)
42
+ if len(free) < 2:
43
+ return {"error": "Not enough free cells; lower obstacle_prob."}
44
+ s_idx, g_idx = np.random.choice(len(free), size=2, replace=False)
45
+ start = tuple(free[s_idx]); goal = tuple(free[g_idx])
46
+
47
+ voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim)
48
+ voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE)
49
+ pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE)
50
+
51
+ with torch.no_grad():
52
+ actions = MODEL(voxel, pos)[0].tolist()
53
+
54
+ return {
55
+ "start": start,
56
+ "goal": goal,
57
+ "num_actions": len([a for a in actions if 0 <= a < 6]),
58
+ "actions_ids": actions,
59
+ "actions_decoded": decode(actions)[:50],
60
+ }
61
+
62
+ #def infer_npz(npz_file):
63
+ # if npz_file is None:
64
+ # return {"error": "Upload a .npz with 'voxel_data' and 'positions'."}
65
+ # data = np.load(npz_file.name)
66
+ # voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE)
67
+ # pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE)
68
+ # with torch.no_grad():
69
+ # actions = MODEL(voxel, pos)[0].tolist()
70
+ # return {
71
+ # "num_actions": len([a for a in actions if 0 <= a < 6]),
72
+ # "actions_ids": actions,
73
+ # "actions_decoded": decode(actions)[:50],
74
+ # }
75
+
76
+ with gr.Blocks(title="Voxel Path Finder") as demo:
77
+ gr.Markdown("## 3D Voxel Path Finder — Inference")
78
+ with gr.Tab("Random environment"):
79
+ obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability")
80
+ seed = gr.Number(value=None, label="Seed (optional)")
81
+ btn = gr.Button("Run inference")
82
+ out = gr.JSON(label="Result")
83
+ btn.click(infer_random, inputs=[obstacle, seed], outputs=out)
84
+
85
+ with gr.Tab("Upload .npz sample"):
86
+ file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)")
87
+ btn2 = gr.Button("Run inference")
88
+ out2 = gr.JSON(label="Result")
89
+ btn2.click(infer_npz, inputs=file, outputs=out2)
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch()