thecollabagepatch commited on
Commit
8ba62a7
·
1 Parent(s): 5af7cde

braindead progress updates for /generate to be used in JUCE

Browse files
Files changed (2) hide show
  1. app.py +155 -62
  2. one_shot_generation.py +9 -2
app.py CHANGED
@@ -205,6 +205,8 @@ _patch_t5x_for_gpu_coords()
205
  jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
207
 
 
 
208
  @contextmanager
209
  def mrt_overrides(mrt, **kwargs):
210
  """Temporarily set attributes on MRT if they exist; restore after."""
@@ -331,6 +333,33 @@ app.add_middleware(
331
  _MRT = None
332
  _MRT_LOCK = threading.Lock()
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def get_mrt():
335
  global _MRT
336
  if _MRT is None:
@@ -441,6 +470,8 @@ def _boot():
441
  if os.getenv("MRT_WARMUP", "1") != "0":
442
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
443
 
 
 
444
  @app.get("/model/status")
445
  def model_status():
446
  mrt = get_mrt()
@@ -674,7 +705,9 @@ def model_select(req: ModelSelect):
674
  # one-shot generation
675
  # ----------------------------
676
 
677
-
 
 
678
 
679
  @app.post("/generate")
680
  def generate(
@@ -691,76 +724,136 @@ def generate(
691
  temperature: float = Form(1.1),
692
  topk: int = Form(40),
693
  target_sample_rate: int | None = Form(None),
694
- intro_bars_to_drop: int = Form(0), # <— NEW
 
695
  ):
696
- # Read file
697
- data = loop_audio.file.read()
698
- if not data:
699
- return {"error": "Empty file"}
700
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
701
- tmp.write(data)
702
- tmp_path = tmp.name
703
 
704
- # Parse styles + weights
705
- extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
706
- weights = [float(x) for x in style_weights.split(",")] if style_weights else None
 
 
 
 
 
 
 
 
 
 
707
 
708
- mrt = get_mrt() # warm once, in this worker thread
709
- # Temporarily override MRT inference knobs for this request
710
- with mrt_overrides(mrt,
711
- guidance_weight=guidance_weight,
712
- temperature=temperature,
713
- topk=topk):
714
- wav, loud_stats = generate_loop_continuation_with_mrt(
 
 
 
 
715
  mrt,
716
- input_wav_path=tmp_path,
717
- bpm=bpm,
718
- extra_styles=extra_styles,
719
- style_weights=weights,
720
- bars=bars,
721
- beats_per_bar=beats_per_bar,
722
- loop_weight=loop_weight,
723
- loudness_mode=loudness_mode,
724
- loudness_headroom_db=loudness_headroom_db,
725
- intro_bars_to_drop=intro_bars_to_drop, # <— pass through
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  )
727
 
728
- # 1) Figure out the desired SR
729
- inp_info = sf.info(tmp_path)
730
- input_sr = int(inp_info.samplerate)
731
- target_sr = int(target_sample_rate or input_sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
 
733
- # 2) Convert to target SR + snap to exact bars
734
- cur_sr = int(mrt.sample_rate)
735
- x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
736
- seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
737
- expected_secs = float(bars) * seconds_per_bar
738
- x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
739
 
740
- # 3) Encode WAV once (no extra write)
741
- audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
742
- loop_duration_seconds = total_samples / float(target_sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
 
744
- # 4) Metadata
745
- metadata = {
746
- "bpm": int(round(bpm)),
747
- "bars": int(bars),
748
- "beats_per_bar": int(beats_per_bar),
749
- "styles": extra_styles,
750
- "style_weights": weights,
751
- "loop_weight": loop_weight,
752
- "loudness": loud_stats,
753
- "sample_rate": int(target_sr),
754
- "channels": int(channels),
755
- "crossfade_seconds": mrt.config.crossfade_length,
756
- "total_samples": int(total_samples),
757
- "seconds_per_bar": seconds_per_bar,
758
- "loop_duration_seconds": loop_duration_seconds,
759
- "guidance_weight": guidance_weight,
760
- "temperature": temperature,
761
- "topk": topk,
762
- }
763
- return {"audio_base64": audio_b64, "metadata": metadata}
764
 
765
  # new endpoint to return a bar-aligned chunk without the need for combined audio
766
 
 
205
  jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
207
 
208
+
209
+
210
  @contextmanager
211
  def mrt_overrides(mrt, **kwargs):
212
  """Temporarily set attributes on MRT if they exist; restore after."""
 
333
  _MRT = None
334
  _MRT_LOCK = threading.Lock()
335
 
336
+ _PROGRESS = {}
337
+ _PROGRESS_LOCK = threading.Lock()
338
+
339
+ def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
340
+ if not req_id:
341
+ return
342
+ with _PROGRESS_LOCK:
343
+ _PROGRESS[req_id] = {
344
+ "n": int(n),
345
+ "total": int(total),
346
+ "percent": int(round(100.0 * max(0, min(n, total)) / max(1, total))),
347
+ "stage": stage,
348
+ "ts": time.time(),
349
+ }
350
+
351
+ def _progress_done(req_id: str):
352
+ if not req_id:
353
+ return
354
+ with _PROGRESS_LOCK:
355
+ st = _PROGRESS.get(req_id, {})
356
+ total = st.get("total") or st.get("n") or 1
357
+ _PROGRESS[req_id] = {"n": total, "total": total, "percent": 100, "stage": "done", "ts": time.time()}
358
+
359
+ def _progress_get(req_id: str):
360
+ with _PROGRESS_LOCK:
361
+ return _PROGRESS.get(req_id, {"percent": 0, "stage": "pending"})
362
+
363
  def get_mrt():
364
  global _MRT
365
  if _MRT is None:
 
470
  if os.getenv("MRT_WARMUP", "1") != "0":
471
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
472
 
473
+
474
+
475
  @app.get("/model/status")
476
  def model_status():
477
  mrt = get_mrt()
 
705
  # one-shot generation
706
  # ----------------------------
707
 
708
+ @app.get("/progress")
709
+ def progress(request_id: str):
710
+ return _progress_get(request_id)
711
 
712
  @app.post("/generate")
713
  def generate(
 
724
  temperature: float = Form(1.1),
725
  topk: int = Form(40),
726
  target_sample_rate: int | None = Form(None),
727
+ intro_bars_to_drop: int = Form(0),
728
+ request_id: str = Form(None),
729
  ):
730
+ req_id = request_id or str(uuid.uuid4())
731
+ tmp_path = None
 
 
 
 
 
732
 
733
+ try:
734
+ # 0) Read file -> tmp wav
735
+ data = loop_audio.file.read()
736
+ if not data:
737
+ # finalize progress as error and return
738
+ with _PROGRESS_LOCK:
739
+ _PROGRESS[req_id] = {
740
+ "percent": 100,
741
+ "stage": "error",
742
+ "error": "Empty file",
743
+ "ts": time.time(),
744
+ }
745
+ return {"error": "Empty file", "request_id": req_id}
746
 
747
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
748
+ tmp.write(data)
749
+ tmp_path = tmp.name
750
+
751
+ # 1) Parse styles + weights
752
+ extra_styles = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
753
+ weights = [float(x) for x in style_weights.split(",")] if style_weights else None
754
+
755
+ # 2) Get model and apply per-request overrides
756
+ mrt = get_mrt()
757
+ with mrt_overrides(
758
  mrt,
759
+ guidance_weight=guidance_weight,
760
+ temperature=temperature,
761
+ topk=topk,
762
+ ):
763
+ # progress callback (called from the generator loop)
764
+ def on_chunk(i, total):
765
+ _progress_update(req_id, i, total, stage="generating")
766
+
767
+ # 2a) (optional) emit initial 0% once steps are known:
768
+ # We'll do this inside the generator right after steps is computed.
769
+ wav, loud_stats = generate_loop_continuation_with_mrt(
770
+ mrt,
771
+ input_wav_path=tmp_path,
772
+ bpm=bpm,
773
+ extra_styles=extra_styles,
774
+ style_weights=weights,
775
+ bars=bars,
776
+ beats_per_bar=beats_per_bar,
777
+ loop_weight=loop_weight,
778
+ loudness_mode=loudness_mode,
779
+ loudness_headroom_db=loudness_headroom_db,
780
+ intro_bars_to_drop=intro_bars_to_drop,
781
+ progress_cb=on_chunk,
782
+ )
783
+
784
+ # 3) Post-process stages (optional: expose sub-stages for nicer UI)
785
+ # Mark "postprocess" before we resample/snap/encode.
786
+ st = _PROGRESS_GET(req_id) if False else None # (placeholder so lints don't complain)
787
+ _progress_update(
788
+ req_id,
789
+ _progress_get(req_id).get("total", 1),
790
+ _progress_get(req_id).get("total", 1),
791
+ "postprocess",
792
  )
793
 
794
+ # 3a) Determine SR
795
+ inp_info = sf.info(tmp_path)
796
+ input_sr = int(inp_info.samplerate)
797
+ target_sr_val = int(target_sample_rate or input_sr)
798
+
799
+ # 3b) Convert SR + snap to exact bars
800
+ cur_sr = int(mrt.sample_rate)
801
+ x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
802
+ seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
803
+ expected_secs = float(bars) * seconds_per_bar
804
+
805
+ # (optional) sub-stage
806
+ _progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "resample_and_snap")
807
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr_val, seconds=expected_secs)
808
+
809
+ # 3c) Encode WAV -> base64
810
+ _progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "encode")
811
+ audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr_val)
812
+ loop_duration_seconds = total_samples / float(target_sr_val)
813
+
814
+ # 4) Metadata
815
+ metadata = {
816
+ "bpm": int(round(bpm)),
817
+ "bars": int(bars),
818
+ "beats_per_bar": int(beats_per_bar),
819
+ "styles": extra_styles,
820
+ "style_weights": weights,
821
+ "loop_weight": loop_weight,
822
+ "loudness": loud_stats,
823
+ "sample_rate": int(target_sr_val),
824
+ "channels": int(channels),
825
+ "crossfade_seconds": mrt.config.crossfade_length,
826
+ "total_samples": int(total_samples),
827
+ "seconds_per_bar": seconds_per_bar,
828
+ "loop_duration_seconds": loop_duration_seconds,
829
+ "guidance_weight": guidance_weight,
830
+ "temperature": temperature,
831
+ "topk": topk,
832
+ }
833
 
