Spaces:
Running
Running
| import { useEffect, useState, useRef } from "react"; | |
| import Chat from "./components/Chat"; | |
| import ArrowRightIcon from "./components/icons/ArrowRightIcon"; | |
| import StopIcon from "./components/icons/StopIcon"; | |
| import Progress from "./components/Progress"; | |
| import LightBulbIcon from "./components/icons/LightBulbIcon"; | |
| const IS_WEBGPU_AVAILABLE = !!navigator.gpu; | |
| const STICKY_SCROLL_THRESHOLD = 120; | |
| const EXAMPLES = [ | |
| "Solve the equation x^2 - 3x + 2 = 0", | |
| "How do you say 'I love you' in French, Spanish, and German? Respond in a table.", | |
| "Explain the concept of gravity in simple terms.", | |
| ]; | |
| function App() { | |
| // Create a reference to the worker object. | |
| const worker = useRef(null); | |
| const textareaRef = useRef(null); | |
| const chatContainerRef = useRef(null); | |
| // Model loading and progress | |
| const [status, setStatus] = useState(null); | |
| const [error, setError] = useState(null); | |
| const [loadingMessage, setLoadingMessage] = useState(""); | |
| const [progressItems, setProgressItems] = useState([]); | |
| const [isRunning, setIsRunning] = useState(false); | |
| // Inputs and outputs | |
| const [input, setInput] = useState(""); | |
| const [messages, setMessages] = useState([]); | |
| const [tps, setTps] = useState(null); | |
| const [numTokens, setNumTokens] = useState(null); | |
| const [reasonEnabled, setReasonEnabled] = useState(false); | |
| function onEnter(message) { | |
| setMessages((prev) => [...prev, { role: "user", content: message }]); | |
| setTps(null); | |
| setIsRunning(true); | |
| setInput(""); | |
| } | |
| function onInterrupt() { | |
| // NOTE: We do not set isRunning to false here because the worker | |
| // will send a 'complete' message when it is done. | |
| worker.current.postMessage({ type: "interrupt" }); | |
| } | |
| useEffect(() => { | |
| resizeInput(); | |
| }, [input]); | |
| function resizeInput() { | |
| if (!textareaRef.current) return; | |
| const target = textareaRef.current; | |
| target.style.height = "auto"; | |
| const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200); | |
| target.style.height = `${newHeight}px`; | |
| } | |
| // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted. | |
| useEffect(() => { | |
| // Create the worker if it does not yet exist. | |
| if (!worker.current) { | |
| worker.current = new Worker(new URL("./worker.js", import.meta.url), { | |
| type: "module", | |
| }); | |
| worker.current.postMessage({ type: "check" }); // Do a feature check | |
| } | |
| // Create a callback function for messages from the worker thread. | |
| const onMessageReceived = (e) => { | |
| switch (e.data.status) { | |
| case "loading": | |
| // Model file start load: add a new progress item to the list. | |
| setStatus("loading"); | |
| setLoadingMessage(e.data.data); | |
| break; | |
| case "initiate": | |
| setProgressItems((prev) => [...prev, e.data]); | |
| break; | |
| case "progress": | |
| // Model file progress: update one of the progress items. | |
| setProgressItems((prev) => | |
| prev.map((item) => { | |
| if (item.file === e.data.file) { | |
| return { ...item, ...e.data }; | |
| } | |
| return item; | |
| }), | |
| ); | |
| break; | |
| case "done": | |
| // Model file loaded: remove the progress item from the list. | |
| setProgressItems((prev) => | |
| prev.filter((item) => item.file !== e.data.file), | |
| ); | |
| break; | |
| case "ready": | |
| // Pipeline ready: the worker is ready to accept messages. | |
| setStatus("ready"); | |
| break; | |
| case "start": | |
| { | |
| // Start generation | |
| setMessages((prev) => [ | |
| ...prev, | |
| { role: "assistant", content: "" }, | |
| ]); | |
| } | |
| break; | |
| case "update": | |
| { | |
| // Generation update: update the output text. | |
| // Parse messages | |
| const { output, tps, numTokens, state } = e.data; | |
| setTps(tps); | |
| setNumTokens(numTokens); | |
| setMessages((prev) => { | |
| const cloned = [...prev]; | |
| const last = cloned.at(-1); | |
| const data = { | |
| ...last, | |
| content: last.content + output, | |
| }; | |
| if (data.answerIndex === undefined && state === "answering") { | |
| // When state changes to answering, we set the answerIndex | |
| data.answerIndex = last.content.length; | |
| } | |
| cloned[cloned.length - 1] = data; | |
| return cloned; | |
| }); | |
| } | |
| break; | |
| case "complete": | |
| // Generation complete: re-enable the "Generate" button | |
| setIsRunning(false); | |
| break; | |
| case "error": | |
| setError(e.data.data); | |
| break; | |
| } | |
| }; | |
| const onErrorReceived = (e) => { | |
| console.error("Worker error:", e); | |
| }; | |
| // Attach the callback function as an event listener. | |
| worker.current.addEventListener("message", onMessageReceived); | |
| worker.current.addEventListener("error", onErrorReceived); | |
| // Define a cleanup function for when the component is unmounted. | |
| return () => { | |
| worker.current.removeEventListener("message", onMessageReceived); | |
| worker.current.removeEventListener("error", onErrorReceived); | |
| }; | |
| }, []); | |
| // Send the messages to the worker thread whenever the `messages` state changes. | |
| useEffect(() => { | |
| if (messages.filter((x) => x.role === "user").length === 0) { | |
| // No user messages yet: do nothing. | |
| return; | |
| } | |
| if (messages.at(-1).role === "assistant") { | |
| // Do not update if the last message is from the assistant | |
| return; | |
| } | |
| setTps(null); | |
| worker.current.postMessage({ | |
| type: "generate", | |
| data: { messages, reasonEnabled }, | |
| }); | |
| }, [messages, isRunning]); | |
| useEffect(() => { | |
| if (!chatContainerRef.current) return; | |
| const element = chatContainerRef.current; | |
| if ( | |
| element.scrollHeight - element.scrollTop - element.clientHeight < | |
| STICKY_SCROLL_THRESHOLD | |
| ) { | |
| element.scrollTop = element.scrollHeight; | |
| } | |
| }, [messages, isRunning]); | |
| return IS_WEBGPU_AVAILABLE ? ( | |
| <div className="flex flex-col h-screen mx-auto items justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900"> | |
| {status === null && messages.length === 0 && ( | |
| <div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative"> | |
| <div className="flex flex-col items-center mb-1 max-w-[360px] text-center"> | |
| <img | |
| src="logo.png" | |
| width="80%" | |
| height="auto" | |
| className="block drop-shadow-lg bg-transparent" | |
| ></img> | |
| <h1 className="text-4xl font-bold my-1">SmolLM3 WebGPU</h1> | |
| <h2 className="font-semibold"> | |
| A dual reasoning model that runs locally in <br /> | |
| your browser with WebGPU acceleration. | |
| </h2> | |
| </div> | |
| <div className="flex flex-col items-center px-4"> | |
| <p className="max-w-[480px] mb-4"> | |
| <br /> | |
| You are about to load{" "} | |
| <a | |
| href="https://huggingface.co/HuggingFaceTB/SmolLM3-3B-ONNX" | |
| target="_blank" | |
| rel="noreferrer" | |
| className="font-medium underline" | |
| > | |
| SmolLM3-3B | |
| </a> | |
| , a 3B parameter reasoning LLM optimized for in-browser inference. | |
| Everything runs entirely in your browser with{" "} | |
| <a | |
| href="https://huggingface.co/docs/transformers.js" | |
| target="_blank" | |
| rel="noreferrer" | |
| className="underline" | |
| > | |
| 🤗 Transformers.js | |
| </a>{" "} | |
| and ONNX Runtime Web, meaning no data is sent to a server. Once | |
| loaded, it can even be used offline. The source code for the demo | |
| is available on{" "} | |
| <a | |
| href="https://github.com/huggingface/transformers.js-examples/tree/main/smollm3-webgpu" | |
| target="_blank" | |
| rel="noreferrer" | |
| className="font-medium underline" | |
| > | |
| GitHub | |
| </a> | |
| . | |
| </p> | |
| {error && ( | |
| <div className="text-red-500 text-center mb-2"> | |
| <p className="mb-1"> | |
| Unable to load model due to the following error: | |
| </p> | |
| <p className="text-sm">{error}</p> | |
| </div> | |
| )} | |
| <button | |
| className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 cursor-pointer disabled:cursor-not-allowed select-none" | |
| onClick={() => { | |
| worker.current.postMessage({ type: "load" }); | |
| setStatus("loading"); | |
| }} | |
| disabled={status !== null || error !== null} | |
| > | |
| Load model | |
| </button> | |
| </div> | |
| </div> | |
| )} | |
| {status === "loading" && ( | |
| <> | |
| <div className="w-full max-w-[500px] text-left mx-auto p-4 bottom-0 mt-auto"> | |
| <p className="text-center mb-1">{loadingMessage}</p> | |
| {progressItems.map(({ file, progress, total }, i) => ( | |
| <Progress | |
| key={i} | |
| text={file} | |
| percentage={progress} | |
| total={total} | |
| /> | |
| ))} | |
| </div> | |
| </> | |
| )} | |
| {status === "ready" && ( | |
| <> | |
| <div | |
| ref={chatContainerRef} | |
| className="overflow-y-auto scrollbar-thin w-full flex flex-col items-center h-full" | |
| > | |
| <Chat messages={messages} /> | |
| {messages.length === 0 && ( | |
| <div> | |
| {EXAMPLES.map((msg, i) => ( | |
| <div | |
| key={i} | |
| className="m-1 border border-gray-300 dark:border-gray-600 rounded-md p-2 bg-gray-100 dark:bg-gray-700 cursor-pointer max-w-[500px]" | |
| onClick={() => onEnter(msg)} | |
| > | |
| {msg} | |
| </div> | |
| ))} | |
| </div> | |
| )} | |
| </div> | |
| <p className="text-center text-sm min-h-6 text-gray-500 dark:text-gray-300 mt-2 mb-1"> | |
| {tps && messages.length > 0 && ( | |
| <> | |
| {!isRunning && ( | |
| <span> | |
| Generated {numTokens} tokens in{" "} | |
| {(numTokens / tps).toFixed(2)} seconds ( | |
| </span> | |
| )} | |
| { | |
| <> | |
| <span className="font-medium text-center mr-1 text-black dark:text-white"> | |
| {tps.toFixed(2)} | |
| </span> | |
| <span className="text-gray-500 dark:text-gray-300"> | |
| tokens/second | |
| </span> | |
| </> | |
| } | |
| {!isRunning && ( | |
| <> | |
| <span className="mr-1">).</span> | |
| <span | |
| className="underline cursor-pointer" | |
| onClick={() => { | |
| worker.current.postMessage({ type: "reset" }); | |
| setMessages([]); | |
| }} | |
| > | |
| Reset | |
| </span> | |
| </> | |
| )} | |
| </> | |
| )} | |
| </p> | |
| </> | |
| )} | |
| <div className="w-[600px] max-w-[80%] mx-auto mt-2 mb-3"> | |
| <div className="border border-gray-300 dark:border-gray-500 dark:bg-gray-700 rounded-lg max-h-[200px] relative flex"> | |
| <textarea | |
| ref={textareaRef} | |
| className="scrollbar-thin w-[550px] px-3 py-4 rounded-lg bg-transparent border-none outline-hidden text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-300 disabled:placeholder-gray-200 dark:disabled:placeholder-gray-500 resize-none disabled:cursor-not-allowed" | |
| placeholder="Type your message..." | |
| type="text" | |
| rows={1} | |
| value={input} | |
| disabled={status !== "ready"} | |
| title={ | |
| status === "ready" ? "Model is ready" : "Model not loaded yet" | |
| } | |
| onKeyDown={(e) => { | |
| if ( | |
| input.length > 0 && | |
| !isRunning && | |
| e.key === "Enter" && | |
| !e.shiftKey | |
| ) { | |
| e.preventDefault(); // Prevent default behavior of Enter key | |
| onEnter(input); | |
| } | |
| }} | |
| onInput={(e) => setInput(e.target.value)} | |
| /> | |
| {isRunning ? ( | |
| <div className="cursor-pointer" onClick={onInterrupt}> | |
| <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" /> | |
| </div> | |
| ) : input.length > 0 ? ( | |
| <div className="cursor-pointer" onClick={() => onEnter(input)}> | |
| <ArrowRightIcon | |
| className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`} | |
| /> | |
| </div> | |
| ) : ( | |
| <div> | |
| <ArrowRightIcon | |
| className={`h-8 w-8 p-1 bg-gray-200 dark:bg-gray-600 text-gray-50 dark:text-gray-800 rounded-md absolute right-3 bottom-3`} | |
| /> | |
| </div> | |
| )} | |
| </div> | |
| <div className="flex justify-end"> | |
| <div | |
| className={`border mt-1 inline-flex items-center p-2 gap-1 rounded-xl text-sm cursor-pointer ${ | |
| reasonEnabled | |
| ? "border-blue-500 bg-blue-100 text-blue-500 dark:bg-blue-600 dark:text-gray-200" | |
| : "dark:border-gray-700 bg-gray-800 text-gray-200 dark:text-gray-400" | |
| } ${ | |
| messages.length === 0 | |
| ? "pointer-events-auto" | |
| : "pointer-events-none opacity-50" | |
| }`} | |
| onClick={() => setReasonEnabled((prev) => !prev)} | |
| > | |
| <LightBulbIcon | |
| className={`h-4 w-4 ${ | |
| reasonEnabled ? "" : "stroke-gray-600 dark:stroke-gray-400" | |
| }`} | |
| /> | |
| Reason | |
| </div> | |
| </div> | |
| </div> | |
| <p className="text-xs text-gray-400 text-center mb-3"> | |
| Disclaimer: Generated content may be inaccurate or false. | |
| </p> | |
| </div> | |
| ) : ( | |
| <div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center"> | |
| WebGPU is not supported | |
| <br /> | |
| by this browser :( | |
| </div> | |
| ); | |
| } | |
| export default App; | |