Commit 
							
							Β·
						
						1b98b73
	
1
								Parent(s):
							
							783cbeb
								
fixing continuity
Browse files- jam_worker.py +106 -72
- utils.py +4 -2
    	
        jam_worker.py
    CHANGED
    
    | @@ -350,88 +350,122 @@ class JamWorker(threading.Thread): | |
| 350 | 
             
                        self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
         | 
| 351 |  | 
| 352 | 
             
                def run(self):
         | 
| 353 | 
            -
                    """ | 
| 354 | 
            -
                    sr_model = int(self.mrt.sample_rate)
         | 
| 355 | 
             
                    spb = self._seconds_per_bar()
         | 
| 356 | 
            -
                    chunk_secs =  | 
| 357 | 
            -
                     | 
| 358 | 
            -
             | 
| 359 | 
            -
             | 
| 360 | 
            -
                    #  | 
| 361 | 
            -
                     | 
| 362 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 363 |  | 
| 364 | 
            -
                    print("π JamWorker  | 
| 365 |  | 
| 366 | 
             
                    while not self._stop_event.is_set():
         | 
| 367 | 
            -
                        #  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 368 | 
             
                        with self._lock:
         | 
| 369 | 
            -
                            if self.idx > self._last_delivered_index + self._max_buffer_ahead:
         | 
| 370 | 
            -
                                time.sleep(0.25)
         | 
| 371 | 
            -
                                continue
         | 
| 372 | 
             
                            style_vec = self.params.style_vec
         | 
| 373 | 
            -
                            self.mrt.guidance_weight = self.params.guidance_weight
         | 
| 374 | 
            -
                            self.mrt.temperature     = self.params.temperature
         | 
| 375 | 
            -
                            self.mrt.topk            = self.params.topk
         | 
|  | |
| 376 |  | 
| 377 | 
            -
                         | 
| 378 | 
             
                        self.last_chunk_started_at = time.time()
         | 
| 379 | 
            -
                        wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
         | 
| 380 | 
            -
                        self._append_model_chunk_to_stream(wav)
         | 
| 381 | 
            -
                        if getattr(self, "_needs_bar_realign", False):
         | 
| 382 | 
            -
                            self._realign_emit_pointer_to_bar(sr_model)
         | 
| 383 | 
            -
                            self._needs_bar_realign = False
         | 
| 384 | 
            -
                            # DEBUG
         | 
| 385 | 
            -
                            bar_samps = int(round(self._seconds_per_bar() * sr_model))
         | 
| 386 | 
            -
                            if bar_samps > 0 and (self._next_emit_start % bar_samps) != 0:
         | 
| 387 | 
            -
                                print(f"β οΈ emit pointer not aligned: phase={self._next_emit_start % bar_samps}")
         | 
| 388 | 
            -
                            else:
         | 
| 389 | 
            -
                                print("β
 emit pointer aligned to bar")
         | 
| 390 | 
            -
             | 
| 391 | 
            -
                        self.last_chunk_completed_at = time.time()
         | 
| 392 | 
            -
             | 
| 393 | 
            -
                        # While we have at least one full 8-bar window available, emit it
         | 
| 394 | 
            -
                        while (getattr(self, "_stream", None) is not None and
         | 
| 395 | 
            -
                            self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
         | 
| 396 | 
            -
                            not self._stop_event.is_set()):
         | 
| 397 | 
            -
             | 
| 398 | 
            -
                            seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
         | 
| 399 |  | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
                             | 
| 407 | 
            -
                             | 
| 408 | 
            -
             | 
| 409 | 
            -
             | 
| 410 | 
            -
             | 
| 411 | 
            -
             | 
| 412 | 
            -
             | 
| 413 | 
            -
             | 
| 414 | 
            -
             | 
| 415 | 
            -
             | 
| 416 | 
            -
             | 
| 417 | 
            -
             | 
| 418 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 419 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 420 |  | 
| 421 | 
            -
             | 
| 422 | 
            -
             | 
| 423 | 
            -
             | 
| 424 | 
            -
             | 
| 425 | 
            -
             | 
| 426 | 
            -
             | 
| 427 | 
            -
             | 
| 428 | 
            -
             | 
| 429 | 
            -
                            self._next_emit_start += chunk_n_model
         | 
| 430 |  | 
| 431 | 
            -
             | 
| 432 | 
            -
             | 
| 433 | 
            -
                            if keep_from > 0:
         | 
| 434 | 
            -
                                self._stream = self._stream[keep_from:]
         | 
| 435 | 
            -
                                self._next_emit_start -= keep_from
         | 
| 436 |  | 
| 437 | 
            -
                    print("π JamWorker  | 
|  | |
| 350 | 
             
                        self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
         | 
| 351 |  | 
| 352 | 
             
                def run(self):
         | 
| 353 | 
            +
                    """Main worker loop - generate chunks continuously but don't get too far ahead"""
         | 
