DA-2-WebGPU / script.js
phiph's picture
Update script.js
db5acd6 verified
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/[email protected]';
// Skip local model checks since we are fetching from HF Hub
env.allowLocalModels = false;
// Enable caching
env.useBrowserCache = true;
const MODEL_ID = 'phiph/DA-2-WebGPU';
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;
let depth_estimator = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');
// Initialize Transformers.js Pipeline
async function init() {
try {
statusElement.textContent = 'Loading model... (this may take a while)';
// Initialize the pipeline
depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
device: 'webgpu',
dtype: 'fp32', // Important: Model is FP32
quantized: false
});
statusElement.textContent = 'Model loaded. Ready.';
runBtn.disabled = false;
} catch (e) {
console.error(e);
statusElement.textContent = 'Error loading model: ' + e.message;
// Fallback to wasm if webgpu fails
try {
statusElement.textContent = 'WebGPU failed, trying WASM...';
depth_estimator = await pipeline('depth-estimation', MODEL_ID, {
device: 'wasm',
dtype: 'fp32',
quantized: false
});
statusElement.textContent = 'Model loaded (WASM). Ready.';
runBtn.disabled = false;
} catch (e2) {
statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
}
}
}
imageInput.addEventListener('change', (e) => {
const file = e.target.files[0];
if (!file) return;
const img = new Image();
img.onload = () => {
inputCanvas.width = INPUT_WIDTH;
inputCanvas.height = INPUT_HEIGHT;
inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
// Clear output
outputCanvas.width = INPUT_WIDTH;
outputCanvas.height = INPUT_HEIGHT;
outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
};
img.src = URL.createObjectURL(file);
});
runBtn.addEventListener('click', async () => {
if (!depth_estimator) return;
statusElement.textContent = 'Running inference...';
runBtn.disabled = true;
try {
// Get the image source from the canvas (or the file URL directly)
// Using the canvas data ensures we are passing what the user sees
const url = inputCanvas.toDataURL();
// Run inference
// The pipeline handles preprocessing (resize, rescale) automatically
const output = await depth_estimator(url);
// output.depth is the raw tensor
// output.mask is the visualized depth map (Image object) if available,
// but for custom models it might just return the tensor.
// Let's check what we got
if (output.depth) {
// Visualize the raw tensor manually to be safe
visualize(output.depth.data, INPUT_WIDTH, INPUT_HEIGHT);
} else {
// Fallback if structure is different
console.log("Output structure:", output);
statusElement.textContent = 'Done (Check console for output structure).';
}
statusElement.textContent = 'Done.';
} catch (e) {
console.error(e);
statusElement.textContent = 'Error running inference: ' + e.message;
} finally {
runBtn.disabled = false;
}
});
function visualize(data, width, height) {
// Find min and max for normalization
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] < min) min = data[i];
if (data[i] > max) max = data[i];
}
const range = max - min;
const imageData = outputCtx.createImageData(width, height);
for (let i = 0; i < data.length; i++) {
// Normalize to 0-1
const val = (data[i] - min) / (range || 1);
// Simple heatmap (Magma-like or just grayscale)
// Inverted depth usually looks better (closer is brighter)
// But here it's distance, so closer is smaller value.
// If we map min (close) to 255 (white) and max (far) to 0 (black)
const pixelVal = Math.floor((1 - val) * 255);
imageData.data[i * 4] = pixelVal; // R
imageData.data[i * 4 + 1] = pixelVal; // G
imageData.data[i * 4 + 2] = pixelVal; // B
imageData.data[i * 4 + 3] = 255; // Alpha
}
outputCtx.putImageData(imageData, 0, 0);
}
init();