Spaces:
Running
Running
| import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/[email protected]'; | |
| // Skip local model checks since we are fetching from HF Hub | |
| env.allowLocalModels = false; | |
| // Enable caching | |
| env.useBrowserCache = true; | |
| const MODEL_ID = 'phiph/DA-2-WebGPU'; | |
| const INPUT_WIDTH = 1092; | |
| const INPUT_HEIGHT = 546; | |
| let depth_estimator = null; | |
| const statusElement = document.getElementById('status'); | |
| const runBtn = document.getElementById('runBtn'); | |
| const imageInput = document.getElementById('imageInput'); | |
| const inputCanvas = document.getElementById('inputCanvas'); | |
| const outputCanvas = document.getElementById('outputCanvas'); | |
| const inputCtx = inputCanvas.getContext('2d'); | |
| const outputCtx = outputCanvas.getContext('2d'); | |
| // Initialize Transformers.js Pipeline | |
| async function init() { | |
| try { | |
| statusElement.textContent = 'Loading model... (this may take a while)'; | |
| // Initialize the pipeline | |
| depth_estimator = await pipeline('depth-estimation', MODEL_ID, { | |
| device: 'webgpu', | |
| dtype: 'fp32', // Important: Model is FP32 | |
| quantized: false | |
| }); | |
| statusElement.textContent = 'Model loaded. Ready.'; | |
| runBtn.disabled = false; | |
| } catch (e) { | |
| console.error(e); | |
| statusElement.textContent = 'Error loading model: ' + e.message; | |
| // Fallback to wasm if webgpu fails | |
| try { | |
| statusElement.textContent = 'WebGPU failed, trying WASM...'; | |
| depth_estimator = await pipeline('depth-estimation', MODEL_ID, { | |
| device: 'wasm', | |
| dtype: 'fp32', | |
| quantized: false | |
| }); | |
| statusElement.textContent = 'Model loaded (WASM). Ready.'; | |
| runBtn.disabled = false; | |
| } catch (e2) { | |
| statusElement.textContent = 'Error loading model (WASM): ' + e2.message; | |
| } | |
| } | |
| } | |
| imageInput.addEventListener('change', (e) => { | |
| const file = e.target.files[0]; | |
| if (!file) return; | |
| const img = new Image(); | |
| img.onload = () => { | |
| inputCanvas.width = INPUT_WIDTH; | |
| inputCanvas.height = INPUT_HEIGHT; | |
| inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT); | |
| // Clear output | |
| outputCanvas.width = INPUT_WIDTH; | |
| outputCanvas.height = INPUT_HEIGHT; | |
| outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT); | |
| }; | |
| img.src = URL.createObjectURL(file); | |
| }); | |
| runBtn.addEventListener('click', async () => { | |
| if (!depth_estimator) return; | |
| statusElement.textContent = 'Running inference...'; | |
| runBtn.disabled = true; | |
| try { | |
| // Get the image source from the canvas (or the file URL directly) | |
| // Using the canvas data ensures we are passing what the user sees | |
| const url = inputCanvas.toDataURL(); | |
| // Run inference | |
| // The pipeline handles preprocessing (resize, rescale) automatically | |
| const output = await depth_estimator(url); | |
| // output.depth is the raw tensor | |
| // output.mask is the visualized depth map (Image object) if available, | |
| // but for custom models it might just return the tensor. | |
| // Let's check what we got | |
| if (output.depth) { | |
| // Visualize the raw tensor manually to be safe | |
| visualize(output.depth.data, INPUT_WIDTH, INPUT_HEIGHT); | |
| } else { | |
| // Fallback if structure is different | |
| console.log("Output structure:", output); | |
| statusElement.textContent = 'Done (Check console for output structure).'; | |
| } | |
| statusElement.textContent = 'Done.'; | |
| } catch (e) { | |
| console.error(e); | |
| statusElement.textContent = 'Error running inference: ' + e.message; | |
| } finally { | |
| runBtn.disabled = false; | |
| } | |
| }); | |
| function visualize(data, width, height) { | |
| // Find min and max for normalization | |
| let min = Infinity; | |
| let max = -Infinity; | |
| for (let i = 0; i < data.length; i++) { | |
| if (data[i] < min) min = data[i]; | |
| if (data[i] > max) max = data[i]; | |
| } | |
| const range = max - min; | |
| const imageData = outputCtx.createImageData(width, height); | |
| for (let i = 0; i < data.length; i++) { | |
| // Normalize to 0-1 | |
| const val = (data[i] - min) / (range || 1); | |
| // Simple heatmap (Magma-like or just grayscale) | |
| // Inverted depth usually looks better (closer is brighter) | |
| // But here it's distance, so closer is smaller value. | |
| // If we map min (close) to 255 (white) and max (far) to 0 (black) | |
| const pixelVal = Math.floor((1 - val) * 255); | |
| imageData.data[i * 4] = pixelVal; // R | |
| imageData.data[i * 4 + 1] = pixelVal; // G | |
| imageData.data[i * 4 + 2] = pixelVal; // B | |
| imageData.data[i * 4 + 3] = 255; // Alpha | |
| } | |
| outputCtx.putImageData(imageData, 0, 0); | |
| } | |
| init(); | |