|  | |
| 354 | 
             
                    spb = self._seconds_per_bar()
         | 
| 355 | 
            +
                    chunk_secs = self.params.bars_per_chunk * spb
         | 
| 356 | 
            +
                    xfade = float(self.mrt.config.crossfade_length)  # seconds
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    # local fallback stitcher that *keeps* the first head if utils.stitch_generated
         | 
| 359 | 
            +
                    # doesn't yet support drop_first_pre_roll
         | 
| 360 | 
            +
                    def _stitch_keep_head(chunks, sr: int, xfade_s: float):
         | 
| 361 | 
            +
                        from magenta_rt import audio as au
         | 
| 362 | 
            +
                        import numpy as _np
         | 
| 363 | 
            +
                        if not chunks:
         | 
| 364 | 
            +
                            raise ValueError("no chunks to stitch")
         | 
| 365 | 
            +
                        xfade_n = int(round(max(0.0, xfade_s) * sr))
         | 
| 366 | 
            +
                        # Fast-path: no crossfade
         | 
| 367 | 
            +
                        if xfade_n <= 0:
         | 
| 368 | 
            +
                            out = _np.concatenate([c.samples for c in chunks], axis=0)
         | 
| 369 | 
            +
                            return au.Waveform(out, sr)
         | 
| 370 | 
            +
                        # build equal-power curves
         | 
| 371 | 
            +
                        t = _np.linspace(0, _np.pi / 2, xfade_n, endpoint=False, dtype=_np.float32)
         | 
| 372 | 
            +
                        eq_in, eq_out = _np.sin(t)[:, None], _np.cos(t)[:, None]
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                        first = chunks[0].samples
         | 
| 375 | 
            +
                        if first.shape[0] < xfade_n:
         | 
| 376 | 
            +
                            raise ValueError("chunk shorter than crossfade prefix")
         | 
| 377 | 
            +
                        out = first.copy()  # π keep the head for live seam
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                        for i in range(1, len(chunks)):
         | 
| 380 | 
            +
                            cur = chunks[i].samples
         | 
| 381 | 
            +
                            if cur.shape[0] < xfade_n:
         | 
| 382 | 
            +
                                # too short to crossfade; just butt-join
         | 
| 383 | 
            +
                                out = _np.concatenate([out, cur], axis=0)
         | 
| 384 | 
            +
                                continue
         | 
| 385 | 
            +
                            head, tail = cur[:xfade_n], cur[xfade_n:]
         | 
| 386 | 
            +
                            mixed = out[-xfade_n:] * eq_out + head * eq_in
         | 
| 387 | 
            +
                            out = _np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
         | 
| 388 | 
            +
                        return au.Waveform(out, sr)
         | 
| 389 |  | 
| 390 | 
            +
                    print("π JamWorker started with flow control...")
         | 
| 391 |  | 
| 392 | 
             
                    while not self._stop_event.is_set():
         | 
| 393 | 
            +
                        # Donβt get too far ahead of the consumer
         | 
| 394 | 
            +
                        if not self._should_generate_next_chunk():
         | 
| 395 | 
            +
                            # We're ahead enough, wait a bit for frontend to catch up
         | 
| 396 | 
            +
                            # (kept short so stop() stays responsive)
         | 
| 397 | 
            +
                            time.sleep(0.5)
         | 
| 398 | 
            +
                            continue
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        # Snapshot knobs + compute index atomically
         | 
| 401 | 
             
                        with self._lock:
         | 
|  | |
|  | |
|  | |
| 402 | 
             
                            style_vec = self.params.style_vec
         | 
| 403 | 
            +
                            self.mrt.guidance_weight = float(self.params.guidance_weight)
         | 
| 404 | 
            +
                            self.mrt.temperature     = float(self.params.temperature)
         | 
| 405 | 
            +
                            self.mrt.topk            = int(self.params.topk)
         | 
| 406 | 
            +
                            next_idx = self.idx + 1
         | 
| 407 |  | 
| 408 | 
            +
                        print(f"πΉ Generating chunk {next_idx}...")
         | 
| 409 | 
             
                        self.last_chunk_started_at = time.time()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 410 |  | 
| 411 | 
            +
                        # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
         | 
| 412 | 
            +
                        # Count the first chunk at full length L, and each subsequent at (L - xfade)
         | 
| 413 | 
            +
                        assembled = 0.0
         | 
| 414 | 
            +
                        chunks = []
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        while assembled < chunk_secs and not self._stop_event.is_set():
         | 
| 417 | 
            +
                            # generate_chunk returns (au.Waveform, new_state)
         | 
| 418 | 
            +
                            wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
         | 
| 419 | 
            +
                            chunks.append(wav)
         | 
| 420 | 
            +
                            L = wav.samples.shape[0] / float(self.mrt.sample_rate)
         | 
| 421 | 
            +
                            assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                        if self._stop_event.is_set():
         | 