834
+ _progress_done(req_id)
835
+ return {"audio_base64": audio_b64, "metadata": metadata, "request_id": req_id}
 
 
 
 
836
 
837
+ except Exception as e:
838
+ # Flip to error state so the UI stops polling and can show a message
839
+ with _PROGRESS_LOCK:
840
+ _PROGRESS[req_id] = {
841
+ "percent": 100,
842
+ "stage": "error",
843
+ "error": str(e),
844
+ "ts": time.time(),
845
+ }
846
+ # Re-raise so FastAPI returns a 500 (or your exception handler formats it)
847
+ raise
848
+
849
+ finally:
850
+ # Clean up temp file
851
+ if tmp_path:
852
+ try:
853
+ os.unlink(tmp_path)
854
+ except Exception:
855
+ pass
856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
  # new endpoint to return a bar-aligned chunk without the need for combined audio
859
 
one_shot_generation.py CHANGED
@@ -29,6 +29,7 @@ def generate_loop_continuation_with_mrt(
29
  loudness_mode: str = "auto",
30
  loudness_headroom_db: float = 1.0,
31
  intro_bars_to_drop: int = 0,
 
32
  ):
33
  """
34
  Generate a continuation of an input loop using MagentaRT.
@@ -45,6 +46,7 @@ def generate_loop_continuation_with_mrt(
45
  loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
46
  loudness_headroom_db: Headroom in dB for peak limiting
47
  intro_bars_to_drop: Number of intro bars to generate then drop
 
48
 
49
  Returns:
50
  Tuple of (au.Waveform output, dict loudness_stats)
@@ -90,13 +92,18 @@ def generate_loop_continuation_with_mrt(
90
 
91
  # Chunk scheduling to cover gen_total_secs
92
  chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
93
- steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
 
 
 
94
 
95
  # Generate
96
  chunks = []
97
- for _ in range(steps):
98
  wav, state = mrt.generate_chunk(state=state, style=combined_style)
99
  chunks.append(wav)
 
 
100
 
101
  # Stitch continuous audio
102
  stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
 
29
  loudness_mode: str = "auto",
30
  loudness_headroom_db: float = 1.0,
31
  intro_bars_to_drop: int = 0,
32
+ progress_cb=None
33
  ):
34
  """
35
  Generate a continuation of an input loop using MagentaRT.
 
46
  loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
47
  loudness_headroom_db: Headroom in dB for peak limiting
48
  intro_bars_to_drop: Number of intro bars to generate then drop
49
+ progress_cb: Braindead progress updates for JUCE
50
 
51
  Returns:
52
  Tuple of (au.Waveform output, dict loudness_stats)
 
92
 
93
  # Chunk scheduling to cover gen_total_secs
94
  chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
95
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
96
+
97
+ if progress_cb:
98
+ progress_cb(0, steps) # announce total before first chunk
99
 
100
  # Generate
101
  chunks = []
102
+ for i in range(steps):
103
  wav, state = mrt.generate_chunk(state=state, style=combined_style)
104
  chunks.append(wav)
105
+ if progress_cb:
106
+ progress_cb(i + 1, steps) # <-- report chunk progress
107
 
108
  # Stitch continuous audio
109
  stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()