Spaces:
Build error
Build error
cfg
Browse files- app.py +19 -3
- diffrhythm/infer/infer.py +3 -2
- diffrhythm/model/cfm.py +4 -1
app.py
CHANGED
|
@@ -31,7 +31,7 @@ cfm, tokenizer, muq, vae = prepare_model(device)
|
|
| 31 |
cfm = torch.compile(cfm)
|
| 32 |
|
| 33 |
@spaces.GPU(duration=20)
|
| 34 |
-
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
|
| 35 |
|
| 36 |
if randomize_seed:
|
| 37 |
seed = random.randint(0, MAX_SEED)
|
|
@@ -56,10 +56,12 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
|
|
| 56 |
style_prompt=style_prompt,
|
| 57 |
negative_style_prompt=negative_style_prompt,
|
| 58 |
steps=steps,
|
|
|
|
| 59 |
sway_sampling_coef=sway_sampling_coef,
|
| 60 |
start_time=start_time,
|
| 61 |
file_type=file_type,
|
| 62 |
-
vocal_flag=vocal_flag
|
|
|
|
| 63 |
)
|
| 64 |
return generated_song
|
| 65 |
|
|
@@ -223,6 +225,10 @@ with gr.Blocks(css=css) as demo:
|
|
| 223 |
4. **Supported Languages**
|
| 224 |
- **Chinese and English**
|
| 225 |
- More languages comming soon
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
""")
|
| 227 |
|
| 228 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
|
@@ -246,6 +252,16 @@ with gr.Blocks(css=css) as demo:
|
|
| 246 |
interactive=True,
|
| 247 |
elem_id="step_slider"
|
| 248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
|
| 250 |
|
| 251 |
|
|
@@ -387,7 +403,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 387 |
|
| 388 |
lyrics_btn.click(
|
| 389 |
fn=infer_music,
|
| 390 |
-
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, file_type],
|
| 391 |
outputs=audio_output
|
| 392 |
)
|
| 393 |
|
|
|
|
| 31 |
cfm = torch.compile(cfm)
|
| 32 |
|
| 33 |
@spaces.GPU(duration=20)
|
| 34 |
+
def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', max_frames=2048, device='cuda'):
|
| 35 |
|
| 36 |
if randomize_seed:
|
| 37 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
| 56 |
style_prompt=style_prompt,
|
| 57 |
negative_style_prompt=negative_style_prompt,
|
| 58 |
steps=steps,
|
| 59 |
+
cfg_strength=cfg_strength,
|
| 60 |
sway_sampling_coef=sway_sampling_coef,
|
| 61 |
start_time=start_time,
|
| 62 |
file_type=file_type,
|
| 63 |
+
vocal_flag=vocal_flag,
|
| 64 |
+
odeint_method=odeint_method,
|
| 65 |
)
|
| 66 |
return generated_song
|
| 67 |
|
|
|
|
| 225 |
4. **Supported Languages**
|
| 226 |
- **Chinese and English**
|
| 227 |
- More languages comming soon
|
| 228 |
+
|
| 229 |
+
5. **Others**
|
| 230 |
+
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
| 231 |
+
|
| 232 |
""")
|
| 233 |
|
| 234 |
lyrics_btn = gr.Button("Generate", variant="primary")
|
|
|
|
| 252 |
interactive=True,
|
| 253 |
elem_id="step_slider"
|
| 254 |
)
|
| 255 |
+
cfg_strength = gr.Slider(
|
| 256 |
+
minimum=1,
|
| 257 |
+
maximum=10,
|
| 258 |
+
value=4.0,
|
| 259 |
+
step=0.5,
|
| 260 |
+
label="CFG Strength",
|
| 261 |
+
interactive=True,
|
| 262 |
+
elem_id="step_slider"
|
| 263 |
+
)
|
| 264 |
+
odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
|
| 265 |
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
|
| 266 |
|
| 267 |
|
|
|
|
| 403 |
|
| 404 |
lyrics_btn.click(
|
| 405 |
fn=infer_music,
|
| 406 |
+
inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method],
|
| 407 |
outputs=audio_output
|
| 408 |
)
|
| 409 |
|
diffrhythm/infer/infer.py
CHANGED
|
@@ -74,7 +74,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
| 74 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 75 |
return y_final
|
| 76 |
|
| 77 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type, vocal_flag):
|
| 78 |
|
| 79 |
with torch.inference_mode():
|
| 80 |
generated, _ = cfm_model.sample(
|
|
@@ -84,10 +84,11 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
| 84 |
style_prompt=style_prompt,
|
| 85 |
negative_style_prompt=negative_style_prompt,
|
| 86 |
steps=steps,
|
| 87 |
-
cfg_strength=
|
| 88 |
sway_sampling_coef=sway_sampling_coef,
|
| 89 |
start_time=start_time,
|
| 90 |
vocal_flag=vocal_flag,
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
generated = generated.to(torch.float32)
|
|
|
|
| 74 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 75 |
return y_final
|
| 76 |
|
| 77 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
|
| 78 |
|
| 79 |
with torch.inference_mode():
|
| 80 |
generated, _ = cfm_model.sample(
|
|
|
|
| 84 |
style_prompt=style_prompt,
|
| 85 |
negative_style_prompt=negative_style_prompt,
|
| 86 |
steps=steps,
|
| 87 |
+
cfg_strength=cfg_strength,
|
| 88 |
sway_sampling_coef=sway_sampling_coef,
|
| 89 |
start_time=start_time,
|
| 90 |
vocal_flag=vocal_flag,
|
| 91 |
+
odeint_method=odeint_method,
|
| 92 |
)
|
| 93 |
|
| 94 |
generated = generated.to(torch.float32)
|
diffrhythm/model/cfm.py
CHANGED
|
@@ -114,9 +114,12 @@ class CFM(nn.Module):
|
|
| 114 |
start_time=None,
|
| 115 |
latent_pred_start_frame=0,
|
| 116 |
latent_pred_end_frame=2048,
|
| 117 |
-
vocal_flag=False
|
|
|
|
| 118 |
):
|
| 119 |
self.eval()
|
|
|
|
|
|
|
| 120 |
|
| 121 |
if next(self.parameters()).dtype == torch.float16:
|
| 122 |
cond = cond.half()
|
|
|
|
| 114 |
start_time=None,
|
| 115 |
latent_pred_start_frame=0,
|
| 116 |
latent_pred_end_frame=2048,
|
| 117 |
+
vocal_flag=False,
|
| 118 |
+
odeint_method="euler"
|
| 119 |
):
|
| 120 |
self.eval()
|
| 121 |
+
|
| 122 |
+
self.odeint_kwargs = dict(method=odeint_method)
|
| 123 |
|
| 124 |
if next(self.parameters()).dtype == torch.float16:
|
| 125 |
cond = cond.half()
|