| 424 | 
            +
                            break
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                        # ---- Stitch and trim at model SR (keep first head for seamless handoff) ----
         | 
| 427 | 
            +
                        try:
         | 
| 428 | 
            +
                            # Preferred path if you've added the new param in utils.stitch_generated
         | 
| 429 | 
            +
                            y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
         | 
| 430 | 
            +
                        except TypeError:
         | 
| 431 | 
            +
                            # Backward-compatible: local stitcher that keeps the head
         | 
| 432 | 
            +
                            y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        # Hard trim to the exact musical duration (still at model SR)
         | 
| 435 | 
            +
                        y = hard_trim_seconds(y, chunk_secs)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                        # ---- Post-processing ----
         | 
| 438 | 
            +
                        if next_idx == 1 and self.params.ref_loop is not None:
         | 
| 439 | 
            +
                            # match loudness to the provided reference on the very first audible chunk
         | 
| 440 | 
            +
                            y, _ = match_loudness_to_reference(
         | 
| 441 | 
            +
                                self.params.ref_loop, y,
         | 
| 442 | 
            +
                                method=self.params.loudness_mode,
         | 
| 443 | 
            +
                                headroom_db=self.params.headroom_db
         | 
| 444 | 
             
                            )
         | 
| 445 | 
            +
                        else:
         | 
| 446 | 
            +
                            # light micro-fades to guard against clicks
         | 
| 447 | 
            +
                            apply_micro_fades(y, 3)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                        # ---- Resample + bar-snap + encode ----
         | 
| 450 | 
            +
                        b64, meta = self._snap_and_encode(
         | 
| 451 | 
            +
                            y,
         | 
| 452 | 
            +
                            seconds=chunk_secs,
         | 
| 453 | 
            +
                            target_sr=self.params.target_sr,
         | 
| 454 | 
            +
                            bars=self.params.bars_per_chunk
         | 
| 455 | 
            +
                        )
         | 
| 456 | 
            +
                        # small hint for the client if you want UI butter between chunks
         | 
| 457 | 
            +
                        meta["xfade_seconds"] = xfade
         | 
| 458 |  | 
| 459 | 
            +
                        # ---- Publish the completed chunk ----
         | 
| 460 | 
            +
                        with self._lock:
         | 
| 461 | 
            +
                            self.idx = next_idx
         | 
| 462 | 
            +
                            self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
         | 
| 463 | 
            +
                            # Keep outbox bounded (trim far-behind entries)
         | 
| 464 | 
            +
                            if len(self.outbox) > 10:
         | 
| 465 | 
            +
                                cutoff = self._last_delivered_index - 5
         | 
| 466 | 
            +
                                self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
         | 
|  | |
| 467 |  | 
| 468 | 
            +
                        self.last_chunk_completed_at = time.time()
         | 
| 469 | 
            +
                        print(f"β
 Completed chunk {next_idx}")
         | 
|  | |
|  | |
|  | |
| 470 |  | 
| 471 | 
            +
                    print("π JamWorker stopped")
         | 
    	
        utils.py
    CHANGED
    
    | @@ -69,7 +69,7 @@ def match_loudness_to_reference( | |
| 69 |  | 
| 70 |  | 
| 71 | 
             
            # ---------- Stitch / fades / trims ----------
         | 
| 72 | 
            -
            def stitch_generated(chunks, sr: int, xfade_s: float | 
| 73 | 
             
                if not chunks:
         | 
| 74 | 
             
                    raise ValueError("no chunks")
         | 
| 75 | 
             
                xfade_n = int(round(xfade_s * sr))
         | 
| @@ -82,7 +82,9 @@ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform: | |
| 82 | 
             
                first = chunks[0].samples
         | 
| 83 | 
             
                if first.shape[0] < xfade_n:
         | 
| 84 | 
             
                    raise ValueError("chunk shorter than crossfade prefix")
         | 
| 85 | 
            -
             | 
|  | |
|  | |
| 86 |  | 
| 87 | 
             
                for i in range(1, len(chunks)):
         | 
| 88 | 
             
                    cur = chunks[i].samples
         | 
|  | |
| 69 |  | 
| 70 |  | 
| 71 | 
             
            # ---------- Stitch / fades / trims ----------
         | 
| 72 | 
            +
            def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True):
         | 
| 73 | 
             
                if not chunks:
         | 
| 74 | 
             
                    raise ValueError("no chunks")
         | 
| 75 | 
             
                xfade_n = int(round(xfade_s * sr))
         | 
|  | |
| 82 | 
             
                first = chunks[0].samples
         | 
| 83 | 
             
                if first.shape[0] < xfade_n:
         | 
| 84 | 
             
                    raise ValueError("chunk shorter than crossfade prefix")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                # π§ key change:
         | 
| 87 | 
            +
                out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy()
         | 
| 88 |  | 
| 89 | 
             
                for i in range(1, len(chunks)):
         | 
| 90 | 
             
                    cur = chunks[i].samples
         | 
