<html> <head> <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> <title>Candle Segment Anything Model (SAM) Rust/WASM</title> </head> <body></body> </html> <!DOCTYPE html> <html> <head> <meta charset="UTF-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <style> @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); html, body { font-family: "Source Sans 3", sans-serif; } </style> <script src="https://cdn.tailwindcss.com/3.4.3"></script> <script type="module"> // base url for image examples const MODEL_BASEURL = "https://huggingface.co/lmz/candle-sam/resolve/main/"; // models base url const MODELS = { sam_mobile_tiny: { url: "mobile_sam-tiny-vitt.safetensors", }, sam_base: { url: "sam_vit_b_01ec64.safetensors", }, }; const samWorker = new Worker("./samWorker.js", { type: "module" }); async function segmentPoints( modelURL, // URL to the weights file modelID, // model ID imageURL, // URL to the image file points // {x, y} points to prompt image ) { return new Promise((resolve, reject) => { function messageHandler(event) { console.log(event.data); if ("status" in event.data) { updateStatus(event.data); } if ("error" in event.data) { samWorker.removeEventListener("message", messageHandler); reject(new Error(event.data.error)); } if (event.data.status === "complete-embedding") { samWorker.removeEventListener("message", messageHandler); resolve(); } if (event.data.status === "complete") { samWorker.removeEventListener("message", messageHandler); resolve(event.data.output); } } samWorker.addEventListener("message", messageHandler); samWorker.postMessage({ modelURL, modelID, imageURL, points, }); }); } function updateStatus(statusMessage) { statusOutput.innerText = event.data.message; } let copyMaskURL = null; let copyImageURL = null; const clearBtn = document.querySelector("#clear-btn"); const maskBtn = document.querySelector("#mask-btn"); const undoBtn = document.querySelector("#undo-btn"); const downloadBtn = document.querySelector("#download-btn"); const canvas = document.querySelector("#canvas"); const mask = document.querySelector("#mask"); const ctxCanvas = canvas.getContext("2d"); const ctxMask = mask.getContext("2d"); const fileUpload = document.querySelector("#file-upload"); const dropArea = document.querySelector("#drop-area"); const dropButtons = document.querySelector("#drop-buttons"); const imagesExamples = document.querySelector("#image-select"); const modelSelection = document.querySelector("#model"); const statusOutput = document.querySelector("#output-status"); //add event listener to file input fileUpload.addEventListener("input", (e) => { const target = e.target; if (target.files.length > 0) { const href = URL.createObjectURL(target.files[0]); clearImageCanvas(); copyImageURL = href; drawImageCanvas(href); setImageEmbeddings(href); togglePointMode(false); } }); // add event listener to drop-area dropArea.addEventListener("dragenter", (e) => { e.preventDefault(); dropArea.classList.add("border-blue-700"); }); dropArea.addEventListener("dragleave", (e) => { e.preventDefault(); dropArea.classList.remove("border-blue-700"); }); dropArea.addEventListener("dragover", (e) => { e.preventDefault(); }); dropArea.addEventListener("drop", (e) => { e.preventDefault(); dropArea.classList.remove("border-blue-700"); const url = e.dataTransfer.getData("text/uri-list"); const files = e.dataTransfer.files; if (files.length > 0) { const href = URL.createObjectURL(files[0]); clearImageCanvas(); copyImageURL = href; drawImageCanvas(href); setImageEmbeddings(href); togglePointMode(false); } else if (url) { clearImageCanvas(); copyImageURL = url; drawImageCanvas(url); setImageEmbeddings(url); togglePointMode(false); } }); let hasImage = false; let isSegmenting = false; let isEmbedding = false; let currentImageURL = ""; let pointArr = []; let bgPointMode = false; //add event listener to image examples imagesExamples.addEventListener("click", (e) => { if (isEmbedding || isSegmenting) { return; } const target = e.target; if (target.nodeName === "IMG") { const href = target.src; clearImageCanvas(); drawImageCanvas(href); setImageEmbeddings(href); copyImageURL = href; } }); //add event listener to mask button maskBtn.addEventListener("click", () => { togglePointMode(); }); //add event listener to clear button clearBtn.addEventListener("click", () => { clearImageCanvas(); togglePointMode(false); pointArr = []; }); //add event listener to undo button undoBtn.addEventListener("click", () => { undoPoint(); }); // add event to download btn downloadBtn.addEventListener("click", async () => { // Function to load image blobs as Image elements asynchronously const loadImageAsync = (imageURL) => { return new Promise((resolve) => { const img = new Image(); img.onload = () => { resolve(img); }; img.crossOrigin = "anonymous"; img.src = imageURL; }); }; const originalImage = await loadImageAsync(copyImageURL); const maskImage = await loadImageAsync(copyMaskURL); // create main a board to draw const canvas = document.createElement("canvas"); const ctx = canvas.getContext("2d"); canvas.width = originalImage.width; canvas.height = originalImage.height; // Perform the mask operation ctx.drawImage(maskImage, 0, 0); ctx.globalCompositeOperation = "source-in"; ctx.drawImage(originalImage, 0, 0); // to blob const blobPromise = new Promise((resolve) => { canvas.toBlob(resolve); }); const blob = await blobPromise; const resultURL = URL.createObjectURL(blob); // download const link = document.createElement("a"); link.href = resultURL; link.download = "cutout.png"; link.click(); }); //add click event to canvas canvas.addEventListener("click", async (event) => { if (!hasImage || isEmbedding || isSegmenting) { return; } const backgroundMode = event.shiftKey ? bgPointMode^event.shiftKey : bgPointMode; const targetBox = event.target.getBoundingClientRect(); const x = (event.clientX - targetBox.left) / targetBox.width; const y = (event.clientY - targetBox.top) / targetBox.height; const ptsToRemove = []; for (const [idx, pts] of pointArr.entries()) { const d = Math.sqrt((pts[0] - x) ** 2 + (pts[1] - y) ** 2); if (d < 6 / targetBox.width) { ptsToRemove.push(idx); } } if (ptsToRemove.length > 0) { pointArr = pointArr.filter((_, idx) => !ptsToRemove.includes(idx)); } else { pointArr = [...pointArr, [x, y, !backgroundMode]]; } undoBtn.disabled = false; downloadBtn.disabled = false; if (pointArr.length == 0) { ctxMask.clearRect(0, 0, canvas.width, canvas.height); undoBtn.disabled = true; downloadBtn.disabled = true; return; } isSegmenting = true; const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; copyMaskURL = maskURL; drawMask(maskURL, pointArr); }); async function undoPoint() { if (!hasImage || isEmbedding || isSegmenting) { return; } if (pointArr.length === 0) { return; } pointArr.pop(); if (pointArr.length === 0) { ctxMask.clearRect(0, 0, canvas.width, canvas.height); undoBtn.disabled = true; return; } isSegmenting = true; const { maskURL } = await getSegmentationMask(pointArr); isSegmenting = false; copyMaskURL = maskURL; drawMask(maskURL, pointArr); } function togglePointMode(mode) { bgPointMode = mode === undefined ? !bgPointMode : mode; maskBtn.querySelector("span").innerText = bgPointMode ? "Background Point" : "Mask Point"; if (bgPointMode) { maskBtn.querySelector("#mask-circle").setAttribute("hidden", ""); maskBtn.querySelector("#unmask-circle").removeAttribute("hidden"); } else { maskBtn.querySelector("#mask-circle").removeAttribute("hidden"); maskBtn.querySelector("#unmask-circle").setAttribute("hidden", ""); } } async function getSegmentationMask(points) { const modelID = modelSelection.value; const modelURL = MODEL_BASEURL + MODELS[modelID].url; const imageURL = currentImageURL; const { maskURL } = await segmentPoints( modelURL, modelID, imageURL, points ); return { maskURL }; } async function setImageEmbeddings(imageURL) { if (isEmbedding) { return; } canvas.classList.remove("cursor-pointer"); canvas.classList.add("cursor-wait"); clearBtn.disabled = true; const modelID = modelSelection.value; const modelURL = MODEL_BASEURL + MODELS[modelID].url; isEmbedding = true; await segmentPoints(modelURL, modelID, imageURL); canvas.classList.remove("cursor-wait"); canvas.classList.add("cursor-pointer"); clearBtn.disabled = false; isEmbedding = false; currentImageURL = imageURL; } function clearImageCanvas() { ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); ctxMask.clearRect(0, 0, canvas.width, canvas.height); hasImage = false; isEmbedding = false; isSegmenting = false; currentImageURL = ""; pointArr = []; clearBtn.disabled = true; canvas.parentElement.style.height = "auto"; dropButtons.classList.remove("invisible"); } function drawMask(maskURL, points) { if (!maskURL) { throw new Error("No mask URL provided"); } const img = new Image(); img.crossOrigin = "anonymous"; img.onload = () => { mask.width = canvas.width; mask.height = canvas.height; ctxMask.save(); ctxMask.drawImage(canvas, 0, 0); ctxMask.globalCompositeOperation = "source-atop"; ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)"; ctxMask.fillRect(0, 0, canvas.width, canvas.height); ctxMask.globalCompositeOperation = "destination-in"; ctxMask.drawImage(img, 0, 0); ctxMask.globalCompositeOperation = "source-over"; for (const pt of points) { if (pt[2]) { ctxMask.fillStyle = "rgba(0, 255, 255, 1)"; } else { ctxMask.fillStyle = "rgba(255, 255, 0, 1)"; } ctxMask.beginPath(); ctxMask.arc( pt[0] * canvas.width, pt[1] * canvas.height, 3, 0, 2 * Math.PI ); ctxMask.fill(); } ctxMask.restore(); }; img.src = maskURL; } function drawImageCanvas(imgURL) { if (!imgURL) { throw new Error("No image URL provided"); } ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); const img = new Image(); img.crossOrigin = "anonymous"; img.onload = () => { canvas.width = img.width; canvas.height = img.height; ctxCanvas.drawImage(img, 0, 0); canvas.parentElement.style.height = canvas.offsetHeight + "px"; hasImage = true; clearBtn.disabled = false; dropButtons.classList.add("invisible"); }; img.src = imgURL; } const observer = new ResizeObserver((entries) => { for (let entry of entries) { if (entry.target === canvas) { canvas.parentElement.style.height = canvas.offsetHeight + "px"; } } }); observer.observe(canvas); </script> </head> <body class="container max-w-4xl mx-auto p-4"> <main class="grid grid-cols-1 gap-8 relative"> <span class="absolute text-5xl -ml-[1em]">🕯️</span> <div> <h1 class="text-5xl font-bold">Candle Segment Anything</h1> <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> <p class="max-w-lg"> Zero-shot image segmentation with <a href="https://segment-anything.com" class="underline hover:text-blue-500 hover:no-underline" target="_blank" >Segment Anything Model (SAM)</a > and <a href="https://github.com/ChaoningZhang/MobileSAM" class="underline hover:text-blue-500 hover:no-underline" target="_blank" >MobileSAM </a >. It runs in the browser with a WASM runtime built with <a href="https://github.com/huggingface/candle/" target="_blank" class="underline hover:text-blue-500 hover:no-underline" >Candle </a> </p> </div> <div> <label for="model" class="font-medium">Models Options: </label> <select id="model" class="border-2 border-gray-500 rounded-md font-light"> <option value="sam_mobile_tiny" selected> Mobile SAM Tiny (40.6 MB) </option> <option value="sam_base">SAM Base (375 MB)</option> </select> </div> <div> <p class="text-xs italic max-w-lg"> <b>Note:</b> The model's first run may take a few seconds as it loads and caches the model in the browser, and then creates the image embeddings. Any subsequent clicks on points will be significantly faster. </p> </div> <div class="relative max-w-2xl"> <div class="flex justify-between items-center"> <div class="px-2 rounded-md inline text-xs"> <span id="output-status" class="m-auto font-light"></span> </div> <div class="flex gap-2"> <button id="mask-btn" title="Toggle Mask Point and Background Point" class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center"> <span>Mask Point</span> <svg xmlns="http://www.w3.org/2000/svg" height="1em" viewBox="0 0 512 512"> <path id="mask-circle" d="M256 512a256 256 0 1 0 0-512 256 256 0 1 0 0 512z" /> <path id="unmask-circle" hidden d="M464 256a208 208 0 1 0-416 0 208 208 0 1 0 416 0zM0 256a256 256 0 1 1 512 0 256 256 0 1 1-512 0z" /> </svg> </button> <button id="undo-btn" disabled title="Undo Last Point" class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center"> <svg xmlns="http://www.w3.org/2000/svg" height="1em" viewBox="0 0 512 512"> <path d="M48.5 224H40a24 24 0 0 1-24-24V72a24 24 0 0 1 41-17l41.6 41.6a224 224 0 1 1-1 317.8 32 32 0 0 1 45.3-45.3 160 160 0 1 0 1-227.3L185 183a24 24 0 0 1-17 41H48.5z" /> </svg> </button> <button id="clear-btn" disabled title="Clear Image" class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center"> <svg class="" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 13 12" height="1em"> <path d="M1.6.7 12 11.1M12 .7 1.6 11.1" stroke="#2E3036" stroke-width="2" /> </svg> </button> </div> </div> <div id="drop-area" class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden"> <div id="drop-buttons" class="flex flex-col items-center justify-center space-y-1 text-center relative z-10"> <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg"> <path d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z" fill="#000" /> </svg> <div class="flex text-sm text-gray-600"> <label for="file-upload" class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"> <span>Drag and drop your image here</span> <span class="block text-xs">or</span> <span class="block text-xs">Click to upload</span> </label> </div> <input id="file-upload" name="file-upload" type="file" class="sr-only" /> </div> <canvas id="canvas" class="absolute w-full"></canvas> <canvas id="mask" class="pointer-events-none absolute w-full"></canvas> </div> <div class="text-right py-2"> <button id="share-btn" class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible"> <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" /> </button> <button id="download-btn" title="Copy result (.png)" disabled class="p-1 px-2 text-xs font-medium bg-white rounded-2xl outline outline-gray-200 hover:outline-orange-200 disabled:opacity-50" > Download the result (png file) </button> </div> </div> <div> <div class="flex gap-3 items-center overflow-x-scroll" id="image-select"> <h3 class="font-medium">Examples:</h3> <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg" class="cursor-pointer w-24 h-24 object-cover" /> <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg" class="cursor-pointer w-24 h-24 object-cover" /> <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg" class="cursor-pointer w-24 h-24 object-cover" /> </div> </div> </main> </body> </html>