SNAC-24khz-decoder-onnx / SNAC-Decoder-in-browser.html
ChristophSchuhmann's picture
Upload 4 files
206ee95 verified
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>SNAC 24k — Click-free Streaming (Robust scheduler + Cache)</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style>
:root { color-scheme: dark light; }
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, Cantarell, "Helvetica Neue", Arial; margin:0 }
header { padding:16px 20px; background:#111827; color:#f9fafb }
main { padding:16px; display:grid; gap:16px; grid-template-columns:1fr 380px }
section { border:1px solid #e5e7eb20; border-radius:12px; padding:12px 14px; background:#0b1220; color:#e5e7eb }
h1{ margin:0 0 6px 0; font-size:20px } h2{ margin:8px 0; font-size:16px }
.row{ display:flex; gap:8px; align-items:center; flex-wrap:wrap }
.btn{ padding:8px 12px; border-radius:10px; border:1px solid #475569; background:#1f2937; color:#e5e7eb; cursor:pointer }
.btn:disabled{ opacity:.5; cursor:not-allowed }
.mono{ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace }
textarea{ width:100%; min-height:160px; border-radius:10px; border:1px solid #334155; background:#0b1220; color:#e5e7eb; padding:10px }
input[type="number"],input[type="text"],select{ padding:6px 8px; border-radius:8px; border:1px solid #334155; background:#0b1220; color:#e5e7eb }
.grid{ display:grid; gap:10px; grid-template-columns:repeat(2,1fr) }
.log{ height:180px; overflow:auto; background:#0a0f1c; border-radius:8px; padding:8px; border:1px solid #1f2937 }
.small{ font-size:12px; opacity:.9 } .muted{ opacity:.7 }
.pill{ display:inline-block; padding:4px 8px; border-radius:999px; background:#0b132b; border:1px solid #334155; margin:2px }
.progress{ width:100%; height:8px; background:#111827; border-radius:999px; overflow:hidden; border:1px solid #374151 }
.progress>div{ height:100%; background:#22c55e; width:0% }
audio{ width:100%; margin-top:8px }
</style>
</head>
<body>
<header>
<h1>SNAC 24&nbsp;kHz — Click-free Streaming (Robust scheduler + Cache)</h1>
<div class="small muted">Streaming uses 48-frame windows, default hop 40, center-keep & equal-power crossfade. Preloads model into IndexedDB cache.</div>
</header>
<main>
<section>
<h2>1) Inputs</h2>
<div class="grid">
<div>
<div class="small muted">Model URL (int→wav ONNX)</div>
<input id="modelUrl" class="mono" type="text" style="width:100%"
value="https://huggingface.co/laion/SNAC-24khz-decoder-onnx/resolve/main/snac24_int2wav_static.onnx">
</div>
<div>
<div class="small muted">Codes URL (flattened JSON)</div>
<input id="codesUrl" class="mono" type="text" style="width:100%"
value="https://huggingface.co/laion/SNAC-24khz-decoder-onnx/resolve/main/snac_flattened_stream.txt">
</div>
</div>
<div class="row" style="margin-top:8px;gap:12px;">
<button id="preloadBtn" class="btn">Preload model (and cache)</button>
<button id="loadCodesBtn" class="btn">Load codes into textbox</button>
<button id="clearCacheBtn" class="btn">Clear cache</button>
<span id="preloadStatus" class="small pill">idle</span>
</div>
<div class="progress" style="margin:8px 0;"><div id="dlBar"></div></div>
<div class="small" id="dlText"></div>
</section>
<section>
<h2>2) Decode options</h2>
<div class="grid">
<div>
<div class="small muted">Execution Provider</div>
<select id="providerSel">
<option value="webgpu">webgpu (if available)</option>
<option value="wasm">wasm</option>
</select>
</div>
<div>
<div class="small muted">Streaming mode</div>
<select id="modeSel">
<option value="stream">Streaming (center-keep)</option>
<option value="whole">Whole file (assemble then play)</option>
<option value="fixed">Fixed windows (butt-join; for comparison)</option>
</select>
</div>
<div>
<div class="small muted">Hop (L2 frames)</div>
<input id="hopFrames" type="number" min="8" max="48" step="4" value="40">
</div>
<div>
<div class="small muted">Crossfade (ms)</div>
<input id="xfadeMs" type="number" min="0" max="40" step="2" value="12">
</div>
<div>
<div class="small muted">Keep center (L2 frames)</div>
<input id="keepFrames" type="number" min="8" max="48" step="4" value="40">
</div>
<div>
<div class="small muted">Window (L2 frames)</div>
<input id="winFrames" type="number" min="48" max="48" step="0" value="48" disabled>
</div>
</div>
<div class="row" style="margin-top:8px; gap:12px;">
<label class="small"><input id="sequentialChk" type="checkbox"> Sequential playback (no overlap)</label>
<span class="small muted">Sample rate 24,000 Hz</span>
</div>
<div class="row" style="margin-top:8px; gap:12px;">
<button id="generateBtn" class="btn">Generate</button>
</div>
</section>
<section style="grid-column:1 / span 2;">
<h2>3) Flattened SNAC JSON</h2>
<textarea id="snacIn" class="mono" spellcheck="false"
placeholder='Paste the single-line JSON (with "flattened","lengths", optional "streaming") here…'></textarea>
</section>
<section>
<h2>4) Output</h2>
<div class="small">Player</div>
<audio id="player" controls></audio>
<div class="small" style="margin-top:8px;">Metrics</div>
<pre id="metrics" class="log mono"></pre>
</section>
<section>
<h2>5) Logs & Info</h2>
<pre id="log" class="log mono"></pre>
<div class="small muted">
<p><b>WebGPU</b> runs operations on your GPU (fast when supported). <b>WASM</b> runs on CPU; SIMD is automatic; multithreading needs cross-origin isolation (COOP/COEP). Without COI, WASM uses 1 thread.</p>
<p>Streaming: 48-frame windows, hop 40, keep center 40. Equal-power crossfade removes seams. “Sequential” forces no overlap and starts each chunk only after the previous ended (sets crossfade to 0).</p>
</div>
</section>
</main>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
(async () => {
const el = id => document.getElementById(id);
const modelUrl = el('modelUrl'), codesUrl = el('codesUrl');
const preloadBtn = el('preloadBtn'), loadCodesBtn = el('loadCodesBtn'), clearCacheBtn = el('clearCacheBtn');
const preloadStatus = el('preloadStatus'), dlBar = el('dlBar'), dlText = el('dlText');
const providerSel = el('providerSel'), modeSel = el('modeSel');
const hopFrames = el('hopFrames'), keepFrames = el('keepFrames'), xfadeMs = el('xfadeMs'), winFrames = el('winFrames');
const sequentialChk = el('sequentialChk');
const snacIn = el('snacIn'), player = el('player'), metrics = el('metrics'), logbox = el('log');
const genBtn = el('generateBtn');
const log = (...a)=>{ console.log(...a); logbox.textContent += a.join(' ') + '\n'; logbox.scrollTop = logbox.scrollHeight; }
const fmt = o => JSON.stringify(o, null, 2);
// ===== IndexedDB tiny cache =====
const DB='snac-cache', STORE='files';
function idb(){ return new Promise((res,rej)=>{ const r=indexedDB.open(DB,1); r.onupgradeneeded=()=>r.result.createObjectStore(STORE); r.onsuccess=()=>res(r.result); r.onerror=()=>rej(r.error); }); }
async function idbGet(k){ const db=await idb(); return new Promise((res,rej)=>{ const tx=db.transaction(STORE,'readonly'); const rq=tx.objectStore(STORE).get(k); rq.onsuccess=()=>res(rq.result||null); rq.onerror=()=>rej(rq.error); }); }
async function idbSet(k,v){ const db=await idb(); return new Promise((res,rej)=>{ const tx=db.transaction(STORE,'readwrite'); tx.objectStore(STORE).put(v,k); tx.oncomplete=()=>res(true); tx.onerror=()=>rej(tx.error); }); }
async function idbDel(k){ const db=await idb(); return new Promise((res,rej)=>{ const tx=db.transaction(STORE,'readwrite'); tx.objectStore(STORE).delete(k); tx.oncomplete=()=>res(true); tx.onerror=()=>rej(tx.error); }); }
clearCacheBtn.onclick = async () => { await idbDel(modelUrl.value.trim()); log('Cache cleared for', modelUrl.value.trim()); };
// ===== network with progress =====
async function fetchWithProgress(url, onProg){
const r = await fetch(url); if(!r.ok) throw new Error(`HTTP ${r.status} for ${url}`);
const len = Number(r.headers.get('Content-Length'))||0;
if(!r.body || !window.ReadableStream){ const buf=await r.arrayBuffer(); onProg?.(buf.byteLength,len); return buf; }
const reader=r.body.getReader(); const chunks=[]; let got=0;
for(;;){ const {done,value}=await reader.read(); if(done) break; chunks.push(value); got+=value.byteLength; onProg?.(got,len); }
const out=new Uint8Array(got); let off=0; for(const c of chunks){ out.set(c,off); off+=c.byteLength; } return out.buffer;
}
// ===== ORT env =====
const coi = (typeof crossOriginIsolated!=='undefined') ? crossOriginIsolated : false;
ort.env.wasm.simd = true;
ort.env.wasm.numThreads = coi ? (navigator.hardwareConcurrency||4) : 1;
// ===== preload model (with cache) =====
let session=null, sessionEP=null;
async function preloadModel(){
try{
const url = modelUrl.value.trim();
preloadBtn.disabled = true; dlBar.style.width='0%'; dlText.textContent=''; preloadStatus.textContent='checking cache…';
let buf = await idbGet(url);
if(buf){ preloadStatus.textContent='cache hit'; dlBar.style.width='100%'; dlText.textContent='Loaded from IndexedDB'; }
else{
preloadStatus.textContent='downloading…';
buf = await fetchWithProgress(url, (got,total)=>{
const pct = total ? Math.round(100*got/total) : 0;
dlBar.style.width = `${pct}%`;
dlText.textContent = total ? `Downloading: ${pct}% (${(got/1e6).toFixed(1)} / ${(total/1e6).toFixed(1)} MB)` :
`Downloading: ${(got/1e6).toFixed(1)} MB`;
});
await idbSet(url, buf); log('Cached model to IndexedDB');
}
const want = (providerSel.value==='webgpu' && 'gpu' in navigator) ? ['webgpu','wasm'] : ['wasm'];
const t0=performance.now();
preloadStatus.textContent='compiling…';
session = await ort.InferenceSession.create(buf, { executionProviders: want, graphOptimizationLevel: 'all' });
const t1=performance.now(); sessionEP = session.executionProvider ?? want[0];
preloadStatus.textContent=`ready (${(t1-t0).toFixed(1)} ms) via ${sessionEP}`;
log('Session ready. EP:', sessionEP, 'compile_ms:', (t1-t0).toFixed(1));
}catch(e){ log('Preload error:', e); }
finally{ preloadBtn.disabled=false; }
}
preloadBtn.onclick = preloadModel;
window.addEventListener('load', ()=>preloadModel().catch(e=>log('Preload error:',e)));
// ===== load codes =====
loadCodesBtn.onclick = async ()=>{
try{ const r=await fetch(codesUrl.value.trim()); if(!r.ok) throw new Error(`HTTP ${r.status}`);
const txt=await r.text(); snacIn.value = txt.trim(); log('Loaded codes text.'); }
catch(e){ log('Load codes error:', e); }
};
// ===== SNAC helpers =====
const SPF = 512, SR = 24000;
function unflatten(flat,L0,A=0,K=4096){
const L1=2*L0, L2=4*L0;
const c0=new BigInt64Array(L0), c1=new BigInt64Array(L1), c2=new BigInt64Array(L2);
const bA=BigInt(A), bK=BigInt(K), mod=v=>((v%bK)+bK)%bK;
for(let i=0;i<L0;i++){
const v0=BigInt(flat[7*i+0])-(bA+0n*bK), v1=BigInt(flat[7*i+1])-(bA+1n*bK);
const v2=BigInt(flat[7*i+2])-(bA+2n*bK), v3=BigInt(flat[7*i+3])-(bA+3n*bK);
const v4=BigInt(flat[7*i+4])-(bA+4n*bK), v5=BigInt(flat[7*i+5])-(bA+5n*bK), v6=BigInt(flat[7*i+6])-(bA+6n*bK);
c0[i]=mod(v0); c1[2*i]=mod(v1); c2[4*i]=mod(v2); c2[4*i+1]=mod(v3);
c1[2*i+1]=mod(v4); c2[4*i+2]=mod(v5); c2[4*i+3]=mod(v6);
}
return {c0,c1,c2};
}
function sliceEdgePad(src,start,len){
const T=src.length, out=new BigInt64Array(len);
for(let i=0;i<len;i++){ let j=start+i; if(j<0) j=0; if(j>=T) j=T-1; out[i]=src[j]; }
return out;
}
function concatFloat32(a,b){ const out=new Float32Array(a.length+b.length); out.set(a,0); out.set(b,a.length); return out; }
// ===== robust scheduler: gain-node crossfades + safety headroom + detailed logging =====
async function generate(){
try{
genBtn.disabled=true; metrics.textContent=''; logbox.textContent='';
if(!session) await preloadModel();
const blob = JSON.parse(snacIn.value.trim());
const A = blob.audio_tokens_start ?? 0;
const K = blob.codebook_size ?? 4096;
const L0 = blob.lengths?.L0 ?? Math.floor((blob.flattened.length)/7);
const flat = blob.flattened; if(!Array.isArray(flat)) throw new Error("flattened missing");
const {c0,c1,c2} = unflatten(flat, L0, A, K);
const L2=c2.length, T_true=L2*SPF;
log(`Parsed L0/L1/L2 = ${L0}/${c1.length}/${L2} -> true samples ${T_true}`);
// streaming params (with sane defaults)
const s = blob.streaming || {};
const K2 = Number(winFrames.value)||48;
const H2 = Number(hopFrames.value)||Number(s.hop_frames||40);
const keepF = Number(keepFrames.value)||Number(s.center_keep_frames||40);
const leftCtx = Number(s.left_ctx_frames ?? ((K2-keepF)/2));
const xfade = Number(xfadeMs.value ?? s.xfade_ms_default ?? 12);
const sequential = !!sequentialChk.checked;
const overlapSec = sequential ? 0 : (xfade/1000);
const SAFETY = 0.040; // 40 ms safety to avoid underschedule
const ctx = new (window.AudioContext||window.webkitAudioContext)({sampleRate: SR});
let playClock = ctx.currentTime + 0.10; // first chunk start
let scheduledEnd = null; // end time of last scheduled chunk (audio time)
let windows=0, samples=0; const t0=performance.now();
async function runWindow(c0w,c1w,c2w, keepStart, keepEnd){
const feeds = {
codes0: new ort.Tensor('int64', c0w,[1,c0w.length]),
codes1: new ort.Tensor('int64', c1w,[1,c1w.length]),
codes2: new ort.Tensor('int64', c2w,[1,c2w.length]),
};
const tA=performance.now();
const out = await session.run(feeds);
const tB=performance.now();
const audio = out.audio.data; // Float32 [24576]
const kept = audio.subarray(keepStart, keepEnd);
const segSec = kept.length / SR;
samples += kept.length;
// Build nodes
const buf = ctx.createBuffer(1, kept.length, SR);
buf.copyToChannel(kept, 0, 0);
const src = ctx.createBufferSource(); src.buffer = buf;
const g = ctx.createGain(); g.gain.setValueAtTime(1, ctx.currentTime);
src.connect(g).connect(ctx.destination);
// Decide start time
const desiredStart = (scheduledEnd==null) ? playClock : (sequential ? scheduledEnd : scheduledEnd - overlapSec);
const now = ctx.currentTime;
const startAt = Math.max(desiredStart, now + SAFETY);
const endAt = startAt + segSec;
// Apply equal-power crossfade via gains (if overlapping)
if(!sequential && scheduledEnd!==null && overlapSec>0){
const prevDropStart = startAt;
const prevDropEnd = Math.min(startAt + overlapSec, scheduledEnd);
// fade previous node down if we have a handle:
const prev = lastGainNode;
if(prev){
prev.gain.setValueAtTime(1, prevDropStart);
prev.gain.linearRampToValueAtTime(0, prevDropEnd);
}
// fade current node up
g.gain.setValueAtTime(0, startAt);
g.gain.linearRampToValueAtTime(1, startAt + overlapSec);
}
src.start(startAt);
windows += 1;
scheduledEnd = endAt;
lastGainNode = g;
const late = Math.max(0, (now + SAFETY) - desiredStart);
log(`win#${windows} infer ${(tB-tA).toFixed(2)} ms now ${now.toFixed(3)} desired ${desiredStart.toFixed(3)} start ${startAt.toFixed(3)} end ${endAt.toFixed(3)} overlap ${sequential?0:(overlapSec*1000)} ms late ${ (late*1000).toFixed(1)} ms`);
}
// window generators
function* fixedWindows(){
const nWin=Math.ceil(L2/48);
for(let i=0;i<nWin;i++){
const s2=i*48,s1=i*24,s0=i*12;
yield { c0w:sliceEdgePad(c0,s0,12), c1w:sliceEdgePad(c1,s1,24), c2w:sliceEdgePad(c2,s2,48),
keepStart:0, keepEnd:48*SPF };
}
}
function* slidingCenterWindows(){
let s2=0,i=0; const keepStart = ((K2-keepF)/2)*SPF, keepEnd=((K2-keepF)/2+keepF)*SPF;
while(s2 < L2 || i===0){
const s1=Math.floor(s2/2), s0=Math.floor(s2/4);
yield { c0w:sliceEdgePad(c0,s0,K2/4), c1w:sliceEdgePad(c1,s1,K2/2), c2w:sliceEdgePad(c2,s2,K2),
keepStart, keepEnd };
s2 += (modeSel.value==='fixed'?48:H2); i++; if(s2>=L2 && modeSel.value!=='fixed') break;
}
}
// choose mode
const mode=modeSel.value;
let gen;
if(mode==='fixed') gen = fixedWindows;
else gen = slidingCenterWindows;
// scheduling
let lastGainNode = null;
if(mode==='whole'){
// assemble then play (still center-keep in sliding path)
let full = new Float32Array(0);
for(const {c0w,c1w,c2w,keepStart,keepEnd} of slidingCenterWindows()){
const out = await session.run({
codes0:new ort.Tensor('int64', c0w,[1,c0w.length]),
codes1:new ort.Tensor('int64', c1w,[1,c1w.length]),
codes2:new ort.Tensor('int64', c2w,[1,c2w.length]),
});
const a = out.audio.data.subarray(keepStart, keepEnd);
full = concatFloat32(full, a);
windows++; samples += a.length;
}
full = full.subarray(0, T_true);
const wav = pcm16Wav(full, SR);
player.src = URL.createObjectURL(new Blob([wav], {type:'audio/wav'}));
await player.play().catch(()=>{});
} else {
for(const w of gen()){
await runWindow(w.c0w, w.c1w, w.c2w, w.keepStart, w.keepEnd);
}
}
const t1=performance.now();
const r = {
usedEP: sessionEP || providerSel.value,
threads: ort.env.wasm.numThreads||1,
simd: ort.env.wasm.simd===true,
coi, windows, samples,
audio_seconds: samples/SR,
inference_ms: (t1-t0),
rtf: ( (samples/SR) / ((t1-t0)/1000) ).toFixed(3)
};
metrics.textContent = fmt({env:{coi, hwc:navigator.hardwareConcurrency||1}, providers:['webgpu','wasm']}) + "\n" + fmt(r);
log('Done.', r);
function pcm16Wav(float32, sr){
const clamp=v=>Math.max(-1,Math.min(1,v));
const pcm=new Int16Array(float32.length); for(let i=0;i<float32.length;i++) pcm[i]=Math.round(clamp(float32[i])*32767);
const bytes=44+pcm.length*2, buf=new ArrayBuffer(bytes), dv=new DataView(buf); let p=0, w=s=>{ for(let i=0;i<s.length;i++) dv.setUint8(p++, s.charCodeAt(i)); };
w('RIFF'); dv.setUint32(p,bytes-8,true); p+=4; w('WAVE'); w('fmt '); dv.setUint32(p,16,true); p+=4;
dv.setUint16(p,1,true); p+=2; dv.setUint16(p,1,true); p+=2; dv.setUint32(p,sr,true); p+=4;
dv.setUint32(p,sr*2,true); p+=4; dv.setUint16(p,2,true); p+=2; dv.setUint16(p,16,true); p+=2;
w('data'); dv.setUint32(p, pcm.length*2, true); p+=4; new Uint8Array(buf).set(new Uint8Array(pcm.buffer),44); return new Uint8Array(buf);
}
}catch(e){ console.error(e); log('ERROR:', e.message||e); }
finally{ genBtn.disabled=false; }
}
genBtn.onclick = generate;
})();
</script>
</body>
</html>