Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,92 @@
|
|
1 |
-
|
2 |
-
import
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
if not os.path.exists(ckpt_path):
|
9 |
-
# fallback: find any .pth
|
10 |
-
pths = [p for p in glob.glob(os.path.join(local_dir, "**", "*.pth"), recursive=True)]
|
11 |
-
if not pths:
|
12 |
-
raise FileNotFoundError("No .pth found in repo snapshot.")
|
13 |
-
ckpt_path = pths[0]
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
state = ckpt.get("model_state_dict", ckpt)
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib.util
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
+
REPO_ID = "c1tr0n75/VoxelPathFinder"
|
9 |
|
10 |
+
# 1) Make sure torch is imported, then define device BEFORE using it anywhere
|
11 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# 2) Download model code and weights from your model repo
|
14 |
+
PY_PATH = hf_hub_download(repo_id=REPO_ID, filename="pathfinding_nn.py")
|
15 |
+
CKPT_PATH = hf_hub_download(repo_id=REPO_ID, filename="training_outputs/final_model.pth")
|
16 |
+
|
17 |
+
# 3) Dynamically import your model definitions
|
18 |
+
spec = importlib.util.spec_from_file_location("pathfinding_nn", PY_PATH)
|
19 |
+
mod = importlib.util.module_from_spec(spec)
|
20 |
+
spec.loader.exec_module(mod)
|
21 |
+
PathfindingNetwork = mod.PathfindingNetwork
|
22 |
+
create_voxel_input = mod.create_voxel_input
|
23 |
+
|
24 |
+
# 4) Build and load model
|
25 |
+
MODEL = PathfindingNetwork().to(DEVICE).eval()
|
26 |
+
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
|
27 |
state = ckpt.get("model_state_dict", ckpt)
|
28 |
+
MODEL.load_state_dict(state)
|
29 |
+
|
30 |
+
ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN']
|
31 |
+
|
32 |
+
def decode(actions):
|
33 |
+
return [ACTION_NAMES[a] for a in actions if 0 <= a < 6]
|
34 |
+
|
35 |
+
def infer_random(obstacle_prob=0.2, seed=None):
|
36 |
+
if seed is not None:
|
37 |
+
np.random.seed(int(seed))
|
38 |
+
voxel_dim = MODEL.voxel_dim
|
39 |
+
D, H, W = voxel_dim
|
40 |
+
obstacles = (np.random.rand(D, H, W) < float(obstacle_prob)).astype(np.float32)
|
41 |
+
free = np.argwhere(obstacles == 0)
|
42 |
+
if len(free) < 2:
|
43 |
+
return {"error": "Not enough free cells; lower obstacle_prob."}
|
44 |
+
s_idx, g_idx = np.random.choice(len(free), size=2, replace=False)
|
45 |
+
start = tuple(free[s_idx]); goal = tuple(free[g_idx])
|
46 |
+
|
47 |
+
voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim)
|
48 |
+
voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE)
|
49 |
+
pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE)
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
actions = MODEL(voxel, pos)[0].tolist()
|
53 |
+
|
54 |
+
return {
|
55 |
+
"start": start,
|
56 |
+
"goal": goal,
|
57 |
+
"num_actions": len([a for a in actions if 0 <= a < 6]),
|
58 |
+
"actions_ids": actions,
|
59 |
+
"actions_decoded": decode(actions)[:50],
|
60 |
+
}
|
61 |
+
|
62 |
+
#def infer_npz(npz_file):
|
63 |
+
# if npz_file is None:
|
64 |
+
# return {"error": "Upload a .npz with 'voxel_data' and 'positions'."}
|
65 |
+
# data = np.load(npz_file.name)
|
66 |
+
# voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE)
|
67 |
+
# pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE)
|
68 |
+
# with torch.no_grad():
|
69 |
+
# actions = MODEL(voxel, pos)[0].tolist()
|
70 |
+
# return {
|
71 |
+
# "num_actions": len([a for a in actions if 0 <= a < 6]),
|
72 |
+
# "actions_ids": actions,
|
73 |
+
# "actions_decoded": decode(actions)[:50],
|
74 |
+
# }
|
75 |
+
|
76 |
+
with gr.Blocks(title="Voxel Path Finder") as demo:
|
77 |
+
gr.Markdown("## 3D Voxel Path Finder — Inference")
|
78 |
+
with gr.Tab("Random environment"):
|
79 |
+
obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability")
|
80 |
+
seed = gr.Number(value=None, label="Seed (optional)")
|
81 |
+
btn = gr.Button("Run inference")
|
82 |
+
out = gr.JSON(label="Result")
|
83 |
+
btn.click(infer_random, inputs=[obstacle, seed], outputs=out)
|
84 |
+
|
85 |
+
with gr.Tab("Upload .npz sample"):
|
86 |
+
file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)")
|
87 |
+
btn2 = gr.Button("Run inference")
|
88 |
+
out2 = gr.JSON(label="Result")
|
89 |
+
btn2.click(infer_npz, inputs=file, outputs=out2)
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
demo.launch()
|