<html>
  <head>
    <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
    <title>Candle Segment Anything Model (SAM) Rust/WASM</title>
  </head>
  <body></body>
</html>

<!DOCTYPE html>
<html>
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <style>
      @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
      html,
      body {
        font-family: "Source Sans 3", sans-serif;
      }
    </style>
    <script src="https://cdn.tailwindcss.com/3.4.3"></script>
    <script type="module">
      // base url for image examples
      const MODEL_BASEURL =
        "https://huggingface.co/lmz/candle-sam/resolve/main/";

      // models base url
      const MODELS = {
        sam_mobile_tiny: {
          url: "mobile_sam-tiny-vitt.safetensors",
        },
        sam_base: {
          url: "sam_vit_b_01ec64.safetensors",
        },
      };
      const samWorker = new Worker("./samWorker.js", { type: "module" });

      async function segmentPoints(
        modelURL, // URL to the weights file
        modelID, // model ID
        imageURL, // URL to the image file
        points // {x, y} points to prompt image
      ) {
        return new Promise((resolve, reject) => {
          function messageHandler(event) {
            console.log(event.data);
            if ("status" in event.data) {
              updateStatus(event.data);
            }
            if ("error" in event.data) {
              samWorker.removeEventListener("message", messageHandler);
              reject(new Error(event.data.error));
            }
            if (event.data.status === "complete-embedding") {
              samWorker.removeEventListener("message", messageHandler);
              resolve();
            }
            if (event.data.status === "complete") {
              samWorker.removeEventListener("message", messageHandler);
              resolve(event.data.output);
            }
          }
          samWorker.addEventListener("message", messageHandler);
          samWorker.postMessage({
            modelURL,
            modelID,
            imageURL,
            points,
          });
        });
      }
      function updateStatus(statusMessage) {
        statusOutput.innerText = event.data.message;
      }

      let copyMaskURL = null;
      let copyImageURL = null;
      const clearBtn = document.querySelector("#clear-btn");
      const maskBtn = document.querySelector("#mask-btn");
      const undoBtn = document.querySelector("#undo-btn");
      const downloadBtn = document.querySelector("#download-btn");
      const canvas = document.querySelector("#canvas");
      const mask = document.querySelector("#mask");
      const ctxCanvas = canvas.getContext("2d");
      const ctxMask = mask.getContext("2d");
      const fileUpload = document.querySelector("#file-upload");
      const dropArea = document.querySelector("#drop-area");
      const dropButtons = document.querySelector("#drop-buttons");
      const imagesExamples = document.querySelector("#image-select");
      const modelSelection = document.querySelector("#model");
      const statusOutput = document.querySelector("#output-status");

      //add event listener to file input
      fileUpload.addEventListener("input", (e) => {
        const target = e.target;
        if (target.files.length > 0) {
          const href = URL.createObjectURL(target.files[0]);
          clearImageCanvas();
          copyImageURL = href;
          drawImageCanvas(href);
          setImageEmbeddings(href);
          togglePointMode(false);
        }
      });
      // add event listener to drop-area
      dropArea.addEventListener("dragenter", (e) => {
        e.preventDefault();
        dropArea.classList.add("border-blue-700");
      });
      dropArea.addEventListener("dragleave", (e) => {
        e.preventDefault();
        dropArea.classList.remove("border-blue-700");
      });
      dropArea.addEventListener("dragover", (e) => {
        e.preventDefault();
      });
      dropArea.addEventListener("drop", (e) => {
        e.preventDefault();
        dropArea.classList.remove("border-blue-700");
        const url = e.dataTransfer.getData("text/uri-list");
        const files = e.dataTransfer.files;

        if (files.length > 0) {
          const href = URL.createObjectURL(files[0]);
          clearImageCanvas();
          copyImageURL = href;
          drawImageCanvas(href);
          setImageEmbeddings(href);
          togglePointMode(false);
        } else if (url) {
          clearImageCanvas();
          copyImageURL = url;
          drawImageCanvas(url);
          setImageEmbeddings(url);
          togglePointMode(false);
        }
      });

      let hasImage = false;
      let isSegmenting = false;
      let isEmbedding = false;
      let currentImageURL = "";
      let pointArr = [];
      let bgPointMode = false;
      //add event listener to image examples
      imagesExamples.addEventListener("click", (e) => {
        if (isEmbedding || isSegmenting) {
          return;
        }
        const target = e.target;
        if (target.nodeName === "IMG") {
          const href = target.src;
          clearImageCanvas();
          drawImageCanvas(href);
          setImageEmbeddings(href);
          copyImageURL = href;
        }
      });
      //add event listener to mask button
      maskBtn.addEventListener("click", () => {
        togglePointMode();
      });
      //add event listener to clear button
      clearBtn.addEventListener("click", () => {
        clearImageCanvas();
        togglePointMode(false);
        pointArr = [];
      });
      //add event listener to undo button
      undoBtn.addEventListener("click", () => {
        undoPoint();
      });
      // add event to download btn
      downloadBtn.addEventListener("click", async () => {
        // Function to load image blobs as Image elements asynchronously
        const loadImageAsync = (imageURL) => {
          return new Promise((resolve) => {
            const img = new Image();
            img.onload = () => {
              resolve(img);
            };
            img.crossOrigin = "anonymous";
            img.src = imageURL;
          });
        };
        const originalImage = await loadImageAsync(copyImageURL);
        const maskImage = await loadImageAsync(copyMaskURL);

        // create main a board to draw
        const canvas = document.createElement("canvas");
        const ctx = canvas.getContext("2d");
        canvas.width = originalImage.width;
        canvas.height = originalImage.height;

        // Perform the mask operation
        ctx.drawImage(maskImage, 0, 0);
        ctx.globalCompositeOperation = "source-in";
        ctx.drawImage(originalImage, 0, 0);

        // to blob
        const blobPromise = new Promise((resolve) => {
          canvas.toBlob(resolve);
        });
        const blob = await blobPromise;
        const resultURL = URL.createObjectURL(blob);

        // download
        const link = document.createElement("a");
        link.href = resultURL;
        link.download = "cutout.png";
        link.click();
      });
      //add click event to canvas
      canvas.addEventListener("click", async (event) => {
        if (!hasImage || isEmbedding || isSegmenting) {
          return;
        }
        const backgroundMode = event.shiftKey ? bgPointMode^event.shiftKey : bgPointMode;
        const targetBox = event.target.getBoundingClientRect();
        const x = (event.clientX - targetBox.left) / targetBox.width;
        const y = (event.clientY - targetBox.top) / targetBox.height;
        const ptsToRemove = [];
        for (const [idx, pts] of pointArr.entries()) {
          const d = Math.sqrt((pts[0] - x) ** 2 + (pts[1] - y) ** 2);
          if (d < 6 / targetBox.width) {
            ptsToRemove.push(idx);
          }
        }
        if (ptsToRemove.length > 0) {
          pointArr = pointArr.filter((_, idx) => !ptsToRemove.includes(idx));
        } else {
          pointArr = [...pointArr, [x, y, !backgroundMode]];
        }
        undoBtn.disabled = false;
        downloadBtn.disabled = false;
        if (pointArr.length == 0) {
          ctxMask.clearRect(0, 0, canvas.width, canvas.height);
          undoBtn.disabled = true;
          downloadBtn.disabled = true;
          return;
        }
        isSegmenting = true;
        const { maskURL } = await getSegmentationMask(pointArr);
        isSegmenting = false;
        copyMaskURL = maskURL;
        drawMask(maskURL, pointArr);
      });

      async function undoPoint() {
        if (!hasImage || isEmbedding || isSegmenting) {
          return;
        }
        if (pointArr.length === 0) {
          return;
        }
        pointArr.pop();
        if (pointArr.length === 0) {
          ctxMask.clearRect(0, 0, canvas.width, canvas.height);
          undoBtn.disabled = true;
          return;
        }
        isSegmenting = true;
        const { maskURL } = await getSegmentationMask(pointArr);
        isSegmenting = false;
        copyMaskURL = maskURL;
        drawMask(maskURL, pointArr);
      }
      function togglePointMode(mode) {
        bgPointMode = mode === undefined ? !bgPointMode : mode;

        maskBtn.querySelector("span").innerText = bgPointMode
          ? "Background Point"
          : "Mask Point";
        if (bgPointMode) {
          maskBtn.querySelector("#mask-circle").setAttribute("hidden", "");
          maskBtn.querySelector("#unmask-circle").removeAttribute("hidden");
        } else {
          maskBtn.querySelector("#mask-circle").removeAttribute("hidden");
          maskBtn.querySelector("#unmask-circle").setAttribute("hidden", "");
        }
      }

      async function getSegmentationMask(points) {
        const modelID = modelSelection.value;
        const modelURL = MODEL_BASEURL + MODELS[modelID].url;
        const imageURL = currentImageURL;
        const { maskURL } = await segmentPoints(
          modelURL,
          modelID,
          imageURL,
          points
        );
        return { maskURL };
      }
      async function setImageEmbeddings(imageURL) {
        if (isEmbedding) {
          return;
        }
        canvas.classList.remove("cursor-pointer");
        canvas.classList.add("cursor-wait");
        clearBtn.disabled = true;
        const modelID = modelSelection.value;
        const modelURL = MODEL_BASEURL + MODELS[modelID].url;
        isEmbedding = true;
        await segmentPoints(modelURL, modelID, imageURL);
        canvas.classList.remove("cursor-wait");
        canvas.classList.add("cursor-pointer");
        clearBtn.disabled = false;
        isEmbedding = false;
        currentImageURL = imageURL;
      }

      function clearImageCanvas() {
        ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
        ctxMask.clearRect(0, 0, canvas.width, canvas.height);
        hasImage = false;
        isEmbedding = false;
        isSegmenting = false;
        currentImageURL = "";
        pointArr = [];
        clearBtn.disabled = true;
        canvas.parentElement.style.height = "auto";
        dropButtons.classList.remove("invisible");
      }
      function drawMask(maskURL, points) {
        if (!maskURL) {
          throw new Error("No mask URL provided");
        }

        const img = new Image();
        img.crossOrigin = "anonymous";

        img.onload = () => {
          mask.width = canvas.width;
          mask.height = canvas.height;
          ctxMask.save();
          ctxMask.drawImage(canvas, 0, 0);
          ctxMask.globalCompositeOperation = "source-atop";
          ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
          ctxMask.fillRect(0, 0, canvas.width, canvas.height);
          ctxMask.globalCompositeOperation = "destination-in";
          ctxMask.drawImage(img, 0, 0);
          ctxMask.globalCompositeOperation = "source-over";
          for (const pt of points) {
            if (pt[2]) {
              ctxMask.fillStyle = "rgba(0, 255, 255, 1)";
            } else {
              ctxMask.fillStyle = "rgba(255, 255, 0, 1)";
            }
            ctxMask.beginPath();
            ctxMask.arc(
              pt[0] * canvas.width,
              pt[1] * canvas.height,
              3,
              0,
              2 * Math.PI
            );
            ctxMask.fill();
          }
          ctxMask.restore();
        };
        img.src = maskURL;
      }
      function drawImageCanvas(imgURL) {
        if (!imgURL) {
          throw new Error("No image URL provided");
        }

        ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
        ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);

        const img = new Image();
        img.crossOrigin = "anonymous";

        img.onload = () => {
          canvas.width = img.width;
          canvas.height = img.height;
          ctxCanvas.drawImage(img, 0, 0);
          canvas.parentElement.style.height = canvas.offsetHeight + "px";
          hasImage = true;
          clearBtn.disabled = false;
          dropButtons.classList.add("invisible");
        };
        img.src = imgURL;
      }

      const observer = new ResizeObserver((entries) => {
        for (let entry of entries) {
          if (entry.target === canvas) {
            canvas.parentElement.style.height = canvas.offsetHeight + "px";
          }
        }
      });
      observer.observe(canvas);
    </script>
  </head>
  <body class="container max-w-4xl mx-auto p-4">
    <main class="grid grid-cols-1 gap-8 relative">
      <span class="absolute text-5xl -ml-[1em]">🕯️</span>
      <div>
        <h1 class="text-5xl font-bold">Candle Segment Anything</h1>
        <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
        <p class="max-w-lg">
          Zero-shot image segmentation with
          <a
            href="https://segment-anything.com"
            class="underline hover:text-blue-500 hover:no-underline"
            target="_blank"
            >Segment Anything Model (SAM)</a
          >
          and
          <a
            href="https://github.com/ChaoningZhang/MobileSAM"
            class="underline hover:text-blue-500 hover:no-underline"
            target="_blank"
            >MobileSAM </a
          >. It runs in the browser with a WASM runtime built with
          <a
            href="https://github.com/huggingface/candle/"
            target="_blank"
            class="underline hover:text-blue-500 hover:no-underline"
            >Candle
          </a>
        </p>
      </div>
      <div>
        <label for="model" class="font-medium">Models Options: </label>
        <select
          id="model"
          class="border-2 border-gray-500 rounded-md font-light">
          <option value="sam_mobile_tiny" selected>
            Mobile SAM Tiny (40.6 MB)
          </option>
          <option value="sam_base">SAM Base (375 MB)</option>
        </select>
      </div>
      <div>
        <p class="text-xs italic max-w-lg">
          <b>Note:</b>
          The model's first run may take a few seconds as it loads and caches
          the model in the browser, and then creates the image embeddings. Any
          subsequent clicks on points will be significantly faster.
        </p>
      </div>
      <div class="relative max-w-2xl">
        <div class="flex justify-between items-center">
          <div class="px-2 rounded-md inline text-xs">
            <span id="output-status" class="m-auto font-light"></span>
          </div>
          <div class="flex gap-2">
            <button
              id="mask-btn"
              title="Toggle Mask Point and Background Point"
              class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
              <span>Mask Point</span>
              <svg
                xmlns="http://www.w3.org/2000/svg"
                height="1em"
                viewBox="0 0 512 512">
                <path
                  id="mask-circle"
                  d="M256 512a256 256 0 1 0 0-512 256 256 0 1 0 0 512z" />
                <path
                  id="unmask-circle"
                  hidden
                  d="M464 256a208 208 0 1 0-416 0 208 208 0 1 0 416 0zM0 256a256 256 0 1 1 512 0 256 256 0 1 1-512 0z" />
              </svg>
            </button>
            <button
              id="undo-btn"
              disabled
              title="Undo Last Point"
              class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
              <svg
                xmlns="http://www.w3.org/2000/svg"
                height="1em"
                viewBox="0 0 512 512">
                <path
                  d="M48.5 224H40a24 24 0 0 1-24-24V72a24 24 0 0 1 41-17l41.6 41.6a224 224 0 1 1-1 317.8 32 32 0 0 1 45.3-45.3 160 160 0 1 0 1-227.3L185 183a24 24 0 0 1-17 41H48.5z" />
              </svg>
            </button>

            <button
              id="clear-btn"
              disabled
              title="Clear Image"
              class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center">
              <svg
                class=""
                xmlns="http://www.w3.org/2000/svg"
                viewBox="0 0 13 12"
                height="1em">
                <path
                  d="M1.6.7 12 11.1M12 .7 1.6 11.1"
                  stroke="#2E3036"
                  stroke-width="2" />
              </svg>
            </button>
          </div>
        </div>
        <div
          id="drop-area"
          class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden">
          <div
            id="drop-buttons"
            class="flex flex-col items-center justify-center space-y-1 text-center relative z-10">
            <svg
              width="25"
              height="25"
              viewBox="0 0 25 25"
              fill="none"
              xmlns="http://www.w3.org/2000/svg">
              <path
                d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
                fill="#000" />
            </svg>
            <div class="flex text-sm text-gray-600">
              <label
                for="file-upload"
                class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
                <span>Drag and drop your image here</span>
                <span class="block text-xs">or</span>
                <span class="block text-xs">Click to upload</span>
              </label>
            </div>
            <input
              id="file-upload"
              name="file-upload"
              type="file"
              class="sr-only" />
          </div>
          <canvas id="canvas" class="absolute w-full"></canvas>
          <canvas
            id="mask"
            class="pointer-events-none absolute w-full"></canvas>
        </div>
        <div class="text-right py-2">
          <button
            id="share-btn"
            class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible">
            <img
              src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" />
          </button>

          <button
            id="download-btn"
            title="Copy result (.png)"
            disabled
            class="p-1 px-2 text-xs font-medium bg-white rounded-2xl outline outline-gray-200 hover:outline-orange-200 disabled:opacity-50"
          >
            Download the result (png file)
          </button>
        </div>
      </div>
      <div>
        <div
          class="flex gap-3 items-center overflow-x-scroll"
          id="image-select">
          <h3 class="font-medium">Examples:</h3>

          <img
            src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
            class="cursor-pointer w-24 h-24 object-cover" />
          <img
            src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
            class="cursor-pointer w-24 h-24 object-cover" />
          <img
            src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
            class="cursor-pointer w-24 h-24 object-cover" />
        </div>
      </div>
    </main>
  </body>
</html>