magenta-retry / jam_worker.py
thecollabagepatch's picture
smarter loudness matching in /jam/start
dfa1fc4
# jam_worker.py - Bar-locked spool rewrite
from __future__ import annotations
import os
import threading, time
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional, Dict, Tuple, List
import numpy as np
from magenta_rt import audio as au
from utils import (
StreamingResampler,
match_loudness_to_reference,
make_bar_aligned_context,
take_bar_aligned_tail,
wav_bytes_base64,
)
def _dbg_rms_dbfs(x: np.ndarray) -> float:
if x.ndim == 2:
x = x.mean(axis=1)
r = float(np.sqrt(np.mean(x * x) + 1e-12))
return 20.0 * np.log10(max(r, 1e-12))
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
# x is model-rate, shape [S,C] or [S]
if x.ndim == 2:
x = x.mean(axis=1)
r = float(np.sqrt(np.mean(x * x) + 1e-12))
return 20.0 * np.log10(max(r, 1e-12))
def _dbg_shape(x):
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
# -----------------------------
# Data classes
# -----------------------------
@dataclass
class JamParams:
bpm: float
beats_per_bar: int
bars_per_chunk: int
target_sr: int
loudness_mode: str = "auto"
headroom_db: float = 1.0
style_vec: Optional[np.ndarray] = None
ref_loop: Optional[au.Waveform] = None
combined_loop: Optional[au.Waveform] = None
guidance_weight: float = 1.1
temperature: float = 1.1
topk: int = 40
style_ramp_seconds: float = 8.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
@dataclass
class JamChunk:
index: int
audio_base64: str
metadata: dict
# -----------------------------
# Helpers
# -----------------------------
class BarClock:
"""Sample-domain bar clock with drift-free absolute boundaries."""
def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0):
self.sr = int(target_sr)
self.bpm = Fraction(str(bpm)) # exact decimal to avoid FP drift
self.beats_per_bar = int(beats_per_bar)
self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm
self.base = int(base_offset_samples)
def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]:
start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk)
end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk)
return int(round(start_f)), int(round(end_f))
def seconds_per_bar(self) -> float:
return float(self.beats_per_bar) * (60.0 / float(self.bpm))
# -----------------------------
# Worker
# -----------------------------
class JamWorker(threading.Thread):
FRAMES_PER_SECOND: float | None = None # filled in __init__ once codec is available
"""Generates continuous audio with MagentaRT, spools it at target SR,
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
def __init__(self, mrt, params: JamParams):
super().__init__(daemon=True)
self.mrt = mrt
self.params = params
# external callers (FastAPI endpoints) use this for atomic updates
self._lock = threading.RLock()
# generation state
self.state = self.mrt.init_state()
self.mrt.guidance_weight = float(self.params.guidance_weight)
self.mrt.temperature = float(self.params.temperature)
self.mrt.topk = int(self.params.topk)
# codec/setup
self._codec_fps = float(self.mrt.codec.frame_rate)
JamWorker.FRAMES_PER_SECOND = self._codec_fps
self._ctx_frames = int(self.mrt.config.context_length_frames)
self._ctx_seconds = self._ctx_frames / self._codec_fps
# model stream (model SR) for internal continuity/crossfades
self._model_stream: Optional[np.ndarray] = None
self._model_sr = int(self.mrt.sample_rate)
# style vector (already normalized upstream)
self._style_vec = (None if self.params.style_vec is None
else np.array(self.params.style_vec, dtype=np.float32, copy=True))
self._chunk_secs = (
self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples
) / float(self._model_sr) # ≈ 2.0 s by default
# target-SR in-RAM spool (what we cut loops from)
if int(self.params.target_sr) != int(self._model_sr):
self._rs = StreamingResampler(self._model_sr, int(self.params.target_sr), channels=2)
else:
self._rs = None
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
self._spool_written = 0 # absolute frames written into spool
self._pending_tail_model = None # type: Optional[np.ndarray] # last tail at model SR
self._pending_tail_target_len = 0 # number of target-SR samples last tail contributed
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
# emission counters
self.idx = 0 # next chunk index to *produce*
self._next_to_deliver = 0 # next chunk index to hand out via get_next_chunk()
self._last_consumed_index = -1 # updated via mark_chunk_consumed(); generation throttle uses this
# outbox and synchronization
self._outbox: Dict[int, JamChunk] = {}
self._cv = threading.Condition()
# control flags
self._stop_event = threading.Event()
self._max_buffer_ahead = 1
# reseed queues (install at next bar boundary after emission)
self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
self._pending_token_splice: Optional[dict] = None # seamless token splice
# Prepare initial context from combined loop (best musical alignment)
if self.params.combined_loop is not None:
self._install_context_from_loop(self.params.combined_loop)
# ---------- lifecycle ----------
def set_buffer_seconds(self, seconds: float):
"""Clamp how far ahead we allow, in *seconds* of audio."""
chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar()
max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6))))
with self._cv:
self._max_buffer_ahead = max_chunks
def set_buffer_chunks(self, k: int):
with self._cv:
self._max_buffer_ahead = max(0, int(k))
def stop(self):
self._stop_event.set()
# FastAPI reads this to block until the next sequential chunk is ready
def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
deadline = time.time() + timeout
with self._cv:
while True:
c = self._outbox.get(self._next_to_deliver)
if c is not None:
self._next_to_deliver += 1
return c
remaining = deadline - time.time()
if remaining <= 0:
return None
self._cv.wait(timeout=min(0.25, remaining))
def mark_chunk_consumed(self, chunk_index: int):
# This lets the generator run ahead, but not too far
with self._cv:
self._last_consumed_index = max(self._last_consumed_index, int(chunk_index))
# purge old chunks to cap memory
for k in list(self._outbox.keys()):
if k < self._last_consumed_index - 1:
self._outbox.pop(k, None)
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
with self._lock:
if guidance_weight is not None:
self.params.guidance_weight = float(guidance_weight)
if temperature is not None:
self.params.temperature = float(temperature)
if topk is not None:
self.params.topk = int(topk)
# push into mrt
self.mrt.guidance_weight = float(self.params.guidance_weight)
self.mrt.temperature = float(self.params.temperature)
self.mrt.topk = int(self.params.topk)
# ---------- context / reseed ----------
def _expected_token_shape(self) -> Tuple[int, int]:
F = int(self._ctx_frames)
D = int(self.mrt.config.decoder_codec_rvq_depth)
return F, D
def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray:
"""Force tokens to (context_length_frames, rvq_depth), padding/trimming as needed.
Pads missing frames by repeating the last frame (safer than zeros for RVQ stacks)."""
F, D = self._expected_token_shape()
if toks.ndim != 2:
toks = np.atleast_2d(toks)
# depth first
if toks.shape[1] > D:
toks = toks[:, :D]
elif toks.shape[1] < D:
pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1]))
toks = np.concatenate([toks, pad_cols], axis=1)
# frames
if toks.shape[0] < F:
if toks.shape[0] == 0:
toks = np.zeros((1, D), dtype=np.int32)
pad = np.repeat(toks[-1:, :], F - toks.shape[0], axis=0)
toks = np.concatenate([pad, toks], axis=0)
elif toks.shape[0] > F:
toks = toks[-F:, :]
if toks.dtype != np.int32:
toks = toks.astype(np.int32, copy=False)
return toks
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
while ensuring the *end* of the audio lands on a bar boundary.
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
then left-fill from just before that tail (wrapping if needed) to reach exactly
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
tokens to the expected frame count.
"""
wav = loop.as_stereo().resample(self._model_sr)
data = wav.samples.astype(np.float32, copy=False)
if data.ndim == 1:
data = data[:, None]
spb = self._bar_clock.seconds_per_bar()
ctx_sec = float(self._ctx_seconds)
sr = int(self._model_sr)
# bars that fit fully inside ctx_sec (at least 1)
bars_fit = max(1, int(ctx_sec // spb))
tail_len_samps = int(round(bars_fit * spb * sr))
# ensure we have enough source by tiling
need = int(round(ctx_sec * sr)) + tail_len_samps
if data.shape[0] == 0:
data = np.zeros((1, 2), dtype=np.float32)
reps = int(np.ceil(need / float(data.shape[0])))
tiled = np.tile(data, (reps, 1))
end = tiled.shape[0]
tail = tiled[end - tail_len_samps:end]
# left-fill to reach exact ctx samples (keeps end-of-bar alignment)
ctx_samps = int(round(ctx_sec * sr))
pad_len = ctx_samps - tail.shape[0]
if pad_len > 0:
pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
ctx = np.concatenate([pre, tail], axis=0)
else:
ctx = tail[-ctx_samps:]
# final snap to *exact* ctx samples
if ctx.shape[0] < ctx_samps:
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
ctx = np.concatenate([pad, ctx], axis=0)
elif ctx.shape[0] > ctx_samps:
ctx = ctx[-ctx_samps:]
exact = au.Waveform(ctx, sr)
tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
depth = int(self.mrt.config.decoder_codec_rvq_depth)
tokens = tokens_full[:, :depth]
# Force expected (F,D) at *return time*
tokens = self._coerce_tokens(tokens)
return tokens
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
while ensuring the *end* of the audio lands on a bar boundary.
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
then left-fill from just before that tail (wrapping if needed) to reach exactly
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
tokens to the expected frame count.
"""
wav = loop.as_stereo().resample(self._model_sr)
data = wav.samples.astype(np.float32, copy=False)
if data.ndim == 1:
data = data[:, None]
spb = self._bar_clock.seconds_per_bar()
ctx_sec = float(self._ctx_seconds)
sr = int(self._model_sr)
# bars that fit fully inside ctx_sec (at least 1)
bars_fit = max(1, int(ctx_sec // spb))
tail_len_samps = int(round(bars_fit * spb * sr))
# ensure we have enough source by tiling
need = int(round(ctx_sec * sr)) + tail_len_samps
if data.shape[0] == 0:
data = np.zeros((1, 2), dtype=np.float32)
reps = int(np.ceil(need / float(data.shape[0])))
tiled = np.tile(data, (reps, 1))
end = tiled.shape[0]
tail = tiled[end - tail_len_samps:end]
# left-fill to reach exact ctx samples (keeps end-of-bar alignment)
ctx_samps = int(round(ctx_sec * sr))
pad_len = ctx_samps - tail.shape[0]
if pad_len > 0:
pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
ctx = np.concatenate([pre, tail], axis=0)
else:
ctx = tail[-ctx_samps:]
# final snap to *exact* ctx samples
if ctx.shape[0] < ctx_samps:
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
ctx = np.concatenate([pad, ctx], axis=0)
elif ctx.shape[0] > ctx_samps:
ctx = ctx[-ctx_samps:]
exact = au.Waveform(ctx, sr)
tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
depth = int(self.mrt.config.decoder_codec_rvq_depth)
tokens = tokens_full[:, :depth]
# Last defense: force expected frame count
frames = tokens.shape[0]
exp = int(self._ctx_frames)
if frames < exp:
# repeat last frame
pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
tokens = np.concatenate([pad, tokens], axis=0)
elif frames > exp:
tokens = tokens[-exp:, :]
return tokens
def _install_context_from_loop(self, loop: au.Waveform):
# Build exact-length, bar-locked context tokens
context_tokens = self._encode_exact_context_tokens(loop)
s = self.mrt.init_state()
s.context_tokens = context_tokens
self.state = s
self._original_context_tokens = np.copy(context_tokens)
def reseed_from_waveform(self, wav: au.Waveform):
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
context_tokens = self._encode_exact_context_tokens(wav)
with self._lock:
s = self.mrt.init_state()
s.context_tokens = context_tokens
self.state = s
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
self._original_context_tokens = np.copy(context_tokens)
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
"""Queue a *seamless* reseed by token splicing instead of full restart.
We compute a fresh, bar-locked context token tensor of exact length
(e.g., 250 frames), then splice only the *tail* corresponding to
`anchor_bars` so generation continues smoothly without resetting state.
"""
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
F, D = self._expected_token_shape()
# how many frames correspond to the requested anchor bars
spb = self._bar_clock.seconds_per_bar()
frames_per_bar = max(1, int(round(self._codec_fps * spb)))
splice_frames = max(1, min(int(round(max(1.0, float(anchor_bars)) * frames_per_bar)), F))
with self._lock:
# snapshot current context
cur = getattr(self.state, "context_tokens", None)
if cur is None:
# fall back to full reseed (still coerced)
self._pending_reseed = {"ctx": new_ctx}
return
cur = self._coerce_tokens(cur)
# build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
left = cur[:F - splice_frames, :]
right = new_ctx[F - splice_frames:, :]
spliced = np.concatenate([left, right], axis=0)
spliced = self._coerce_tokens(spliced)
# queue for install at the *next bar boundary* right after emission
self._pending_token_splice = {
"tokens": spliced,
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
}
def reseed_from_waveform(self, wav: au.Waveform):
"""Immediate reseed: replace context from provided wave (bar-aligned tail)."""
wav = wav.as_stereo().resample(self._model_sr)
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
depth = int(self.mrt.config.decoder_codec_rvq_depth)
context_tokens = tokens_full[:, :depth]
s = self.mrt.init_state()
s.context_tokens = context_tokens
self.state = s
# reset model stream so next generate starts cleanly
self._model_stream = None
# optional loudness match will be applied per-chunk on emission
# also remember this as new "original"
self._original_context_tokens = np.copy(context_tokens)
# ---------- core streaming helpers ----------
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
"""
Append one MagentaRT chunk into the target-SR spool with an energy-aware,
deferred-overwrite crossfade to avoid writing near-silence at bar edges.
Key behavior:
- Append BODY and TAIL of *this* chunk right away (resampled to target SR).
- Keep THIS chunk's model-rate TAIL (+ its target-SR length if appended) to repair the
previous boundary on the *next* call by mixing (prev_tail*cos + new_head*sin).
- When the correction length Lpop would be 0 (e.g., tail produced no target samples last time),
we APPEND the mixed-overlap to bridge the gap instead of overwriting 0 samples.
- Before overwriting/appending the mixed-overlap, we guard against writing ultra-quiet audio
by normalizing it up (bounded) if it's >20 dB below the existing spool end.
This keeps your bar clock and external timing the same, but removes "bad starts" and fizzles.
"""
import math
import numpy as np
# ---- helpers ----
def _rms_dbfs(x: np.ndarray) -> float:
if x.size == 0:
return -120.0
if x.ndim == 2 and x.shape[1] > 1:
x_m = x.mean(axis=1, dtype=np.float32)
else:
x_m = x.astype(np.float32, copy=False).reshape(-1)
# guard for NaNs
x_m = np.nan_to_num(x_m, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
r = float(np.sqrt(np.mean(x_m * x_m) + 1e-12))
return 20.0 * math.log10(max(r, 1e-12))
def _rms_dbfs_model(x: np.ndarray) -> float:
# same metric; named for clarity in logs
return _rms_dbfs(x)
def to_target(y: np.ndarray) -> np.ndarray:
return y if self._rs is None else self._rs.process(y, final=False)
# ---- unpack model-rate samples ----
s = wav.samples.astype(np.float32, copy=False)
if s.ndim == 1:
s = s[:, None]
if s.shape[1] == 1:
# ensure stereo shape for consistency with your spool (S,2)
s = np.repeat(s, 2, axis=1)
n_samps = int(s.shape[0])
# crossfade length in model samples
try:
xfade_s = float(self.mrt.config.crossfade_length)
except Exception:
xfade_s = 0.0
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
# carve head/body/tail in model domain
if xfade_n > 0 and n_samps >= (2 * xfade_n):
head_m = s[:xfade_n, :]
body_m = s[xfade_n:n_samps - xfade_n, :]
tail_m = s[n_samps - xfade_n:, :]
else:
# too short or no xfade configured — treat everything as body
head_m = np.zeros((0, 2), dtype=np.float32)
body_m = s
tail_m = np.zeros((0, 2), dtype=np.float32)
# ------------------------------------------
# (A) Repair the PREVIOUS boundary if we have a pending model-tail
# ------------------------------------------
did_boundary_mix = False
if (self._pending_tail_model is not None) and (xfade_n > 0) and (n_samps >= xfade_n):
# adaptive crossfade length when either side is very quiet
tail_prev_m = self._pending_tail_model
head_now_m = head_m
# safety: match shapes
if tail_prev_m.shape[1] != 2:
if tail_prev_m.ndim == 1:
tail_prev_m = tail_prev_m[:, None]
tail_prev_m = np.repeat(tail_prev_m[:, :1], 2, axis=1)
if head_now_m.shape[1] != 2:
if head_now_m.ndim == 1:
head_now_m = head_now_m[:, None]
head_now_m = np.repeat(head_now_m[:, :1], 2, axis=1)
# compute energy to decide whether to shorten xfade
tail_r = _rms_dbfs_model(tail_prev_m)
head_r = _rms_dbfs_model(head_now_m)
xfade_use = int(xfade_n)
if min(tail_r, head_r) < -45.0:
xfade_use = max(1, xfade_n // 4)
# windowed overlap (model domain)
Lm = min(xfade_use, tail_prev_m.shape[0], head_now_m.shape[0])
if Lm > 0:
t = np.linspace(0.0, math.pi / 2.0, Lm, endpoint=False, dtype=np.float32)[:, None]
cosw = np.cos(t, dtype=np.float32)
sinw = np.sin(t, dtype=np.float32)
mixed_m = tail_prev_m[-Lm:, :] * cosw + head_now_m[:Lm, :] * sinw
# resample to target and correct the end of the spool
y_mixed = to_target(mixed_m)
Lcorr = int(y_mixed.shape[0])
if Lcorr > 0:
# how many samples from last time's tail did we append?
# (may be zero if resampler yielded nothing then)
Lpop = int(min(self._pending_tail_target_len, self._spool.shape[0], Lcorr))
if Lpop > 0:
# energy-aware overwrite of last Lpop samples
prev_end = self._spool[-Lpop:, :]
new_seg = y_mixed[-Lpop:, :]
prev_r = _rms_dbfs(prev_end)
new_r = _rms_dbfs(new_seg)
# If the new overlap is >20 dB quieter than what's there, lift it (bounded)
if new_r < (prev_r - 20.0):
lift_db = max(0.0, min(20.0, (prev_r - 6.0) - new_r)) # cap boost; leave ~6 dB headroom
scale = 10.0 ** (lift_db / 20.0)
new_seg = np.clip(new_seg * scale, -1.0, 1.0).astype(np.float32, copy=False)
self._spool[-Lpop:, :] = new_seg
print(f"[append] mixedOverlap len={Lpop} rms={_rms_dbfs(new_seg):+.1f} dBFS")
else:
# Nothing to overwrite (e.g., last tail produced 0 target samples).
# Bridge by APPENDING the mixed-overlap.
self._spool = np.concatenate([self._spool, y_mixed], axis=0)
self._spool_written += int(y_mixed.shape[0])
print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_rms_dbfs(y_mixed):+.1f} dBFS")
did_boundary_mix = True
# clear pending once we attempted the repair
self._pending_tail_model = None
self._pending_tail_target_len = 0
# ------------------------------------------
# (B) Append this chunk's BODY then TAIL (target SR)
# ------------------------------------------
# BODY
y_body = to_target(body_m) if body_m.size else np.zeros((0, 2), dtype=np.float32)
if y_body.size:
self._spool = np.concatenate([self._spool, y_body], axis=0)
self._spool_written += int(y_body.shape[0])
print(f"[append] body len={y_body.shape[0] if y_body.size else 0} rms={_rms_dbfs(y_body):+.1f} dBFS")
# TAIL (we append now to keep continuity; on next call we'll correct the end)
y_tail = to_target(tail_m) if tail_m.size else np.zeros((0, 2), dtype=np.float32)
if y_tail.size:
self._spool = np.concatenate([self._spool, y_tail], axis=0)
self._spool_written += int(y_tail.shape[0])
self._pending_tail_target_len = int(y_tail.shape[0]) # how much we just added at target SR
else:
# resampler returned nothing for the tail; mark 0 so next Lpop==0
self._pending_tail_target_len = 0
print(f"[append] tail len={y_tail.shape[0] if y_tail.size else 0} rms={_rms_dbfs(y_tail):+.1f} dBFS")
# keep THIS chunk's model tail to mix with next chunk's head
# (even if y_tail had 0 target samples; in that case we'll bridge by appending mixed overlap)
self._pending_tail_model = tail_m if tail_m.size else None
def _should_generate_next_chunk(self) -> bool:
# Allow running ahead relative to whichever is larger: last *consumed*
# (explicit ack from client) or last *delivered* (implicit ack).
implicit_consumed = self._next_to_deliver - 1 # last chunk handed to client
horizon_anchor = max(self._last_consumed_index, implicit_consumed)
return self.idx <= (horizon_anchor + self._max_buffer_ahead)
def _emit_ready(self):
"""Emit next chunk(s) if the spool has enough samples."""
while True:
start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
if end > self._spool_written:
break # need more audio
loop = self._spool[start:end]
# Loudness match per chunk (bar-aligned reference)
if self.params.loudness_mode != "none" and self.params.combined_loop is not None:
sr = int(self.params.target_sr)
# 1) Get the combined loop at target SR (stereo, float32)
comb = self.params.combined_loop.as_stereo().resample(sr).samples.astype(np.float32, copy=False)
if comb.ndim == 1:
comb = comb[:, None]
if comb.shape[1] == 1:
comb = np.repeat(comb, 2, axis=1)
# 2) Build a reference slice aligned to this outgoing chunk [start:end]
# We wrap/tile the combined loop so it always covers the needed range.
need = end - start
if comb.shape[0] > 0 and need > 0:
s = start % comb.shape[0]
if s + need <= comb.shape[0]:
ref_slice = comb[s:s+need]
else:
part1 = comb[s:]
part2 = comb[:max(0, need - part1.shape[0])]
ref_slice = np.vstack([part1, part2])
ref = au.Waveform(ref_slice, sr)
tgt = au.Waveform(loop.copy(), sr)
matched, _stats = match_loudness_to_reference(
ref, tgt,
method=self.params.loudness_mode,
headroom_db=self.params.headroom_db
)
loop = matched.samples
audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
meta = {
"bpm": float(self.params.bpm),
"bars": int(self.params.bars_per_chunk),
"beats_per_bar": int(self.params.beats_per_bar),
"sample_rate": int(self.params.target_sr),
"channels": int(channels),
"total_samples": int(total_samples),
"seconds_per_bar": self._bar_clock.seconds_per_bar(),
"loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(),
"guidance_weight": float(self.params.guidance_weight),
"temperature": float(self.params.temperature),
"topk": int(self.params.topk),
}
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
if os.getenv("MRT_DEBUG_RMS", "0") == "1":
spb = self._bar_clock.bar_samps
seg = int(max(1, spb // 4)) # quarter-bar window
rms = [float(np.sqrt(np.mean(loop[i:i+seg]**2))) for i in range(0, loop.shape[0], seg)]
print(f"[emit idx={self.idx}] quarter-bar RMS: {rms[:8]}")
with self._cv:
self._outbox[self.idx] = chunk
self._cv.notify_all()
self.idx += 1
# If a reseed is queued, install it *right after* we finish a chunk
with self._lock:
# Prefer seamless token splice when available
if self._pending_token_splice is not None:
spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
try:
# inplace update (no reset)
self.state.context_tokens = spliced
self._pending_token_splice = None
except Exception:
# fallback: full reseed using spliced tokens
new_state = self.mrt.init_state()
new_state.context_tokens = spliced
self.state = new_state
self._model_stream = None
self._pending_token_splice = None
elif self._pending_reseed is not None:
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
new_state = self.mrt.init_state()
new_state.context_tokens = ctx
self.state = new_state
self._model_stream = None
self._pending_reseed = None
# ---------- main loop ----------
def run(self):
# generate until stopped
while not self._stop_event.is_set():
# throttle generation if we are far ahead
if not self._should_generate_next_chunk():
# still try to emit if spool already has enough
self._emit_ready()
time.sleep(0.01)
continue
# generate next model chunk
# snapshot current style vector under lock for this step
with self._lock:
target = self.params.style_vec
if target is None:
style_to_use = None
else:
if self._style_vec is None: # first use: start exactly at initial style (no glide)
self._style_vec = np.array(target, dtype=np.float32, copy=True)
else:
ramp = float(self.params.style_ramp_seconds or 0.0)
step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp)
# linear ramp in embedding space
self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec)
style_to_use = self._style_vec
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use)
# append and spool
self._append_model_chunk_and_spool(wav)
# try emitting zero or more chunks if available
self._emit_ready()
# finalize resampler (flush) — not strictly necessary here
tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
if tail.size:
self._spool = np.concatenate([self._spool, tail], axis=0)
self._spool_written += tail.shape[0]
# one last emit attempt
self._emit_ready()