Spaces:
Runtime error
Runtime error
[feat] add extend
Browse files- pipeline_ace_step.py +100 -16
- ui/components.py +120 -5
pipeline_ace_step.py
CHANGED
|
@@ -595,23 +595,83 @@ class ACEStepPipeline:
|
|
| 595 |
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
| 596 |
|
| 597 |
is_repaint = False
|
|
|
|
| 598 |
if add_retake_noise:
|
|
|
|
| 599 |
retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
|
| 600 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
| 601 |
repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
|
| 602 |
repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
|
| 603 |
-
|
| 604 |
# retake
|
| 605 |
-
is_repaint = repaint_end_frame - repaint_start_frame != frame_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
# to make sure mean = 0, std = 1
|
| 607 |
if not is_repaint:
|
| 608 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 609 |
-
|
|
|
|
| 610 |
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
| 611 |
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
|
| 612 |
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 613 |
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
|
| 614 |
z0 = repaint_noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 617 |
|
|
@@ -716,6 +776,16 @@ class ACEStepPipeline:
|
|
| 716 |
return sample
|
| 717 |
|
| 718 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
# expand the latents if we are doing classifier free guidance
|
| 720 |
latents = target_latents
|
| 721 |
|
|
@@ -818,14 +888,27 @@ class ACEStepPipeline:
|
|
| 818 |
timestep=timestep,
|
| 819 |
).sample
|
| 820 |
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
return target_latents
|
| 830 |
|
| 831 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
@@ -899,6 +982,7 @@ class ACEStepPipeline:
|
|
| 899 |
save_path: str = None,
|
| 900 |
format: str = "flac",
|
| 901 |
batch_size: int = 1,
|
|
|
|
| 902 |
):
|
| 903 |
|
| 904 |
start_time = time.time()
|
|
@@ -936,7 +1020,7 @@ class ACEStepPipeline:
|
|
| 936 |
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
| 937 |
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
| 938 |
if len(lyrics) > 0:
|
| 939 |
-
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=
|
| 940 |
lyric_mask = [1] * len(lyric_token_idx)
|
| 941 |
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
| 942 |
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
|
@@ -949,7 +1033,7 @@ class ACEStepPipeline:
|
|
| 949 |
preprocess_time_cost = end_time - start_time
|
| 950 |
start_time = end_time
|
| 951 |
|
| 952 |
-
add_retake_noise = task in ("retake", "repaint")
|
| 953 |
# retake equal to repaint
|
| 954 |
if task == "retake":
|
| 955 |
repaint_start = 0
|
|
@@ -957,7 +1041,7 @@ class ACEStepPipeline:
|
|
| 957 |
|
| 958 |
src_latents = None
|
| 959 |
if src_audio_path is not None:
|
| 960 |
-
assert src_audio_path is not None and task in ("repaint", "edit"), "src_audio_path is required for repaint task"
|
| 961 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
| 962 |
src_latents = self.infer_latents(src_audio_path)
|
| 963 |
|
|
@@ -989,7 +1073,7 @@ class ACEStepPipeline:
|
|
| 989 |
target_lyric_token_ids=target_lyric_token_idx,
|
| 990 |
target_lyric_mask=target_lyric_mask,
|
| 991 |
src_latents=src_latents,
|
| 992 |
-
random_generators=
|
| 993 |
infer_steps=infer_step,
|
| 994 |
guidance_scale=guidance_scale,
|
| 995 |
n_min=edit_n_min,
|
|
@@ -1048,8 +1132,8 @@ class ACEStepPipeline:
|
|
| 1048 |
|
| 1049 |
input_params_json = {
|
| 1050 |
"task": task,
|
| 1051 |
-
"prompt": prompt,
|
| 1052 |
-
"lyrics": lyrics,
|
| 1053 |
"audio_duration": audio_duration,
|
| 1054 |
"infer_step": infer_step,
|
| 1055 |
"guidance_scale": guidance_scale,
|
|
|
|
| 595 |
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
| 596 |
|
| 597 |
is_repaint = False
|
| 598 |
+
is_extend = False
|
| 599 |
if add_retake_noise:
|
| 600 |
+
n_min = int(infer_steps * (1 - retake_variance))
|
| 601 |
retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
|
| 602 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
| 603 |
repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
|
| 604 |
repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
|
| 605 |
+
x0 = src_latents
|
| 606 |
# retake
|
| 607 |
+
is_repaint = (repaint_end_frame - repaint_start_frame != frame_length)
|
| 608 |
+
|
| 609 |
+
is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
|
| 610 |
+
if is_extend:
|
| 611 |
+
is_repaint = True
|
| 612 |
+
|
| 613 |
+
# TODO: train a mask aware repainting controlnet
|
| 614 |
# to make sure mean = 0, std = 1
|
| 615 |
if not is_repaint:
|
| 616 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 617 |
+
elif not is_extend:
|
| 618 |
+
# if repaint_end_frame
|
| 619 |
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
| 620 |
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
|
| 621 |
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 622 |
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
|
| 623 |
z0 = repaint_noise
|
| 624 |
+
elif is_extend:
|
| 625 |
+
to_right_pad_gt_latents = None
|
| 626 |
+
to_left_pad_gt_latents = None
|
| 627 |
+
gt_latents = src_latents
|
| 628 |
+
src_latents_length = gt_latents.shape[-1]
|
| 629 |
+
max_infer_fame_length = int(240 * 44100 / 512 / 8)
|
| 630 |
+
left_pad_frame_length = 0
|
| 631 |
+
right_pad_frame_length = 0
|
| 632 |
+
right_trim_length = 0
|
| 633 |
+
left_trim_length = 0
|
| 634 |
+
if repaint_start_frame < 0:
|
| 635 |
+
left_pad_frame_length = abs(repaint_start_frame)
|
| 636 |
+
frame_length = left_pad_frame_length + gt_latents.shape[-1]
|
| 637 |
+
extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
|
| 638 |
+
if frame_length > max_infer_fame_length:
|
| 639 |
+
right_trim_length = frame_length - max_infer_fame_length
|
| 640 |
+
extend_gt_latents = extend_gt_latents[:,:,:,:max_infer_fame_length]
|
| 641 |
+
to_right_pad_gt_latents = extend_gt_latents[:,:,:,-right_trim_length:]
|
| 642 |
+
frame_length = max_infer_fame_length
|
| 643 |
+
repaint_start_frame = 0
|
| 644 |
+
gt_latents = extend_gt_latents
|
| 645 |
+
|
| 646 |
+
if repaint_end_frame > src_latents_length:
|
| 647 |
+
right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
|
| 648 |
+
frame_length = gt_latents.shape[-1] + right_pad_frame_length
|
| 649 |
+
extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
|
| 650 |
+
if frame_length > max_infer_fame_length:
|
| 651 |
+
left_trim_length = frame_length - max_infer_fame_length
|
| 652 |
+
extend_gt_latents = extend_gt_latents[:,:,:,-max_infer_fame_length:]
|
| 653 |
+
to_left_pad_gt_latents = extend_gt_latents[:,:,:,:left_trim_length]
|
| 654 |
+
frame_length = max_infer_fame_length
|
| 655 |
+
repaint_end_frame = frame_length
|
| 656 |
+
gt_latents = extend_gt_latents
|
| 657 |
+
|
| 658 |
+
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
| 659 |
+
if left_pad_frame_length > 0:
|
| 660 |
+
repaint_mask[:,:,:,:left_pad_frame_length] = 1.0
|
| 661 |
+
if right_pad_frame_length > 0:
|
| 662 |
+
repaint_mask[:,:,:,-right_pad_frame_length:] = 1.0
|
| 663 |
+
x0 = gt_latents
|
| 664 |
+
padd_list = []
|
| 665 |
+
if left_pad_frame_length > 0:
|
| 666 |
+
padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
|
| 667 |
+
padd_list.append(target_latents[:,:,:,left_trim_length:target_latents.shape[-1]-right_trim_length])
|
| 668 |
+
if right_pad_frame_length > 0:
|
| 669 |
+
padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
|
| 670 |
+
target_latents = torch.cat(padd_list, dim=-1)
|
| 671 |
+
assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
|
| 672 |
+
|
| 673 |
+
zt_edit = x0.clone()
|
| 674 |
+
z0 = target_latents
|
| 675 |
|
| 676 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 677 |
|
|
|
|
| 776 |
return sample
|
| 777 |
|
| 778 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 779 |
+
|
| 780 |
+
if is_repaint:
|
| 781 |
+
if i < n_min:
|
| 782 |
+
continue
|
| 783 |
+
elif i == n_min:
|
| 784 |
+
t_i = t / 1000
|
| 785 |
+
zt_src = (1 - t_i) * x0 + (t_i) * z0
|
| 786 |
+
target_latents = zt_edit + zt_src - x0
|
| 787 |
+
logger.info(f"repaint start from {n_min} add {t_i} level of noise")
|
| 788 |
+
|
| 789 |
# expand the latents if we are doing classifier free guidance
|
| 790 |
latents = target_latents
|
| 791 |
|
|
|
|
| 888 |
timestep=timestep,
|
| 889 |
).sample
|
| 890 |
|
| 891 |
+
if is_repaint and i >= n_min:
|
| 892 |
+
t_i = t/1000
|
| 893 |
+
if i+1 < len(timesteps):
|
| 894 |
+
t_im1 = (timesteps[i+1])/1000
|
| 895 |
+
else:
|
| 896 |
+
t_im1 = torch.zeros_like(t_i).to(t_i.device)
|
| 897 |
+
dtype = noise_pred.dtype
|
| 898 |
+
target_latents = target_latents.to(torch.float32)
|
| 899 |
+
prev_sample = target_latents + (t_im1 - t_i) * noise_pred
|
| 900 |
+
prev_sample = prev_sample.to(dtype)
|
| 901 |
+
target_latents = prev_sample
|
| 902 |
+
zt_src = (1 - t_im1) * x0 + (t_im1) * z0
|
| 903 |
+
target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
|
| 904 |
+
else:
|
| 905 |
+
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
| 906 |
|
| 907 |
+
if is_extend:
|
| 908 |
+
if to_right_pad_gt_latents is not None:
|
| 909 |
+
target_latents = torch.cate([target_latents, to_right_pad_gt_latents], dim=-1)
|
| 910 |
+
if to_left_pad_gt_latents is not None:
|
| 911 |
+
target_latents = torch.cate([to_right_pad_gt_latents, target_latents], dim=0)
|
| 912 |
return target_latents
|
| 913 |
|
| 914 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
|
|
| 982 |
save_path: str = None,
|
| 983 |
format: str = "flac",
|
| 984 |
batch_size: int = 1,
|
| 985 |
+
debug: bool = False,
|
| 986 |
):
|
| 987 |
|
| 988 |
start_time = time.time()
|
|
|
|
| 1020 |
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
| 1021 |
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
| 1022 |
if len(lyrics) > 0:
|
| 1023 |
+
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
|
| 1024 |
lyric_mask = [1] * len(lyric_token_idx)
|
| 1025 |
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
| 1026 |
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
|
|
|
| 1033 |
preprocess_time_cost = end_time - start_time
|
| 1034 |
start_time = end_time
|
| 1035 |
|
| 1036 |
+
add_retake_noise = task in ("retake", "repaint", "extend")
|
| 1037 |
# retake equal to repaint
|
| 1038 |
if task == "retake":
|
| 1039 |
repaint_start = 0
|
|
|
|
| 1041 |
|
| 1042 |
src_latents = None
|
| 1043 |
if src_audio_path is not None:
|
| 1044 |
+
assert src_audio_path is not None and task in ("repaint", "edit", "extend"), "src_audio_path is required for retake/repaint/extend task"
|
| 1045 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
| 1046 |
src_latents = self.infer_latents(src_audio_path)
|
| 1047 |
|
|
|
|
| 1073 |
target_lyric_token_ids=target_lyric_token_idx,
|
| 1074 |
target_lyric_mask=target_lyric_mask,
|
| 1075 |
src_latents=src_latents,
|
| 1076 |
+
random_generators=retake_random_generators, # more diversity
|
| 1077 |
infer_steps=infer_step,
|
| 1078 |
guidance_scale=guidance_scale,
|
| 1079 |
n_min=edit_n_min,
|
|
|
|
| 1132 |
|
| 1133 |
input_params_json = {
|
| 1134 |
"task": task,
|
| 1135 |
+
"prompt": prompt if task != "edit" else edit_target_prompt,
|
| 1136 |
+
"lyrics": lyrics if task != "edit" else edit_target_lyrics,
|
| 1137 |
"audio_duration": audio_duration,
|
| 1138 |
"infer_step": infer_step,
|
| 1139 |
"guidance_scale": guidance_scale,
|
ui/components.py
CHANGED
|
@@ -65,7 +65,7 @@ def create_text2music_ui(
|
|
| 65 |
with gr.Column():
|
| 66 |
with gr.Row(equal_height=True):
|
| 67 |
# add markdown, tags and lyrics examples are from ai music generation community
|
| 68 |
-
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value
|
| 69 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
| 70 |
|
| 71 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
|
@@ -252,14 +252,15 @@ def create_text2music_ui(
|
|
| 252 |
with gr.Tab("edit"):
|
| 253 |
edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
|
| 254 |
edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
|
| 255 |
-
|
|
|
|
| 256 |
edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
|
| 257 |
-
edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.
|
| 258 |
edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
|
| 259 |
|
| 260 |
def edit_type_change_func(edit_type):
|
| 261 |
if edit_type == "only_lyrics":
|
| 262 |
-
n_min = 0.
|
| 263 |
n_max = 1.0
|
| 264 |
elif edit_type == "remix":
|
| 265 |
n_min = 0.2
|
|
@@ -309,6 +310,7 @@ def create_text2music_ui(
|
|
| 309 |
oss_steps,
|
| 310 |
guidance_scale_text,
|
| 311 |
guidance_scale_lyric,
|
|
|
|
| 312 |
):
|
| 313 |
if edit_source == "upload":
|
| 314 |
src_audio_path = edit_source_audio_upload
|
|
@@ -349,7 +351,8 @@ def create_text2music_ui(
|
|
| 349 |
edit_target_prompt=edit_prompt,
|
| 350 |
edit_target_lyrics=edit_lyrics,
|
| 351 |
edit_n_min=edit_n_min,
|
| 352 |
-
edit_n_max=edit_n_max
|
|
|
|
| 353 |
)
|
| 354 |
|
| 355 |
edit_bnt.click(
|
|
@@ -380,9 +383,121 @@ def create_text2music_ui(
|
|
| 380 |
oss_steps,
|
| 381 |
guidance_scale_text,
|
| 382 |
guidance_scale_lyric,
|
|
|
|
| 383 |
],
|
| 384 |
outputs=edit_outputs + [edit_input_params_json],
|
| 385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
def sample_data():
|
| 388 |
json_data = sample_data_func()
|
|
|
|
| 65 |
with gr.Column():
|
| 66 |
with gr.Row(equal_height=True):
|
| 67 |
# add markdown, tags and lyrics examples are from ai music generation community
|
| 68 |
+
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
| 69 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
| 70 |
|
| 71 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
|
|
|
| 252 |
with gr.Tab("edit"):
|
| 253 |
edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
|
| 254 |
edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
|
| 255 |
+
retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
|
| 256 |
+
|
| 257 |
edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
|
| 258 |
+
edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="edit_n_min", interactive=True)
|
| 259 |
edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
|
| 260 |
|
| 261 |
def edit_type_change_func(edit_type):
|
| 262 |
if edit_type == "only_lyrics":
|
| 263 |
+
n_min = 0.6
|
| 264 |
n_max = 1.0
|
| 265 |
elif edit_type == "remix":
|
| 266 |
n_min = 0.2
|
|
|
|
| 310 |
oss_steps,
|
| 311 |
guidance_scale_text,
|
| 312 |
guidance_scale_lyric,
|
| 313 |
+
retake_seeds,
|
| 314 |
):
|
| 315 |
if edit_source == "upload":
|
| 316 |
src_audio_path = edit_source_audio_upload
|
|
|
|
| 351 |
edit_target_prompt=edit_prompt,
|
| 352 |
edit_target_lyrics=edit_lyrics,
|
| 353 |
edit_n_min=edit_n_min,
|
| 354 |
+
edit_n_max=edit_n_max,
|
| 355 |
+
retake_seeds=retake_seeds,
|
| 356 |
)
|
| 357 |
|
| 358 |
edit_bnt.click(
|
|
|
|
| 383 |
oss_steps,
|
| 384 |
guidance_scale_text,
|
| 385 |
guidance_scale_lyric,
|
| 386 |
+
retake_seeds,
|
| 387 |
],
|
| 388 |
outputs=edit_outputs + [edit_input_params_json],
|
| 389 |
)
|
| 390 |
+
with gr.Tab("extend"):
|
| 391 |
+
extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None)
|
| 392 |
+
left_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Left Extend Length", interactive=True)
|
| 393 |
+
right_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Right Extend Length", interactive=True)
|
| 394 |
+
extend_source = gr.Radio(["text2music", "last_extend", "upload"], value="text2music", label="Extend Source", elem_id="extend_source")
|
| 395 |
+
|
| 396 |
+
extend_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="extend_source_audio_upload")
|
| 397 |
+
extend_source.change(
|
| 398 |
+
fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
|
| 399 |
+
inputs=[extend_source],
|
| 400 |
+
outputs=[extend_source_audio_upload],
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
extend_bnt = gr.Button("Extend", variant="primary")
|
| 404 |
+
extend_outputs, extend_input_params_json = create_output_ui("Extend")
|
| 405 |
+
|
| 406 |
+
def extend_process_func(
|
| 407 |
+
text2music_json_data,
|
| 408 |
+
extend_input_params_json,
|
| 409 |
+
extend_seeds,
|
| 410 |
+
left_extend_length,
|
| 411 |
+
right_extend_length,
|
| 412 |
+
extend_source,
|
| 413 |
+
extend_source_audio_upload,
|
| 414 |
+
prompt,
|
| 415 |
+
lyrics,
|
| 416 |
+
infer_step,
|
| 417 |
+
guidance_scale,
|
| 418 |
+
scheduler_type,
|
| 419 |
+
cfg_type,
|
| 420 |
+
omega_scale,
|
| 421 |
+
manual_seeds,
|
| 422 |
+
guidance_interval,
|
| 423 |
+
guidance_interval_decay,
|
| 424 |
+
min_guidance_scale,
|
| 425 |
+
use_erg_tag,
|
| 426 |
+
use_erg_lyric,
|
| 427 |
+
use_erg_diffusion,
|
| 428 |
+
oss_steps,
|
| 429 |
+
guidance_scale_text,
|
| 430 |
+
guidance_scale_lyric,
|
| 431 |
+
):
|
| 432 |
+
if extend_source == "upload":
|
| 433 |
+
src_audio_path = extend_source_audio_upload
|
| 434 |
+
json_data = text2music_json_data
|
| 435 |
+
elif extend_source == "text2music":
|
| 436 |
+
json_data = text2music_json_data
|
| 437 |
+
src_audio_path = json_data["audio_path"]
|
| 438 |
+
elif extend_source == "last_repaint":
|
| 439 |
+
json_data = extend_input_params_json
|
| 440 |
+
src_audio_path = json_data["audio_path"]
|
| 441 |
+
|
| 442 |
+
repaint_start = -left_extend_length
|
| 443 |
+
repaint_end = json_data["audio_duration"] + right_extend_length
|
| 444 |
+
return text2music_process_func(
|
| 445 |
+
json_data["audio_duration"],
|
| 446 |
+
prompt,
|
| 447 |
+
lyrics,
|
| 448 |
+
infer_step,
|
| 449 |
+
guidance_scale,
|
| 450 |
+
scheduler_type,
|
| 451 |
+
cfg_type,
|
| 452 |
+
omega_scale,
|
| 453 |
+
manual_seeds,
|
| 454 |
+
guidance_interval,
|
| 455 |
+
guidance_interval_decay,
|
| 456 |
+
min_guidance_scale,
|
| 457 |
+
use_erg_tag,
|
| 458 |
+
use_erg_lyric,
|
| 459 |
+
use_erg_diffusion,
|
| 460 |
+
oss_steps,
|
| 461 |
+
guidance_scale_text,
|
| 462 |
+
guidance_scale_lyric,
|
| 463 |
+
retake_seeds=extend_seeds,
|
| 464 |
+
retake_variance=1.0,
|
| 465 |
+
task="extend",
|
| 466 |
+
repaint_start=repaint_start,
|
| 467 |
+
repaint_end=repaint_end,
|
| 468 |
+
src_audio_path=src_audio_path,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
extend_bnt.click(
|
| 472 |
+
fn=extend_process_func,
|
| 473 |
+
inputs=[
|
| 474 |
+
input_params_json,
|
| 475 |
+
extend_input_params_json,
|
| 476 |
+
extend_seeds,
|
| 477 |
+
left_extend_length,
|
| 478 |
+
right_extend_length,
|
| 479 |
+
extend_source,
|
| 480 |
+
extend_source_audio_upload,
|
| 481 |
+
prompt,
|
| 482 |
+
lyrics,
|
| 483 |
+
infer_step,
|
| 484 |
+
guidance_scale,
|
| 485 |
+
scheduler_type,
|
| 486 |
+
cfg_type,
|
| 487 |
+
omega_scale,
|
| 488 |
+
manual_seeds,
|
| 489 |
+
guidance_interval,
|
| 490 |
+
guidance_interval_decay,
|
| 491 |
+
min_guidance_scale,
|
| 492 |
+
use_erg_tag,
|
| 493 |
+
use_erg_lyric,
|
| 494 |
+
use_erg_diffusion,
|
| 495 |
+
oss_steps,
|
| 496 |
+
guidance_scale_text,
|
| 497 |
+
guidance_scale_lyric,
|
| 498 |
+
],
|
| 499 |
+
outputs=extend_outputs + [extend_input_params_json],
|
| 500 |
+
)
|
| 501 |
|
| 502 |
def sample_data():
|
| 503 |
json_data = sample_data_func()
|