Commit
·
5af7cde
1
Parent(s):
c1e4dcd
fixing bar-aligned context inside /generate route just like we did for jam_worker
Browse files
utils.py
CHANGED
@@ -111,44 +111,40 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
111 |
# ---------- Token context helpers ----------
|
112 |
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
113 |
"""
|
114 |
-
Return a ctx_frames-long slice of `tokens` whose **end** lands on
|
115 |
-
|
116 |
-
|
117 |
-
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
|
118 |
-
bpm: float
|
119 |
-
fps: float (codec frames per second; keep this as float)
|
120 |
-
ctx_frames: int (length of context window in codec frames)
|
121 |
-
beats_per_bar: int
|
122 |
"""
|
123 |
-
|
124 |
|
125 |
if tokens is None:
|
126 |
raise ValueError("tokens is None")
|
127 |
tokens = np.asarray(tokens)
|
128 |
if tokens.ndim == 1:
|
129 |
-
tokens = tokens[:, None]
|
130 |
|
131 |
T = tokens.shape[0]
|
132 |
if T == 0:
|
133 |
return tokens
|
134 |
|
135 |
fps = float(fps)
|
136 |
-
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
|
137 |
|
138 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
140 |
tiled = np.tile(tokens, (reps, 1))
|
141 |
total = tiled.shape[0]
|
142 |
|
143 |
-
# How many whole bars fit?
|
144 |
-
k_bars =
|
145 |
if k_bars <= 0:
|
146 |
-
|
147 |
-
window = tiled[-ctx_frames:]
|
148 |
-
return window
|
149 |
|
150 |
-
# Snap END
|
151 |
-
end_idx = int(
|
152 |
end_idx = min(max(end_idx, ctx_frames), total)
|
153 |
start_idx = end_idx - ctx_frames
|
154 |
if start_idx < 0:
|
@@ -157,7 +153,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
|
|
157 |
|
158 |
window = tiled[start_idx:end_idx]
|
159 |
|
160 |
-
# Guard
|
161 |
if window.shape[0] < ctx_frames:
|
162 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
163 |
window = np.vstack([window, pad])[:ctx_frames]
|
@@ -167,6 +163,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
|
|
167 |
return window
|
168 |
|
169 |
|
|
|
170 |
def take_bar_aligned_tail(
|
171 |
wav: au.Waveform,
|
172 |
bpm: float,
|
|
|
111 |
# ---------- Token context helpers ----------
|
112 |
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
113 |
"""
|
114 |
+
Return a ctx_frames-long slice of `tokens` whose **end** lands on an integer
|
115 |
+
bar boundary in codec-frame space (model runs at `fps`, typically 25).
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
"""
|
|
|
117 |
|
118 |
if tokens is None:
|
119 |
raise ValueError("tokens is None")
|
120 |
tokens = np.asarray(tokens)
|
121 |
if tokens.ndim == 1:
|
122 |
+
tokens = tokens[:, None]
|
123 |
|
124 |
T = tokens.shape[0]
|
125 |
if T == 0:
|
126 |
return tokens
|
127 |
|
128 |
fps = float(fps)
|
|
|
129 |
|
130 |
+
# float frames per bar (e.g., ~65.934 at 91 BPM for 4/4 @ 25fps)
|
131 |
+
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
|
132 |
+
|
133 |
+
# >>> KEY FIX: quantize bar length to an integer number of codec frames
|
134 |
+
frames_per_bar_i = max(1, int(round(frames_per_bar_f)))
|
135 |
+
|
136 |
+
# Tile so we can always snap the *end* to a bar boundary and still have ctx_frames
|
137 |
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
138 |
tiled = np.tile(tokens, (reps, 1))
|
139 |
total = tiled.shape[0]
|
140 |
|
141 |
+
# How many whole integer bars fit in the tiled sequence?
|
142 |
+
k_bars = total // frames_per_bar_i
|
143 |
if k_bars <= 0:
|
144 |
+
return tiled[-ctx_frames:]
|
|
|
|
|
145 |
|
146 |
+
# Snap END to an exact integer multiple of frames_per_bar_i
|
147 |
+
end_idx = int(k_bars * frames_per_bar_i)
|
148 |
end_idx = min(max(end_idx, ctx_frames), total)
|
149 |
start_idx = end_idx - ctx_frames
|
150 |
if start_idx < 0:
|
|
|
153 |
|
154 |
window = tiled[start_idx:end_idx]
|
155 |
|
156 |
+
# Guard off-by-one
|
157 |
if window.shape[0] < ctx_frames:
|
158 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
159 |
window = np.vstack([window, pad])[:ctx_frames]
|
|
|
163 |
return window
|
164 |
|
165 |
|
166 |
+
|
167 |
def take_bar_aligned_tail(
|
168 |
wav: au.Waveform,
|
169 |
bpm: float,
|