<html>
  <head>
    <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
    <title>Candle T5</title>
  </head>

  <body></body>
</html>

<!DOCTYPE html>
<html>
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
    <style>
      @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");

      html,
      body {
        font-family: "Source Sans 3", sans-serif;
      }
    </style>
    <style type="text/tailwindcss">
      .link {
        @apply underline hover:text-blue-500 hover:no-underline;
      }
    </style>
    <script src="https://cdn.tailwindcss.com"></script>
    <script type="module">
      import {
        getModelInfo,
        MODELS,
        extractEmbeddings,
        generateText,
      } from "./utils.js";

      const t5ModelEncoderWorker = new Worker("./T5ModelEncoderWorker.js", {
        type: "module",
      });
      const t5ModelConditionalGeneration = new Worker(
        "./T5ModelConditionalGeneration.js",
        { type: "module" }
      );

      const formEl = document.querySelector("#form");
      const modelEl = document.querySelector("#model");
      const promptEl = document.querySelector("#prompt");
      const temperatureEl = document.querySelector("#temperature");
      const toppEL = document.querySelector("#top-p");
      const repeatPenaltyEl = document.querySelector("#repeat_penalty");
      const seedEl = document.querySelector("#seed");
      const outputEl = document.querySelector("#output-generation");
      const tasksEl = document.querySelector("#tasks");
      let selectedTaskID = "";

      document.addEventListener("DOMContentLoaded", () => {
        for (const [id, model] of Object.entries(MODELS)) {
          const option = document.createElement("option");
          option.value = id;
          option.innerText = `${id} (${model.size})`;
          modelEl.appendChild(option);
        }
        populateTasks(modelEl.value);
        modelEl.addEventListener("change", (e) => {
          populateTasks(e.target.value);
        });
        tasksEl.addEventListener("change", (e) => {
          const task = e.target.value;
          const modelID = modelEl.value;
          promptEl.value = MODELS[modelID].tasks[task].prefix;
          selectedTaskID = task;
        });
      });
      function populateTasks(modelID) {
        const tasks = MODELS[modelID].tasks;
        tasksEl.innerHTML = "";
        for (const [task, params] of Object.entries(tasks)) {
          const div = document.createElement("div");
          div.innerHTML = `
          <input
            type="radio"
            name="task"
            id="${task}"
            class="font-light cursor-pointer"
            value="${task}" />
          <label for="${task}" class="cursor-pointer">
            ${params.prefix}
          </label>
          `;
          tasksEl.appendChild(div);
        }
        selectedTaskID = Object.keys(tasks)[0];
        tasksEl.querySelector(`#${selectedTaskID}`).checked = true;
      }
      form.addEventListener("submit", (e) => {
        e.preventDefault();

        const promptText = promptEl.value;
        const modelID = modelEl.value;
        const { modelURL, configURL, tokenizerURL, maxLength } = getModelInfo(
          modelID,
          selectedTaskID
        );
        const params = {
          temperature: Number(temperatureEl.value),
          top_p: Number(toppEL.value),
          repetition_penalty: Number(repeatPenaltyEl.value),
          seed: BigInt(seedEl.value),
          max_length: maxLength,
        };
        generateText(
          t5ModelConditionalGeneration,
          modelURL,
          tokenizerURL,
          configURL,
          modelID,
          promptText,
          params,
          (status) => {
            if (status.status === "loading") {
              outputEl.innerText = "Loading model...";
            }
            if (status.status === "decoding") {
              outputEl.innerText = "Generating...";
            }
          }
        ).then(({ output }) => {
          outputEl.innerText = output.generation;
        });
      });
    </script>
  </head>

  <body class="container max-w-4xl mx-auto p-4">
    <main class="grid grid-cols-1 gap-8 relative">
      <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
      <div>
        <h1 class="text-5xl font-bold">Candle T5 Transformer</h1>
        <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
        <p class="max-w-lg">
          This demo showcase Text-To-Text Transfer Transformer (<a
            href="https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html"
            target="_blank"
            class="link"
            >T5</a
          >) models right in your browser, thanks to
          <a
            href="https://github.com/huggingface/candle/"
            target="_blank"
            class="link">
            Candle
          </a>
          ML framework and rust/wasm. You can choose from a range of available
          models, including
          <a
            href="https://huggingface.co/t5-small"
            target="_blank"
            class="link">
            t5-small</a
          >,
          <a href="https://huggingface.co/t5-base" target="_blank" class="link"
            >t5-base</a
          >,
          <a
            href="https://huggingface.co/google/flan-t5-small"
            target="_blank"
            class="link"
            >flan-t5-small</a
          >,
          several
          <a
            href="https://huggingface.co/lmz/candle-quantized-t5/tree/main"
            target="_blank"
            class="link">
            t5 quantized gguf models</a
          >, and also a quantized
          <a
            href="https://huggingface.co/jbochi/candle-coedit-quantized/tree/main"
            target="_blank"
            class="link">
            CoEdIT model for text rewrite</a
          >.
        </p>
      </div>

      <div>
        <label for="model" class="font-medium">Models Options: </label>
        <select
          id="model"
          class="border-2 border-gray-500 rounded-md font-light"></select>
      </div>

      <div>
        <h3 class="font-medium">Task Prefix:</h3>
        <form id="tasks" class="flex flex-col gap-1 my-2"></form>
      </div>
      <form
        id="form"
        class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
        <input type="submit" hidden />
        <input
          type="text"
          id="prompt"
          class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
          placeholder="Add prompt here, e.g. 'translate English to German: Today I'm going to eat Ice Cream'"
          value="translate English to German: Today I'm going to eat Ice Cream" />
        <button
          class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
          Run
        </button>
      </form>
      <div class="grid grid-cols-3 max-w-md items-center gap-3">
        <label class="text-sm font-medium" for="temperature">Temperature</label>
        <input
          type="range"
          id="temperature"
          name="temperature"
          min="0"
          max="2"
          step="0.01"
          value="0.00"
          oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
        <output
          class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
          0.00</output
        >
        <label class="text-sm font-medium" for="top-p">Top-p</label>
        <input
          type="range"
          id="top-p"
          name="top-p"
          min="0"
          max="1"
          step="0.01"
          value="1.00"
          oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
        <output
          class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
          1.00</output
        >

        <label class="text-sm font-medium" for="repeat_penalty"
          >Repeat Penalty</label
        >

        <input
          type="range"
          id="repeat_penalty"
          name="repeat_penalty"
          min="1"
          max="2"
          step="0.01"
          value="1.10"
          oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
        <output
          class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
          >1.10</output
        >
        <label class="text-sm font-medium" for="seed">Seed</label>
        <input
          type="number"
          id="seed"
          name="seed"
          value="299792458"
          class="font-light border border-gray-700 text-right rounded-md p-2" />
        <button
          id="run"
          onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
          class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
          Rand
        </button>
      </div>
      <div>
        <h3 class="font-medium">Generation:</h3>
        <div
          class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2 text-lg">
          <p id="output-generation" class="grid-rows-2">No output yet</p>
        </div>
      </div>
    </main>
  </body>
</html>