Spaces:
Running
Running
[feat] add extend
Browse files- pipeline_ace_step.py +100 -16
- ui/components.py +120 -5
pipeline_ace_step.py
CHANGED
@@ -595,23 +595,83 @@ class ACEStepPipeline:
|
|
595 |
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
596 |
|
597 |
is_repaint = False
|
|
|
598 |
if add_retake_noise:
|
|
|
599 |
retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
|
600 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
601 |
repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
|
602 |
repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
|
603 |
-
|
604 |
# retake
|
605 |
-
is_repaint = repaint_end_frame - repaint_start_frame != frame_length
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
# to make sure mean = 0, std = 1
|
607 |
if not is_repaint:
|
608 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
609 |
-
|
|
|
610 |
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
611 |
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
|
612 |
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
613 |
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
|
614 |
z0 = repaint_noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
|
616 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
617 |
|
@@ -716,6 +776,16 @@ class ACEStepPipeline:
|
|
716 |
return sample
|
717 |
|
718 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
# expand the latents if we are doing classifier free guidance
|
720 |
latents = target_latents
|
721 |
|
@@ -818,14 +888,27 @@ class ACEStepPipeline:
|
|
818 |
timestep=timestep,
|
819 |
).sample
|
820 |
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
-
|
|
|
|
|
|
|
|
|
829 |
return target_latents
|
830 |
|
831 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
@@ -899,6 +982,7 @@ class ACEStepPipeline:
|
|
899 |
save_path: str = None,
|
900 |
format: str = "flac",
|
901 |
batch_size: int = 1,
|
|
|
902 |
):
|
903 |
|
904 |
start_time = time.time()
|
@@ -936,7 +1020,7 @@ class ACEStepPipeline:
|
|
936 |
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
937 |
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
938 |
if len(lyrics) > 0:
|
939 |
-
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=
|
940 |
lyric_mask = [1] * len(lyric_token_idx)
|
941 |
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
942 |
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
@@ -949,7 +1033,7 @@ class ACEStepPipeline:
|
|
949 |
preprocess_time_cost = end_time - start_time
|
950 |
start_time = end_time
|
951 |
|
952 |
-
add_retake_noise = task in ("retake", "repaint")
|
953 |
# retake equal to repaint
|
954 |
if task == "retake":
|
955 |
repaint_start = 0
|
@@ -957,7 +1041,7 @@ class ACEStepPipeline:
|
|
957 |
|
958 |
src_latents = None
|
959 |
if src_audio_path is not None:
|
960 |
-
assert src_audio_path is not None and task in ("repaint", "edit"), "src_audio_path is required for repaint task"
|
961 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
962 |
src_latents = self.infer_latents(src_audio_path)
|
963 |
|
@@ -989,7 +1073,7 @@ class ACEStepPipeline:
|
|
989 |
target_lyric_token_ids=target_lyric_token_idx,
|
990 |
target_lyric_mask=target_lyric_mask,
|
991 |
src_latents=src_latents,
|
992 |
-
random_generators=
|
993 |
infer_steps=infer_step,
|
994 |
guidance_scale=guidance_scale,
|
995 |
n_min=edit_n_min,
|
@@ -1048,8 +1132,8 @@ class ACEStepPipeline:
|
|
1048 |
|
1049 |
input_params_json = {
|
1050 |
"task": task,
|
1051 |
-
"prompt": prompt,
|
1052 |
-
"lyrics": lyrics,
|
1053 |
"audio_duration": audio_duration,
|
1054 |
"infer_step": infer_step,
|
1055 |
"guidance_scale": guidance_scale,
|
|
|
595 |
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
596 |
|
597 |
is_repaint = False
|
598 |
+
is_extend = False
|
599 |
if add_retake_noise:
|
600 |
+
n_min = int(infer_steps * (1 - retake_variance))
|
601 |
retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
|
602 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
603 |
repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
|
604 |
repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
|
605 |
+
x0 = src_latents
|
606 |
# retake
|
607 |
+
is_repaint = (repaint_end_frame - repaint_start_frame != frame_length)
|
608 |
+
|
609 |
+
is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
|
610 |
+
if is_extend:
|
611 |
+
is_repaint = True
|
612 |
+
|
613 |
+
# TODO: train a mask aware repainting controlnet
|
614 |
# to make sure mean = 0, std = 1
|
615 |
if not is_repaint:
|
616 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
617 |
+
elif not is_extend:
|
618 |
+
# if repaint_end_frame
|
619 |
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
620 |
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
|
621 |
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
622 |
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
|
623 |
z0 = repaint_noise
|
624 |
+
elif is_extend:
|
625 |
+
to_right_pad_gt_latents = None
|
626 |
+
to_left_pad_gt_latents = None
|
627 |
+
gt_latents = src_latents
|
628 |
+
src_latents_length = gt_latents.shape[-1]
|
629 |
+
max_infer_fame_length = int(240 * 44100 / 512 / 8)
|
630 |
+
left_pad_frame_length = 0
|
631 |
+
right_pad_frame_length = 0
|
632 |
+
right_trim_length = 0
|
633 |
+
left_trim_length = 0
|
634 |
+
if repaint_start_frame < 0:
|
635 |
+
left_pad_frame_length = abs(repaint_start_frame)
|
636 |
+
frame_length = left_pad_frame_length + gt_latents.shape[-1]
|
637 |
+
extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
|
638 |
+
if frame_length > max_infer_fame_length:
|
639 |
+
right_trim_length = frame_length - max_infer_fame_length
|
640 |
+
extend_gt_latents = extend_gt_latents[:,:,:,:max_infer_fame_length]
|
641 |
+
to_right_pad_gt_latents = extend_gt_latents[:,:,:,-right_trim_length:]
|
642 |
+
frame_length = max_infer_fame_length
|
643 |
+
repaint_start_frame = 0
|
644 |
+
gt_latents = extend_gt_latents
|
645 |
+
|
646 |
+
if repaint_end_frame > src_latents_length:
|
647 |
+
right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
|
648 |
+
frame_length = gt_latents.shape[-1] + right_pad_frame_length
|
649 |
+
extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
|
650 |
+
if frame_length > max_infer_fame_length:
|
651 |
+
left_trim_length = frame_length - max_infer_fame_length
|
652 |
+
extend_gt_latents = extend_gt_latents[:,:,:,-max_infer_fame_length:]
|
653 |
+
to_left_pad_gt_latents = extend_gt_latents[:,:,:,:left_trim_length]
|
654 |
+
frame_length = max_infer_fame_length
|
655 |
+
repaint_end_frame = frame_length
|
656 |
+
gt_latents = extend_gt_latents
|
657 |
+
|
658 |
+
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
|
659 |
+
if left_pad_frame_length > 0:
|
660 |
+
repaint_mask[:,:,:,:left_pad_frame_length] = 1.0
|
661 |
+
if right_pad_frame_length > 0:
|
662 |
+
repaint_mask[:,:,:,-right_pad_frame_length:] = 1.0
|
663 |
+
x0 = gt_latents
|
664 |
+
padd_list = []
|
665 |
+
if left_pad_frame_length > 0:
|
666 |
+
padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
|
667 |
+
padd_list.append(target_latents[:,:,:,left_trim_length:target_latents.shape[-1]-right_trim_length])
|
668 |
+
if right_pad_frame_length > 0:
|
669 |
+
padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
|
670 |
+
target_latents = torch.cat(padd_list, dim=-1)
|
671 |
+
assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
|
672 |
+
|
673 |
+
zt_edit = x0.clone()
|
674 |
+
z0 = target_latents
|
675 |
|
676 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
677 |
|
|
|
776 |
return sample
|
777 |
|
778 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
779 |
+
|
780 |
+
if is_repaint:
|
781 |
+
if i < n_min:
|
782 |
+
continue
|
783 |
+
elif i == n_min:
|
784 |
+
t_i = t / 1000
|
785 |
+
zt_src = (1 - t_i) * x0 + (t_i) * z0
|
786 |
+
target_latents = zt_edit + zt_src - x0
|
787 |
+
logger.info(f"repaint start from {n_min} add {t_i} level of noise")
|
788 |
+
|
789 |
# expand the latents if we are doing classifier free guidance
|
790 |
latents = target_latents
|
791 |
|
|
|
888 |
timestep=timestep,
|
889 |
).sample
|
890 |
|
891 |
+
if is_repaint and i >= n_min:
|
892 |
+
t_i = t/1000
|
893 |
+
if i+1 < len(timesteps):
|
894 |
+
t_im1 = (timesteps[i+1])/1000
|
895 |
+
else:
|
896 |
+
t_im1 = torch.zeros_like(t_i).to(t_i.device)
|
897 |
+
dtype = noise_pred.dtype
|
898 |
+
target_latents = target_latents.to(torch.float32)
|
899 |
+
prev_sample = target_latents + (t_im1 - t_i) * noise_pred
|
900 |
+
prev_sample = prev_sample.to(dtype)
|
901 |
+
target_latents = prev_sample
|
902 |
+
zt_src = (1 - t_im1) * x0 + (t_im1) * z0
|
903 |
+
target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
|
904 |
+
else:
|
905 |
+
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
906 |
|
907 |
+
if is_extend:
|
908 |
+
if to_right_pad_gt_latents is not None:
|
909 |
+
target_latents = torch.cate([target_latents, to_right_pad_gt_latents], dim=-1)
|
910 |
+
if to_left_pad_gt_latents is not None:
|
911 |
+
target_latents = torch.cate([to_right_pad_gt_latents, target_latents], dim=0)
|
912 |
return target_latents
|
913 |
|
914 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
|
982 |
save_path: str = None,
|
983 |
format: str = "flac",
|
984 |
batch_size: int = 1,
|
985 |
+
debug: bool = False,
|
986 |
):
|
987 |
|
988 |
start_time = time.time()
|
|
|
1020 |
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
1021 |
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
1022 |
if len(lyrics) > 0:
|
1023 |
+
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=debug)
|
1024 |
lyric_mask = [1] * len(lyric_token_idx)
|
1025 |
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
1026 |
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
|
|
1033 |
preprocess_time_cost = end_time - start_time
|
1034 |
start_time = end_time
|
1035 |
|
1036 |
+
add_retake_noise = task in ("retake", "repaint", "extend")
|
1037 |
# retake equal to repaint
|
1038 |
if task == "retake":
|
1039 |
repaint_start = 0
|
|
|
1041 |
|
1042 |
src_latents = None
|
1043 |
if src_audio_path is not None:
|
1044 |
+
assert src_audio_path is not None and task in ("repaint", "edit", "extend"), "src_audio_path is required for retake/repaint/extend task"
|
1045 |
assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
|
1046 |
src_latents = self.infer_latents(src_audio_path)
|
1047 |
|
|
|
1073 |
target_lyric_token_ids=target_lyric_token_idx,
|
1074 |
target_lyric_mask=target_lyric_mask,
|
1075 |
src_latents=src_latents,
|
1076 |
+
random_generators=retake_random_generators, # more diversity
|
1077 |
infer_steps=infer_step,
|
1078 |
guidance_scale=guidance_scale,
|
1079 |
n_min=edit_n_min,
|
|
|
1132 |
|
1133 |
input_params_json = {
|
1134 |
"task": task,
|
1135 |
+
"prompt": prompt if task != "edit" else edit_target_prompt,
|
1136 |
+
"lyrics": lyrics if task != "edit" else edit_target_lyrics,
|
1137 |
"audio_duration": audio_duration,
|
1138 |
"infer_step": infer_step,
|
1139 |
"guidance_scale": guidance_scale,
|
ui/components.py
CHANGED
@@ -65,7 +65,7 @@ def create_text2music_ui(
|
|
65 |
with gr.Column():
|
66 |
with gr.Row(equal_height=True):
|
67 |
# add markdown, tags and lyrics examples are from ai music generation community
|
68 |
-
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value
|
69 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
70 |
|
71 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
@@ -252,14 +252,15 @@ def create_text2music_ui(
|
|
252 |
with gr.Tab("edit"):
|
253 |
edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
|
254 |
edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
|
255 |
-
|
|
|
256 |
edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
|
257 |
-
edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.
|
258 |
edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
|
259 |
|
260 |
def edit_type_change_func(edit_type):
|
261 |
if edit_type == "only_lyrics":
|
262 |
-
n_min = 0.
|
263 |
n_max = 1.0
|
264 |
elif edit_type == "remix":
|
265 |
n_min = 0.2
|
@@ -309,6 +310,7 @@ def create_text2music_ui(
|
|
309 |
oss_steps,
|
310 |
guidance_scale_text,
|
311 |
guidance_scale_lyric,
|
|
|
312 |
):
|
313 |
if edit_source == "upload":
|
314 |
src_audio_path = edit_source_audio_upload
|
@@ -349,7 +351,8 @@ def create_text2music_ui(
|
|
349 |
edit_target_prompt=edit_prompt,
|
350 |
edit_target_lyrics=edit_lyrics,
|
351 |
edit_n_min=edit_n_min,
|
352 |
-
edit_n_max=edit_n_max
|
|
|
353 |
)
|
354 |
|
355 |
edit_bnt.click(
|
@@ -380,9 +383,121 @@ def create_text2music_ui(
|
|
380 |
oss_steps,
|
381 |
guidance_scale_text,
|
382 |
guidance_scale_lyric,
|
|
|
383 |
],
|
384 |
outputs=edit_outputs + [edit_input_params_json],
|
385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
def sample_data():
|
388 |
json_data = sample_data_func()
|
|
|
65 |
with gr.Column():
|
66 |
with gr.Row(equal_height=True):
|
67 |
# add markdown, tags and lyrics examples are from ai music generation community
|
68 |
+
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
69 |
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
70 |
|
71 |
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
|
|
|
252 |
with gr.Tab("edit"):
|
253 |
edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
|
254 |
edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
|
255 |
+
retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
|
256 |
+
|
257 |
edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
|
258 |
+
edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="edit_n_min", interactive=True)
|
259 |
edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
|
260 |
|
261 |
def edit_type_change_func(edit_type):
|
262 |
if edit_type == "only_lyrics":
|
263 |
+
n_min = 0.6
|
264 |
n_max = 1.0
|
265 |
elif edit_type == "remix":
|
266 |
n_min = 0.2
|
|
|
310 |
oss_steps,
|
311 |
guidance_scale_text,
|
312 |
guidance_scale_lyric,
|
313 |
+
retake_seeds,
|
314 |
):
|
315 |
if edit_source == "upload":
|
316 |
src_audio_path = edit_source_audio_upload
|
|
|
351 |
edit_target_prompt=edit_prompt,
|
352 |
edit_target_lyrics=edit_lyrics,
|
353 |
edit_n_min=edit_n_min,
|
354 |
+
edit_n_max=edit_n_max,
|
355 |
+
retake_seeds=retake_seeds,
|
356 |
)
|
357 |
|
358 |
edit_bnt.click(
|
|
|
383 |
oss_steps,
|
384 |
guidance_scale_text,
|
385 |
guidance_scale_lyric,
|
386 |
+
retake_seeds,
|
387 |
],
|
388 |
outputs=edit_outputs + [edit_input_params_json],
|
389 |
)
|
390 |
+
with gr.Tab("extend"):
|
391 |
+
extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None)
|
392 |
+
left_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Left Extend Length", interactive=True)
|
393 |
+
right_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Right Extend Length", interactive=True)
|
394 |
+
extend_source = gr.Radio(["text2music", "last_extend", "upload"], value="text2music", label="Extend Source", elem_id="extend_source")
|
395 |
+
|
396 |
+
extend_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="extend_source_audio_upload")
|
397 |
+
extend_source.change(
|
398 |
+
fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
|
399 |
+
inputs=[extend_source],
|
400 |
+
outputs=[extend_source_audio_upload],
|
401 |
+
)
|
402 |
+
|
403 |
+
extend_bnt = gr.Button("Extend", variant="primary")
|
404 |
+
extend_outputs, extend_input_params_json = create_output_ui("Extend")
|
405 |
+
|
406 |
+
def extend_process_func(
|
407 |
+
text2music_json_data,
|
408 |
+
extend_input_params_json,
|
409 |
+
extend_seeds,
|
410 |
+
left_extend_length,
|
411 |
+
right_extend_length,
|
412 |
+
extend_source,
|
413 |
+
extend_source_audio_upload,
|
414 |
+
prompt,
|
415 |
+
lyrics,
|
416 |
+
infer_step,
|
417 |
+
guidance_scale,
|
418 |
+
scheduler_type,
|
419 |
+
cfg_type,
|
420 |
+
omega_scale,
|
421 |
+
manual_seeds,
|
422 |
+
guidance_interval,
|
423 |
+
guidance_interval_decay,
|
424 |
+
min_guidance_scale,
|
425 |
+
use_erg_tag,
|
426 |
+
use_erg_lyric,
|
427 |
+
use_erg_diffusion,
|
428 |
+
oss_steps,
|
429 |
+
guidance_scale_text,
|
430 |
+
guidance_scale_lyric,
|
431 |
+
):
|
432 |
+
if extend_source == "upload":
|
433 |
+
src_audio_path = extend_source_audio_upload
|
434 |
+
json_data = text2music_json_data
|
435 |
+
elif extend_source == "text2music":
|
436 |
+
json_data = text2music_json_data
|
437 |
+
src_audio_path = json_data["audio_path"]
|
438 |
+
elif extend_source == "last_repaint":
|
439 |
+
json_data = extend_input_params_json
|
440 |
+
src_audio_path = json_data["audio_path"]
|
441 |
+
|
442 |
+
repaint_start = -left_extend_length
|
443 |
+
repaint_end = json_data["audio_duration"] + right_extend_length
|
444 |
+
return text2music_process_func(
|
445 |
+
json_data["audio_duration"],
|
446 |
+
prompt,
|
447 |
+
lyrics,
|
448 |
+
infer_step,
|
449 |
+
guidance_scale,
|
450 |
+
scheduler_type,
|
451 |
+
cfg_type,
|
452 |
+
omega_scale,
|
453 |
+
manual_seeds,
|
454 |
+
guidance_interval,
|
455 |
+
guidance_interval_decay,
|
456 |
+
min_guidance_scale,
|
457 |
+
use_erg_tag,
|
458 |
+
use_erg_lyric,
|
459 |
+
use_erg_diffusion,
|
460 |
+
oss_steps,
|
461 |
+
guidance_scale_text,
|
462 |
+
guidance_scale_lyric,
|
463 |
+
retake_seeds=extend_seeds,
|
464 |
+
retake_variance=1.0,
|
465 |
+
task="extend",
|
466 |
+
repaint_start=repaint_start,
|
467 |
+
repaint_end=repaint_end,
|
468 |
+
src_audio_path=src_audio_path,
|
469 |
+
)
|
470 |
+
|
471 |
+
extend_bnt.click(
|
472 |
+
fn=extend_process_func,
|
473 |
+
inputs=[
|
474 |
+
input_params_json,
|
475 |
+
extend_input_params_json,
|
476 |
+
extend_seeds,
|
477 |
+
left_extend_length,
|
478 |
+
right_extend_length,
|
479 |
+
extend_source,
|
480 |
+
extend_source_audio_upload,
|
481 |
+
prompt,
|
482 |
+
lyrics,
|
483 |
+
infer_step,
|
484 |
+
guidance_scale,
|
485 |
+
scheduler_type,
|
486 |
+
cfg_type,
|
487 |
+
omega_scale,
|
488 |
+
manual_seeds,
|
489 |
+
guidance_interval,
|
490 |
+
guidance_interval_decay,
|
491 |
+
min_guidance_scale,
|
492 |
+
use_erg_tag,
|
493 |
+
use_erg_lyric,
|
494 |
+
use_erg_diffusion,
|
495 |
+
oss_steps,
|
496 |
+
guidance_scale_text,
|
497 |
+
guidance_scale_lyric,
|
498 |
+
],
|
499 |
+
outputs=extend_outputs + [extend_input_params_json],
|
500 |
+
)
|
501 |
|
502 |
def sample_data():
|
503 |
json_data = sample_data_func()
|