|
import os |
|
from pathlib import Path |
|
import numpy as np |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from pathfinding_nn import PathfindingNetwork, create_voxel_input |
|
|
|
ACTION_NAMES = ['FORWARD','BACK','LEFT','RIGHT','UP','DOWN'] |
|
|
|
def load_model(): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = PathfindingNetwork().to(device).eval() |
|
|
|
|
|
local_ckpt = Path('training_outputs/final_model.pth') |
|
ckpt_path = None |
|
if local_ckpt.exists(): |
|
ckpt_path = str(local_ckpt) |
|
else: |
|
|
|
repo_id = os.getenv('MODEL_REPO_ID', '') |
|
filename = os.getenv('MODEL_FILENAME', 'final_model.pth') |
|
if repo_id: |
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
if ckpt_path is None: |
|
raise FileNotFoundError("Model checkpoint not found. Upload to training_outputs/final_model.pth or set MODEL_REPO_ID+MODEL_FILENAME env vars.") |
|
|
|
ckpt = torch.load(ckpt_path, map_location=device) |
|
state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt |
|
model.load_state_dict(state) |
|
return model, device |
|
|
|
MODEL, DEVICE = load_model() |
|
|
|
def decode(actions): |
|
return [ACTION_NAMES[a] for a in actions if 0 <= a < 6] |
|
|
|
def infer_random(obstacle_prob=0.2, seed=None): |
|
if seed is not None: |
|
np.random.seed(int(seed)) |
|
voxel_dim = MODEL.voxel_dim |
|
D,H,W = voxel_dim |
|
obstacles = (np.random.rand(D,H,W) < float(obstacle_prob)).astype(np.float32) |
|
free = np.argwhere(obstacles == 0) |
|
if len(free) < 2: |
|
return {"error": "Not enough free cells; lower obstacle_prob."} |
|
s_idx, g_idx = np.random.choice(len(free), size=2, replace=False) |
|
start = tuple(free[s_idx]) |
|
goal = tuple(free[g_idx]) |
|
|
|
voxel_np = create_voxel_input(obstacles, start, goal, voxel_dim=voxel_dim) |
|
voxel = torch.from_numpy(voxel_np).float().unsqueeze(0).to(DEVICE) |
|
pos = torch.tensor([[start, goal]], dtype=torch.long, device=DEVICE) |
|
|
|
with torch.no_grad(): |
|
actions = MODEL(voxel, pos)[0].tolist() |
|
return { |
|
"start": start, |
|
"goal": goal, |
|
"num_actions": len([a for a in actions if 0 <= a < 6]), |
|
"actions_ids": actions, |
|
"actions_decoded": decode(actions)[:50] |
|
} |
|
|
|
def infer_npz(npz_file): |
|
if npz_file is None: |
|
return {"error": "Please upload a .npz with keys 'voxel_data' and 'positions'."} |
|
data = np.load(npz_file.name) |
|
voxel = torch.from_numpy(data['voxel_data']).float().unsqueeze(0).to(DEVICE) |
|
pos = torch.from_numpy(data['positions']).long().unsqueeze(0).to(DEVICE) |
|
with torch.no_grad(): |
|
actions = MODEL(voxel, pos)[0].tolist() |
|
return { |
|
"num_actions": len([a for a in actions if 0 <= a < 6]), |
|
"actions_ids": actions, |
|
"actions_decoded": decode(actions)[:50] |
|
} |
|
|
|
with gr.Blocks(title="Voxel Path Finder") as demo: |
|
gr.Markdown("## 3D Voxel Path Finder — Inference") |
|
with gr.Tab("Random environment"): |
|
obstacle = gr.Slider(0.0, 0.9, value=0.2, step=0.05, label="Obstacle probability") |
|
seed = gr.Number(value=None, label="Seed (optional)") |
|
btn = gr.Button("Run inference") |
|
out = gr.JSON(label="Result") |
|
btn.click(infer_random, inputs=[obstacle, seed], outputs=out) |
|
|
|
with gr.Tab("Upload .npz sample"): |
|
file = gr.File(file_types=[".npz"], label="Upload sample (voxel_data, positions)") |
|
btn2 = gr.Button("Run inference") |
|
out2 = gr.JSON(label="Result") |
|
btn2.click(infer_npz, inputs=file, outputs=out2) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |