Spaces:
Runtime error
Runtime error
Commit
·
8ba62a7
1
Parent(s):
5af7cde
braindead progress updates for /generate to be used in JUCE
Browse files- app.py +155 -62
- one_shot_generation.py +9 -2
app.py
CHANGED
@@ -205,6 +205,8 @@ _patch_t5x_for_gpu_coords()
|
|
205 |
jam_registry: dict[str, JamWorker] = {}
|
206 |
jam_lock = threading.Lock()
|
207 |
|
|
|
|
|
208 |
@contextmanager
|
209 |
def mrt_overrides(mrt, **kwargs):
|
210 |
"""Temporarily set attributes on MRT if they exist; restore after."""
|
@@ -331,6 +333,33 @@ app.add_middleware(
|
|
331 |
_MRT = None
|
332 |
_MRT_LOCK = threading.Lock()
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
def get_mrt():
|
335 |
global _MRT
|
336 |
if _MRT is None:
|
@@ -441,6 +470,8 @@ def _boot():
|
|
441 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
442 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
443 |
|
|
|
|
|
444 |
@app.get("/model/status")
|
445 |
def model_status():
|
446 |
mrt = get_mrt()
|
@@ -674,7 +705,9 @@ def model_select(req: ModelSelect):
|
|
674 |
# one-shot generation
|
675 |
# ----------------------------
|
676 |
|
677 |
-
|
|
|
|
|
678 |
|
679 |
@app.post("/generate")
|
680 |
def generate(
|
@@ -691,76 +724,136 @@ def generate(
|
|
691 |
temperature: float = Form(1.1),
|
692 |
topk: int = Form(40),
|
693 |
target_sample_rate: int | None = Form(None),
|
694 |
-
intro_bars_to_drop: int = Form(0),
|
|
|
695 |
):
|
696 |
-
|
697 |
-
|
698 |
-
if not data:
|
699 |
-
return {"error": "Empty file"}
|
700 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
701 |
-
tmp.write(data)
|
702 |
-
tmp_path = tmp.name
|
703 |
|
704 |
-
|
705 |
-
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
|
|
|
|
|
|
|
|
715 |
mrt,
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
)
|
727 |
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
|
733 |
-
|
734 |
-
|
735 |
-
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
736 |
-
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
737 |
-
expected_secs = float(bars) * seconds_per_bar
|
738 |
-
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
|
739 |
|
740 |
-
|
741 |
-
|
742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
|
744 |
-
# 4) Metadata
|
745 |
-
metadata = {
|
746 |
-
"bpm": int(round(bpm)),
|
747 |
-
"bars": int(bars),
|
748 |
-
"beats_per_bar": int(beats_per_bar),
|
749 |
-
"styles": extra_styles,
|
750 |
-
"style_weights": weights,
|
751 |
-
"loop_weight": loop_weight,
|
752 |
-
"loudness": loud_stats,
|
753 |
-
"sample_rate": int(target_sr),
|
754 |
-
"channels": int(channels),
|
755 |
-
"crossfade_seconds": mrt.config.crossfade_length,
|
756 |
-
"total_samples": int(total_samples),
|
757 |
-
"seconds_per_bar": seconds_per_bar,
|
758 |
-
"loop_duration_seconds": loop_duration_seconds,
|
759 |
-
"guidance_weight": guidance_weight,
|
760 |
-
"temperature": temperature,
|
761 |
-
"topk": topk,
|
762 |
-
}
|
763 |
-
return {"audio_base64": audio_b64, "metadata": metadata}
|
764 |
|
765 |
# new endpoint to return a bar-aligned chunk without the need for combined audio
|
766 |
|
|
|
205 |
jam_registry: dict[str, JamWorker] = {}
|
206 |
jam_lock = threading.Lock()
|
207 |
|
208 |
+
|
209 |
+
|
210 |
@contextmanager
|
211 |
def mrt_overrides(mrt, **kwargs):
|
212 |
"""Temporarily set attributes on MRT if they exist; restore after."""
|
|
|
333 |
_MRT = None
|
334 |
_MRT_LOCK = threading.Lock()
|
335 |
|
336 |
+
_PROGRESS = {}
|
337 |
+
_PROGRESS_LOCK = threading.Lock()
|
338 |
+
|
339 |
+
def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
|
340 |
+
if not req_id:
|
341 |
+
return
|
342 |
+
with _PROGRESS_LOCK:
|
343 |
+
_PROGRESS[req_id] = {
|
344 |
+
"n": int(n),
|
345 |
+
"total": int(total),
|
346 |
+
"percent": int(round(100.0 * max(0, min(n, total)) / max(1, total))),
|
347 |
+
"stage": stage,
|
348 |
+
"ts": time.time(),
|
349 |
+
}
|
350 |
+
|
351 |
+
def _progress_done(req_id: str):
|
352 |
+
if not req_id:
|
353 |
+
return
|
354 |
+
with _PROGRESS_LOCK:
|
355 |
+
st = _PROGRESS.get(req_id, {})
|
356 |
+
total = st.get("total") or st.get("n") or 1
|
357 |
+
_PROGRESS[req_id] = {"n": total, "total": total, "percent": 100, "stage": "done", "ts": time.time()}
|
358 |
+
|
359 |
+
def _progress_get(req_id: str):
|
360 |
+
with _PROGRESS_LOCK:
|
361 |
+
return _PROGRESS.get(req_id, {"percent": 0, "stage": "pending"})
|
362 |
+
|
363 |
def get_mrt():
|
364 |
global _MRT
|
365 |
if _MRT is None:
|
|
|
470 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
471 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
472 |
|
473 |
+
|
474 |
+
|
475 |
@app.get("/model/status")
|
476 |
def model_status():
|
477 |
mrt = get_mrt()
|
|
|
705 |
# one-shot generation
|
706 |
# ----------------------------
|
707 |
|
708 |
+
@app.get("/progress")
|
709 |
+
def progress(request_id: str):
|
710 |
+
return _progress_get(request_id)
|
711 |
|
712 |
@app.post("/generate")
|
713 |
def generate(
|
|
|
724 |
temperature: float = Form(1.1),
|
725 |
topk: int = Form(40),
|
726 |
target_sample_rate: int | None = Form(None),
|
727 |
+
intro_bars_to_drop: int = Form(0),
|
728 |
+
request_id: str = Form(None),
|
729 |
):
|
730 |
+
req_id = request_id or str(uuid.uuid4())
|
731 |
+
tmp_path = None
|
|
|
|
|
|
|
|
|
|
|
732 |
|
733 |
+
try:
|
734 |
+
# 0) Read file -> tmp wav
|
735 |
+
data = loop_audio.file.read()
|
736 |
+
if not data:
|
737 |
+
# finalize progress as error and return
|
738 |
+
with _PROGRESS_LOCK:
|
739 |
+
_PROGRESS[req_id] = {
|
740 |
+
"percent": 100,
|
741 |
+
"stage": "error",
|
742 |
+
"error": "Empty file",
|
743 |
+
"ts": time.time(),
|
744 |
+
}
|
745 |
+
return {"error": "Empty file", "request_id": req_id}
|
746 |
|
747 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
748 |
+
tmp.write(data)
|
749 |
+
tmp_path = tmp.name
|
750 |
+
|
751 |
+
# 1) Parse styles + weights
|
752 |
+
extra_styles = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
|
753 |
+
weights = [float(x) for x in style_weights.split(",")] if style_weights else None
|
754 |
+
|
755 |
+
# 2) Get model and apply per-request overrides
|
756 |
+
mrt = get_mrt()
|
757 |
+
with mrt_overrides(
|
758 |
mrt,
|
759 |
+
guidance_weight=guidance_weight,
|
760 |
+
temperature=temperature,
|
761 |
+
topk=topk,
|
762 |
+
):
|
763 |
+
# progress callback (called from the generator loop)
|
764 |
+
def on_chunk(i, total):
|
765 |
+
_progress_update(req_id, i, total, stage="generating")
|
766 |
+
|
767 |
+
# 2a) (optional) emit initial 0% once steps are known:
|
768 |
+
# We'll do this inside the generator right after steps is computed.
|
769 |
+
wav, loud_stats = generate_loop_continuation_with_mrt(
|
770 |
+
mrt,
|
771 |
+
input_wav_path=tmp_path,
|
772 |
+
bpm=bpm,
|
773 |
+
extra_styles=extra_styles,
|
774 |
+
style_weights=weights,
|
775 |
+
bars=bars,
|
776 |
+
beats_per_bar=beats_per_bar,
|
777 |
+
loop_weight=loop_weight,
|
778 |
+
loudness_mode=loudness_mode,
|
779 |
+
loudness_headroom_db=loudness_headroom_db,
|
780 |
+
intro_bars_to_drop=intro_bars_to_drop,
|
781 |
+
progress_cb=on_chunk,
|
782 |
+
)
|
783 |
+
|
784 |
+
# 3) Post-process stages (optional: expose sub-stages for nicer UI)
|
785 |
+
# Mark "postprocess" before we resample/snap/encode.
|
786 |
+
st = _PROGRESS_GET(req_id) if False else None # (placeholder so lints don't complain)
|
787 |
+
_progress_update(
|
788 |
+
req_id,
|
789 |
+
_progress_get(req_id).get("total", 1),
|
790 |
+
_progress_get(req_id).get("total", 1),
|
791 |
+
"postprocess",
|
792 |
)
|
793 |
|
794 |
+
# 3a) Determine SR
|
795 |
+
inp_info = sf.info(tmp_path)
|
796 |
+
input_sr = int(inp_info.samplerate)
|
797 |
+
target_sr_val = int(target_sample_rate or input_sr)
|
798 |
+
|
799 |
+
# 3b) Convert SR + snap to exact bars
|
800 |
+
cur_sr = int(mrt.sample_rate)
|
801 |
+
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
802 |
+
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
803 |
+
expected_secs = float(bars) * seconds_per_bar
|
804 |
+
|
805 |
+
# (optional) sub-stage
|
806 |
+
_progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "resample_and_snap")
|
807 |
+
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr_val, seconds=expected_secs)
|
808 |
+
|
809 |
+
# 3c) Encode WAV -> base64
|
810 |
+
_progress_update(req_id, _progress_get(req_id).get("total", 1), _progress_get(req_id).get("total", 1), "encode")
|
811 |
+
audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr_val)
|
812 |
+
loop_duration_seconds = total_samples / float(target_sr_val)
|
813 |
+
|
814 |
+
# 4) Metadata
|
815 |
+
metadata = {
|
816 |
+
"bpm": int(round(bpm)),
|
817 |
+
"bars": int(bars),
|
818 |
+
"beats_per_bar": int(beats_per_bar),
|
819 |
+
"styles": extra_styles,
|
820 |
+
"style_weights": weights,
|
821 |
+
"loop_weight": loop_weight,
|
822 |
+
"loudness": loud_stats,
|
823 |
+
"sample_rate": int(target_sr_val),
|
824 |
+
"channels": int(channels),
|
825 |
+
"crossfade_seconds": mrt.config.crossfade_length,
|
826 |
+
"total_samples": int(total_samples),
|
827 |
+
"seconds_per_bar": seconds_per_bar,
|
828 |
+
"loop_duration_seconds": loop_duration_seconds,
|
829 |
+
"guidance_weight": guidance_weight,
|
830 |
+
"temperature": temperature,
|
831 |
+
"topk": topk,
|
832 |
+
}
|
833 |
|
834 |
+
_progress_done(req_id)
|
835 |
+
return {"audio_base64": audio_b64, "metadata": metadata, "request_id": req_id}
|
|
|
|
|
|
|
|
|
836 |
|
837 |
+
except Exception as e:
|
838 |
+
# Flip to error state so the UI stops polling and can show a message
|
839 |
+
with _PROGRESS_LOCK:
|
840 |
+
_PROGRESS[req_id] = {
|
841 |
+
"percent": 100,
|
842 |
+
"stage": "error",
|
843 |
+
"error": str(e),
|
844 |
+
"ts": time.time(),
|
845 |
+
}
|
846 |
+
# Re-raise so FastAPI returns a 500 (or your exception handler formats it)
|
847 |
+
raise
|
848 |
+
|
849 |
+
finally:
|
850 |
+
# Clean up temp file
|
851 |
+
if tmp_path:
|
852 |
+
try:
|
853 |
+
os.unlink(tmp_path)
|
854 |
+
except Exception:
|
855 |
+
pass
|
856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
|
858 |
# new endpoint to return a bar-aligned chunk without the need for combined audio
|
859 |
|
one_shot_generation.py
CHANGED
@@ -29,6 +29,7 @@ def generate_loop_continuation_with_mrt(
|
|
29 |
loudness_mode: str = "auto",
|
30 |
loudness_headroom_db: float = 1.0,
|
31 |
intro_bars_to_drop: int = 0,
|
|
|
32 |
):
|
33 |
"""
|
34 |
Generate a continuation of an input loop using MagentaRT.
|
@@ -45,6 +46,7 @@ def generate_loop_continuation_with_mrt(
|
|
45 |
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
|
46 |
loudness_headroom_db: Headroom in dB for peak limiting
|
47 |
intro_bars_to_drop: Number of intro bars to generate then drop
|
|
|
48 |
|
49 |
Returns:
|
50 |
Tuple of (au.Waveform output, dict loudness_stats)
|
@@ -90,13 +92,18 @@ def generate_loop_continuation_with_mrt(
|
|
90 |
|
91 |
# Chunk scheduling to cover gen_total_secs
|
92 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
93 |
-
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
|
|
|
|
|
|
|
94 |
|
95 |
# Generate
|
96 |
chunks = []
|
97 |
-
for
|
98 |
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
99 |
chunks.append(wav)
|
|
|
|
|
100 |
|
101 |
# Stitch continuous audio
|
102 |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
|
|
|
29 |
loudness_mode: str = "auto",
|
30 |
loudness_headroom_db: float = 1.0,
|
31 |
intro_bars_to_drop: int = 0,
|
32 |
+
progress_cb=None
|
33 |
):
|
34 |
"""
|
35 |
Generate a continuation of an input loop using MagentaRT.
|
|
|
46 |
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
|
47 |
loudness_headroom_db: Headroom in dB for peak limiting
|
48 |
intro_bars_to_drop: Number of intro bars to generate then drop
|
49 |
+
progress_cb: Braindead progress updates for JUCE
|
50 |
|
51 |
Returns:
|
52 |
Tuple of (au.Waveform output, dict loudness_stats)
|
|
|
92 |
|
93 |
# Chunk scheduling to cover gen_total_secs
|
94 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
95 |
+
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
|
96 |
+
|
97 |
+
if progress_cb:
|
98 |
+
progress_cb(0, steps) # announce total before first chunk
|
99 |
|
100 |
# Generate
|
101 |
chunks = []
|
102 |
+
for i in range(steps):
|
103 |
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
104 |
chunks.append(wav)
|
105 |
+
if progress_cb:
|
106 |
+
progress_cb(i + 1, steps) # <-- report chunk progress
|
107 |
|
108 |
# Stitch continuous audio
|
109 |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
|