thecollabagepatch commited on
Commit
5af7cde
·
1 Parent(s): c1e4dcd

fixing bar-aligned context inside /generate route just like we did for jam_worker

Browse files
Files changed (1) hide show
  1. utils.py +17 -20
utils.py CHANGED
@@ -111,44 +111,40 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
111
  # ---------- Token context helpers ----------
112
  def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
113
  """
114
- Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
- whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
116
-
117
- tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
118
- bpm: float
119
- fps: float (codec frames per second; keep this as float)
120
- ctx_frames: int (length of context window in codec frames)
121
- beats_per_bar: int
122
  """
123
-
124
 
125
  if tokens is None:
126
  raise ValueError("tokens is None")
127
  tokens = np.asarray(tokens)
128
  if tokens.ndim == 1:
129
- tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
130
 
131
  T = tokens.shape[0]
132
  if T == 0:
133
  return tokens
134
 
135
  fps = float(fps)
136
- frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
137
 
138
- # Tile a little more than we need so we can always snap the END to a bar boundary
 
 
 
 
 
 
139
  reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
140
  tiled = np.tile(tokens, (reps, 1))
141
  total = tiled.shape[0]
142
 
143
- # How many whole bars fit?
144
- k_bars = int(np.floor(total / frames_per_bar_f))
145
  if k_bars <= 0:
146
- # Fallback: just take the last ctx_frames
147
- window = tiled[-ctx_frames:]
148
- return window
149
 
150
- # Snap END index to the nearest integer frame at a whole-bar boundary
151
- end_idx = int(round(k_bars * frames_per_bar_f))
152
  end_idx = min(max(end_idx, ctx_frames), total)
153
  start_idx = end_idx - ctx_frames
154
  if start_idx < 0:
@@ -157,7 +153,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
157
 
158
  window = tiled[start_idx:end_idx]
159
 
160
- # Guard against rare off-by-one due to rounding
161
  if window.shape[0] < ctx_frames:
162
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
163
  window = np.vstack([window, pad])[:ctx_frames]
@@ -167,6 +163,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
167
  return window
168
 
169
 
 
170
  def take_bar_aligned_tail(
171
  wav: au.Waveform,
172
  bpm: float,
 
111
  # ---------- Token context helpers ----------
112
  def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
113
  """
114
+ Return a ctx_frames-long slice of `tokens` whose **end** lands on an integer
115
+ bar boundary in codec-frame space (model runs at `fps`, typically 25).
 
 
 
 
 
 
116
  """
 
117
 
118
  if tokens is None:
119
  raise ValueError("tokens is None")
120
  tokens = np.asarray(tokens)
121
  if tokens.ndim == 1:
122
+ tokens = tokens[:, None]
123
 
124
  T = tokens.shape[0]
125
  if T == 0:
126
  return tokens
127
 
128
  fps = float(fps)
 
129
 
130
+ # float frames per bar (e.g., ~65.934 at 91 BPM for 4/4 @ 25fps)
131
+ frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
132
+
133
+ # >>> KEY FIX: quantize bar length to an integer number of codec frames
134
+ frames_per_bar_i = max(1, int(round(frames_per_bar_f)))
135
+
136
+ # Tile so we can always snap the *end* to a bar boundary and still have ctx_frames
137
  reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
138
  tiled = np.tile(tokens, (reps, 1))
139
  total = tiled.shape[0]
140
 
141
+ # How many whole integer bars fit in the tiled sequence?
142
+ k_bars = total // frames_per_bar_i
143
  if k_bars <= 0:
144
+ return tiled[-ctx_frames:]
 
 
145
 
146
+ # Snap END to an exact integer multiple of frames_per_bar_i
147
+ end_idx = int(k_bars * frames_per_bar_i)
148
  end_idx = min(max(end_idx, ctx_frames), total)
149
  start_idx = end_idx - ctx_frames
150
  if start_idx < 0:
 
153
 
154
  window = tiled[start_idx:end_idx]
155
 
156
+ # Guard off-by-one
157
  if window.shape[0] < ctx_frames:
158
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
159
  window = np.vstack([window, pad])[:ctx_frames]
 
163
  return window
164
 
165
 
166
+
167
  def take_bar_aligned_tail(
168
  wav: au.Waveform,
169
  bpm: float,