|  |  | 
					
						
						|  | from __future__ import annotations | 
					
						
						|  | import io, base64, math | 
					
						
						|  | from math import gcd | 
					
						
						|  | import numpy as np | 
					
						
						|  | import soundfile as sf | 
					
						
						|  | from scipy.signal import resample_poly | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from magenta_rt import audio as au | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | import pyloudnorm as pyln | 
					
						
						|  | _HAS_LOUDNORM = True | 
					
						
						|  | except Exception: | 
					
						
						|  | _HAS_LOUDNORM = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _measure_lufs(wav: au.Waveform) -> float: | 
					
						
						|  | meter = pyln.Meter(wav.sample_rate) | 
					
						
						|  | return float(meter.integrated_loudness(wav.samples)) | 
					
						
						|  |  | 
					
						
						|  | def _rms(x: np.ndarray) -> float: | 
					
						
						|  | if x.size == 0: return 0.0 | 
					
						
						|  | return float(np.sqrt(np.mean(x**2))) | 
					
						
						|  |  | 
					
						
						|  | def match_loudness_to_reference( | 
					
						
						|  | ref: au.Waveform, | 
					
						
						|  | target: au.Waveform, | 
					
						
						|  | method: str = "auto", | 
					
						
						|  | headroom_db: float = 1.0 | 
					
						
						|  | ) -> tuple[au.Waveform, dict]: | 
					
						
						|  | stats = {"method": method, "applied_gain_db": 0.0} | 
					
						
						|  | if method == "none": | 
					
						
						|  | return target, stats | 
					
						
						|  |  | 
					
						
						|  | if method == "auto": | 
					
						
						|  | method = "lufs" if _HAS_LOUDNORM else "rms" | 
					
						
						|  |  | 
					
						
						|  | if method == "lufs" and _HAS_LOUDNORM: | 
					
						
						|  | L_ref = _measure_lufs(ref) | 
					
						
						|  | L_tgt = _measure_lufs(target) | 
					
						
						|  | delta_db = L_ref - L_tgt | 
					
						
						|  | gain = 10.0 ** (delta_db / 20.0) | 
					
						
						|  | y = target.samples.astype(np.float32) * gain | 
					
						
						|  | stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) | 
					
						
						|  | else: | 
					
						
						|  | ra = _rms(ref.samples) | 
					
						
						|  | rb = _rms(target.samples) | 
					
						
						|  | if rb <= 1e-12: | 
					
						
						|  | return target, stats | 
					
						
						|  | gain = ra / rb | 
					
						
						|  | y = target.samples.astype(np.float32) * gain | 
					
						
						|  | stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | limit = 10 ** (-headroom_db / 20.0) | 
					
						
						|  | peak = float(np.max(np.abs(y))) if y.size else 0.0 | 
					
						
						|  | if peak > limit: | 
					
						
						|  | y *= (limit / peak) | 
					
						
						|  | stats["post_peak_limited"] = True | 
					
						
						|  | else: | 
					
						
						|  | stats["post_peak_limited"] = False | 
					
						
						|  |  | 
					
						
						|  | target.samples = y.astype(np.float32) | 
					
						
						|  | return target, stats | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True): | 
					
						
						|  | if not chunks: | 
					
						
						|  | raise ValueError("no chunks") | 
					
						
						|  | xfade_n = int(round(xfade_s * sr)) | 
					
						
						|  | if xfade_n <= 0: | 
					
						
						|  | return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) | 
					
						
						|  |  | 
					
						
						|  | t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) | 
					
						
						|  | eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] | 
					
						
						|  |  | 
					
						
						|  | first = chunks[0].samples | 
					
						
						|  | if first.shape[0] < xfade_n: | 
					
						
						|  | raise ValueError("chunk shorter than crossfade prefix") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy() | 
					
						
						|  |  | 
					
						
						|  | for i in range(1, len(chunks)): | 
					
						
						|  | cur = chunks[i].samples | 
					
						
						|  | if cur.shape[0] < xfade_n: | 
					
						
						|  | continue | 
					
						
						|  | head, tail = cur[:xfade_n], cur[xfade_n:] | 
					
						
						|  | mixed = out[-xfade_n:] * eq_out + head * eq_in | 
					
						
						|  | out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) | 
					
						
						|  |  | 
					
						
						|  | return au.Waveform(out, sr) | 
					
						
						|  |  | 
					
						
						|  | def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: | 
					
						
						|  | n = int(round(seconds * wav.sample_rate)) | 
					
						
						|  | return au.Waveform(wav.samples[:n], wav.sample_rate) | 
					
						
						|  |  | 
					
						
						|  | def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: | 
					
						
						|  | n = int(wav.sample_rate * ms / 1000.0) | 
					
						
						|  | if n > 0 and wav.samples.shape[0] > 2*n: | 
					
						
						|  | env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] | 
					
						
						|  | wav.samples[:n]  *= env | 
					
						
						|  | wav.samples[-n:] *= env[::-1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4): | 
					
						
						|  | """ | 
					
						
						|  | Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest | 
					
						
						|  | whole-bar boundary in codec-frame space, even when frames_per_bar is fractional. | 
					
						
						|  |  | 
					
						
						|  | tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames | 
					
						
						|  | bpm: float | 
					
						
						|  | fps: float (codec frames per second; keep this as float) | 
					
						
						|  | ctx_frames: int (length of context window in codec frames) | 
					
						
						|  | beats_per_bar: int | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if tokens is None: | 
					
						
						|  | raise ValueError("tokens is None") | 
					
						
						|  | tokens = np.asarray(tokens) | 
					
						
						|  | if tokens.ndim == 1: | 
					
						
						|  | tokens = tokens[:, None] | 
					
						
						|  |  | 
					
						
						|  | T = tokens.shape[0] | 
					
						
						|  | if T == 0: | 
					
						
						|  | return tokens | 
					
						
						|  |  | 
					
						
						|  | fps = float(fps) | 
					
						
						|  | frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | reps = int(np.ceil((ctx_frames + T) / float(T))) + 1 | 
					
						
						|  | tiled = np.tile(tokens, (reps, 1)) | 
					
						
						|  | total = tiled.shape[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | k_bars = int(np.floor(total / frames_per_bar_f)) | 
					
						
						|  | if k_bars <= 0: | 
					
						
						|  |  | 
					
						
						|  | window = tiled[-ctx_frames:] | 
					
						
						|  | return window | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | end_idx = int(round(k_bars * frames_per_bar_f)) | 
					
						
						|  | end_idx = min(max(end_idx, ctx_frames), total) | 
					
						
						|  | start_idx = end_idx - ctx_frames | 
					
						
						|  | if start_idx < 0: | 
					
						
						|  | start_idx = 0 | 
					
						
						|  | end_idx = ctx_frames | 
					
						
						|  |  | 
					
						
						|  | window = tiled[start_idx:end_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if window.shape[0] < ctx_frames: | 
					
						
						|  | pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1)) | 
					
						
						|  | window = np.vstack([window, pad])[:ctx_frames] | 
					
						
						|  | elif window.shape[0] > ctx_frames: | 
					
						
						|  | window = window[-ctx_frames:] | 
					
						
						|  |  | 
					
						
						|  | return window | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def take_bar_aligned_tail( | 
					
						
						|  | wav: au.Waveform, | 
					
						
						|  | bpm: float, | 
					
						
						|  | beats_per_bar: int, | 
					
						
						|  | ctx_seconds: float, | 
					
						
						|  | max_bars=None | 
					
						
						|  | ) -> au.Waveform: | 
					
						
						|  | """ | 
					
						
						|  | Take a tail whose length is an integer number of bars, with the END aligned | 
					
						
						|  | to a bar boundary. Uses ceil for bars_needed so we never under-fill the context. | 
					
						
						|  | """ | 
					
						
						|  | import math | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | spb = (60.0 / float(bpm)) * float(beats_per_bar) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | eps = 1e-9 | 
					
						
						|  | bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb))) | 
					
						
						|  |  | 
					
						
						|  | if max_bars is not None: | 
					
						
						|  | bars_needed = min(bars_needed, int(max_bars)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | samples_per_bar_f = spb * float(wav.sample_rate) | 
					
						
						|  | n = int(round(bars_needed * samples_per_bar_f)) | 
					
						
						|  |  | 
					
						
						|  | total = int(wav.samples.shape[0]) | 
					
						
						|  | if n >= total: | 
					
						
						|  |  | 
					
						
						|  | return wav | 
					
						
						|  |  | 
					
						
						|  | start = total - n | 
					
						
						|  | return au.Waveform(wav.samples[start:], wav.sample_rate) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | x: np.ndarray shape (S, C), float32 | 
					
						
						|  | Returns: exact-length array (round(seconds*target_sr), C) | 
					
						
						|  | """ | 
					
						
						|  | if x.ndim == 1: | 
					
						
						|  | x = x[:, None] | 
					
						
						|  | if cur_sr != target_sr: | 
					
						
						|  | g = gcd(cur_sr, target_sr) | 
					
						
						|  | up, down = target_sr // g, cur_sr // g | 
					
						
						|  | x = resample_poly(x, up, down, axis=0) | 
					
						
						|  |  | 
					
						
						|  | expected_len = int(round(seconds * target_sr)) | 
					
						
						|  | if x.shape[0] < expected_len: | 
					
						
						|  | pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype) | 
					
						
						|  | x = np.vstack([x, pad]) | 
					
						
						|  | elif x.shape[0] > expected_len: | 
					
						
						|  | x = x[:expected_len, :] | 
					
						
						|  | return x.astype(np.float32, copy=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]: | 
					
						
						|  | """ | 
					
						
						|  | x: np.ndarray shape (S, C) | 
					
						
						|  | returns: (base64_wav, total_samples, channels) | 
					
						
						|  | """ | 
					
						
						|  | buf = io.BytesIO() | 
					
						
						|  | sf.write(buf, x, sr, subtype="FLOAT", format="WAV") | 
					
						
						|  | buf.seek(0) | 
					
						
						|  | b64 = base64.b64encode(buf.read()).decode("utf-8") | 
					
						
						|  | return b64, int(x.shape[0]), int(x.shape[1]) | 
					
						
						|  |  | 
					
						
						|  | def _ratio(out_sr: int, in_sr: int) -> tuple[int, int]: | 
					
						
						|  | g = gcd(int(out_sr), int(in_sr)) | 
					
						
						|  | return int(out_sr) // g, int(in_sr) // g | 
					
						
						|  |  | 
					
						
						|  | class StreamingResampler: | 
					
						
						|  | """ | 
					
						
						|  | Stateful streaming resampler. | 
					
						
						|  | Prefers soxr (best), then libsamplerate; final fallback is block resample_poly. | 
					
						
						|  | Always pass float32 arrays shaped (S, C). | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, in_sr: int, out_sr: int, channels: int = 2, quality: str = "VHQ"): | 
					
						
						|  | self.in_sr = int(in_sr) | 
					
						
						|  | self.out_sr = int(out_sr) | 
					
						
						|  | self.channels = int(channels) | 
					
						
						|  | self.quality = quality | 
					
						
						|  | self._backend = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | import soxr | 
					
						
						|  | self._backend = "soxr" | 
					
						
						|  |  | 
					
						
						|  | self._rs = soxr.Resampler( | 
					
						
						|  | self.in_sr, | 
					
						
						|  | self.out_sr, | 
					
						
						|  | channels=self.channels, | 
					
						
						|  | dtype="float32", | 
					
						
						|  | quality=self.quality, | 
					
						
						|  | ) | 
					
						
						|  | except Exception: | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | import samplerate | 
					
						
						|  | self._backend = "samplerate" | 
					
						
						|  |  | 
					
						
						|  | self._rs = samplerate.Resampler(converter_type="sinc_best", channels=self.channels) | 
					
						
						|  | except Exception: | 
					
						
						|  |  | 
					
						
						|  | from scipy.signal import resample_poly | 
					
						
						|  | self._backend = "scipy" | 
					
						
						|  | self._resample_poly = resample_poly | 
					
						
						|  | self._L, self._M = _ratio(self.out_sr, self.in_sr) | 
					
						
						|  |  | 
					
						
						|  | self._hist = np.zeros((0, self.channels), dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  | def process(self, x: np.ndarray, final: bool = False) -> np.ndarray: | 
					
						
						|  | """Feed a chunk (S, C) and get resampled chunk (S', C). Keep calling in order.""" | 
					
						
						|  | if x.size == 0 and not final: | 
					
						
						|  |  | 
					
						
						|  | return np.zeros((0, self.channels), dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  | if self._backend == "soxr": | 
					
						
						|  | return self._rs.process(x, final=final) | 
					
						
						|  |  | 
					
						
						|  | elif self._backend == "samplerate": | 
					
						
						|  | import samplerate | 
					
						
						|  | ratio = float(self.out_sr) / float(self.in_sr) | 
					
						
						|  |  | 
					
						
						|  | y = self._rs.process(x, ratio, end_of_input=final) | 
					
						
						|  |  | 
					
						
						|  | return y.astype(np.float32, copy=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_ext = x if self._hist.size == 0 else np.vstack([self._hist, x]) | 
					
						
						|  | y = self._resample_poly(x_ext, up=self._L, down=self._M, axis=0).astype(np.float32, copy=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | drop = int(round(self._hist.shape[0] * self.out_sr / self.in_sr)) | 
					
						
						|  | y = y[drop:] if drop < y.shape[0] else np.zeros((0, self.channels), dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tail_samples = max(int(0.004 * self.in_sr), 1) | 
					
						
						|  | self._hist = x[-tail_samples:] if x.shape[0] >= tail_samples else x.copy() | 
					
						
						|  | if final: | 
					
						
						|  | self._hist = np.zeros((0, self.channels), dtype=np.float32) | 
					
						
						|  | return y | 
					
						
						|  |  | 
					
						
						|  | def flush(self) -> np.ndarray: | 
					
						
						|  | """Drain converter tail (call at stop).""" | 
					
						
						|  | if self._backend == "soxr": | 
					
						
						|  | return self._rs.process(np.zeros((0, self.channels), dtype=np.float32), final=True) | 
					
						
						|  | elif self._backend == "samplerate": | 
					
						
						|  | ratio = float(self.out_sr) / float(self.in_sr) | 
					
						
						|  | return self._rs.process(np.zeros((0, self.channels), dtype=np.float32), ratio, end_of_input=True) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | return np.zeros((0, self.channels), dtype=np.float32) | 
					
						
						|  |  |