Sayoyo commited on
Commit
71922e7
·
1 Parent(s): 0e834f7

[fix] fix some bugs

Browse files
.gitignore CHANGED
@@ -1,3 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
@@ -7,6 +20,7 @@ __pycache__/
7
  *.so
8
 
9
  # Distribution / packaging
 
10
  .Python
11
  build/
12
  develop-eggs/
@@ -99,17 +113,15 @@ ipython_config.py
99
  # This is especially recommended for binary packages to ensure reproducibility, and is more
100
  # commonly ignored for libraries.
101
  # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- poetry.lock
103
 
104
  # pdm
105
  # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
  #pdm.lock
107
  # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
  # in version control.
109
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
  .pdm.toml
111
- .pdm-python
112
- .pdm-build/
113
 
114
  # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
  __pypackages__/
@@ -160,5 +172,31 @@ cython_debug/
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
-
164
- checkpoints/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pt
2
+ *.ckpt
3
+ *.onnx
4
+ t5_g2p_model/
5
+ embeddings/
6
+ checkpoints/
7
+
8
+ val_images/
9
+ val_audios/
10
+ lightning_logs/
11
+ lightning_logs_/
12
+ train_images/
13
+ train_audios/
14
  # Byte-compiled / optimized / DLL files
15
  __pycache__/
16
  *.py[cod]
 
20
  *.so
21
 
22
  # Distribution / packaging
23
+ .idea/
24
  .Python
25
  build/
26
  develop-eggs/
 
113
  # This is especially recommended for binary packages to ensure reproducibility, and is more
114
  # commonly ignored for libraries.
115
  # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
116
+ #poetry.lock
117
 
118
  # pdm
119
  # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
120
  #pdm.lock
121
  # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
122
  # in version control.
123
+ # https://pdm.fming.dev/#use-with-ide
124
  .pdm.toml
 
 
125
 
126
  # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
127
  __pypackages__/
 
172
  # and can be added to the global gitignore or merged into this file. For a more nuclear
173
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
174
  #.idea/
175
+ *.txt
176
+ !requirements.txt
177
+ *.log
178
+ *.flac
179
+ minio_config.yaml
180
+ .history/*
181
+ __pycache__/*
182
+ train.log
183
+ *.mp3
184
+ *.tar.gz
185
+ __pycache__/
186
+ demo_examples/
187
+ nohup.out
188
+ test_results/*
189
+ nohup.out
190
+ text_audio_align/*
191
+ remote/*
192
+ MG2P/*
193
+ audio_getter.py
194
+ refiner_loss_debug/
195
+ outputs/*
196
+ !outputs/
197
+ save_checkpoint.ipynb
198
+ repos/*
199
+ app_demo.py
200
+ ui/components_demo.py
201
+ data_sampler_demo.py
202
+ pipeline_ace_step_demo.py
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [2025] Timedomain Inc. and stepfun
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
apg_guidance.py CHANGED
@@ -17,7 +17,10 @@ def project(
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
- v0, v1 = v0.double(), v1.double()
 
 
 
21
  v1 = torch.nn.functional.normalize(v1, dim=dims)
22
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
23
  v0_orthogonal = v0 - v0_parallel
@@ -53,6 +56,7 @@ def apg_forward(
53
  def cfg_forward(cond_output, uncond_output, cfg_strength):
54
  return uncond_output + cfg_strength * (cond_output - uncond_output)
55
 
 
56
  def cfg_double_condition_forward(
57
  cond_output,
58
  uncond_output,
 
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
+ if v0.device.type == "mps":
21
+ v0, v1 = v0.float(), v1.float()
22
+ else:
23
+ v0, v1 = v0.double(), v1.double()
24
  v1 = torch.nn.functional.normalize(v1, dim=dims)
25
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
26
  v0_orthogonal = v0 - v0_parallel
 
56
  def cfg_forward(cond_output, uncond_output, cfg_strength):
57
  return uncond_output + cfg_strength * (cond_output - uncond_output)
58
 
59
+
60
  def cfg_double_condition_forward(
61
  cond_output,
62
  uncond_output,
music_dcae/music_dcae_pipeline.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from diffusers import AutoencoderDC
4
  import torchaudio
5
  import torchvision.transforms as transforms
6
- import torchaudio
7
  from diffusers.models.modeling_utils import ModelMixin
8
  from diffusers.loaders import FromOriginalModelMixin
9
  from diffusers.configuration_utils import ConfigMixin, register_to_config
@@ -30,7 +29,7 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
30
 
31
  if source_sample_rate is None:
32
  source_sample_rate = 48000
33
-
34
  self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
35
 
36
  self.transform = transforms.Compose([
@@ -95,29 +94,21 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
95
  def decode(self, latents, audio_lengths=None, sr=None):
96
  latents = latents / self.scale_factor + self.shift_factor
97
 
98
- mels = []
99
 
100
  for latent in latents:
101
- mel = self.dcae.decoder(latent.unsqueeze(0))
102
- mels.append(mel)
103
- mels = torch.cat(mels, dim=0)
104
-
105
- mels = mels * 0.5 + 0.5
106
- mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
107
- bsz, channels, num_mel, mel_width = mels.shape
108
- pred_wavs = []
109
- for i in range(bsz):
110
- mel = mels[i]
111
- wav = self.vocoder.decode(mel).squeeze(1)
112
  pred_wavs.append(wav)
113
 
114
- pred_wavs = torch.stack(pred_wavs)
115
-
116
- if sr is not None:
117
- resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
118
- pred_wavs = [resampler(wav) for wav in pred_wavs]
119
- else:
120
- sr = 44100
121
  if audio_lengths is not None:
122
  pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
123
  return sr, pred_wavs
 
3
  from diffusers import AutoencoderDC
4
  import torchaudio
5
  import torchvision.transforms as transforms
 
6
  from diffusers.models.modeling_utils import ModelMixin
7
  from diffusers.loaders import FromOriginalModelMixin
8
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
29
 
30
  if source_sample_rate is None:
31
  source_sample_rate = 48000
32
+
33
  self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
34
 
35
  self.transform = transforms.Compose([
 
94
  def decode(self, latents, audio_lengths=None, sr=None):
95
  latents = latents / self.scale_factor + self.shift_factor
96
 
97
+ pred_wavs = []
98
 
99
  for latent in latents:
100
+ mels = self.dcae.decoder(latent.unsqueeze(0))
101
+ mels = mels * 0.5 + 0.5
102
+ mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
103
+ wav = self.vocoder.decode(mels[0]).squeeze(1)
104
+
105
+ if sr is not None:
106
+ resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
107
+ wav = resampler(wav)
108
+ else:
109
+ sr = 44100
 
110
  pred_wavs.append(wav)
111
 
 
 
 
 
 
 
 
112
  if audio_lengths is not None:
113
  pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
114
  return sr, pred_wavs
pipeline_ace_step.py CHANGED
@@ -11,6 +11,7 @@ import json
11
  import math
12
  from huggingface_hub import hf_hub_download
13
 
 
14
  from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
15
  from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
16
  from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
@@ -29,6 +30,7 @@ torch.backends.cudnn.benchmark = False
29
  torch.set_float32_matmul_precision('high')
30
  torch.backends.cudnn.deterministic = True
31
  torch.backends.cuda.matmul.allow_tf32 = True
 
32
 
33
 
34
  SUPPORT_LANGUAGES = {
@@ -54,7 +56,7 @@ REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
54
  class ACEStepPipeline:
55
 
56
  def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, **kwargs):
57
- if checkpoint_dir is None:
58
  if persistent_storage_path is None:
59
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
60
  else:
@@ -63,7 +65,11 @@ class ACEStepPipeline:
63
 
64
  self.checkpoint_dir = checkpoint_dir
65
  device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu")
 
 
66
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
 
 
67
  self.device = device
68
  self.loaded = False
69
  self.torch_compile = torch_compile
@@ -620,6 +626,7 @@ class ACEStepPipeline:
620
  repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
621
  repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
622
  repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
 
623
  z0 = repaint_noise
624
  elif is_extend:
625
  to_right_pad_gt_latents = None
@@ -669,9 +676,8 @@ class ACEStepPipeline:
669
  padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
670
  target_latents = torch.cat(padd_list, dim=-1)
671
  assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
672
-
673
- zt_edit = x0.clone()
674
- z0 = target_latents
675
 
676
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
677
 
@@ -774,7 +780,7 @@ class ACEStepPipeline:
774
  hook.remove()
775
 
776
  return sample
777
-
778
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
779
 
780
  if is_repaint:
 
11
  import math
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
  from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
16
  from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
17
  from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
 
30
  torch.set_float32_matmul_precision('high')
31
  torch.backends.cudnn.deterministic = True
32
  torch.backends.cuda.matmul.allow_tf32 = True
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
 
35
 
36
  SUPPORT_LANGUAGES = {
 
56
  class ACEStepPipeline:
57
 
58
  def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, **kwargs):
59
+ if not checkpoint_dir:
60
  if persistent_storage_path is None:
61
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
62
  else:
 
65
 
66
  self.checkpoint_dir = checkpoint_dir
67
  device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu")
68
+ if device.type == "cpu" and torch.backends.mps.is_available():
69
+ device = torch.device("mps")
70
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
71
+ if device.type == "mps" and self.dtype == torch.bfloat16:
72
+ self.dtype = torch.float16
73
  self.device = device
74
  self.loaded = False
75
  self.torch_compile = torch_compile
 
626
  repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
627
  repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
628
  repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
629
+ zt_edit = x0.clone()
630
  z0 = repaint_noise
631
  elif is_extend:
632
  to_right_pad_gt_latents = None
 
676
  padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
677
  target_latents = torch.cat(padd_list, dim=-1)
678
  assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
679
+ zt_edit = x0.clone()
680
+ z0 = target_latents
 
681
 
682
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
683
 
 
780
  hook.remove()
781
 
782
  return sample
783
+
784
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
785
 
786
  if is_repaint:
ui/components.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
 
3
 
4
  TAG_PLACEHOLDER = "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic"
@@ -67,25 +68,25 @@ def create_text2music_ui(
67
  # add markdown, tags and lyrics examples are from ai music generation community
68
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
69
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
70
-
71
  prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
72
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
73
 
74
  with gr.Accordion("Basic Settings", open=False):
75
- infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
76
  guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
77
- guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=5.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
78
- guidance_scale_lyric = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.5, label="Guidance Scale Lyric", interactive=True)
79
 
80
  manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None, info="Seed for the generation")
81
-
82
  with gr.Accordion("Advanced Settings", open=False):
83
  scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type", info="Scheduler type for the generation. euler is recommended. heun will take more time.")
84
  cfg_type = gr.Radio(["cfg", "apg", "cfg_star"], value="apg", label="CFG Type", elem_id="cfg_type", info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.")
85
  use_erg_tag = gr.Checkbox(label="use ERG for tag", value=True, info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.")
86
  use_erg_lyric = gr.Checkbox(label="use ERG for lyric", value=True, info="The same but apply to lyric encoder's attention.")
87
  use_erg_diffusion = gr.Checkbox(label="use ERG for diffusion", value=True, info="The same but apply to diffusion model's attention.")
88
-
89
  omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True, info="Granularity scale for the generation. Higher values can reduce artifacts")
90
 
91
  guidance_interval = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Guidance Interval", interactive=True, info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)")
@@ -102,7 +103,7 @@ def create_text2music_ui(
102
  retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
103
  retake_bnt = gr.Button("Retake", variant="primary")
104
  retake_outputs, retake_input_params_json = create_output_ui("Retake")
105
-
106
  def retake_process_func(json_data, retake_variance, retake_seeds):
107
  return text2music_process_func(
108
  json_data["audio_duration"],
@@ -143,7 +144,7 @@ def create_text2music_ui(
143
  repaint_start = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Repaint Start Time", interactive=True)
144
  repaint_end = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Repaint End Time", interactive=True)
145
  repaint_source = gr.Radio(["text2music", "last_repaint", "upload"], value="text2music", label="Repaint Source", elem_id="repaint_source")
146
-
147
  repaint_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="repaint_source_audio_upload")
148
  repaint_source.change(
149
  fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"),
@@ -153,7 +154,7 @@ def create_text2music_ui(
153
 
154
  repaint_bnt = gr.Button("Repaint", variant="primary")
155
  repaint_outputs, repaint_input_params_json = create_output_ui("Repaint")
156
-
157
  def repaint_process_func(
158
  text2music_json_data,
159
  repaint_json_data,
@@ -183,7 +184,10 @@ def create_text2music_ui(
183
  ):
184
  if repaint_source == "upload":
185
  src_audio_path = repaint_source_audio_upload
186
- json_data = text2music_json_data
 
 
 
187
  elif repaint_source == "text2music":
188
  json_data = text2music_json_data
189
  src_audio_path = json_data["audio_path"]
@@ -217,7 +221,7 @@ def create_text2music_ui(
217
  repaint_end=repaint_end,
218
  src_audio_path=src_audio_path,
219
  )
220
-
221
  repaint_bnt.click(
222
  fn=repaint_process_func,
223
  inputs=[
@@ -253,11 +257,11 @@ def create_text2music_ui(
253
  edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
254
  edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
255
  retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
256
-
257
  edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
258
  edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="edit_n_min", interactive=True)
259
  edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
260
-
261
  def edit_type_change_func(edit_type):
262
  if edit_type == "only_lyrics":
263
  n_min = 0.6
@@ -266,7 +270,7 @@ def create_text2music_ui(
266
  n_min = 0.2
267
  n_max = 0.4
268
  return n_min, n_max
269
-
270
  edit_type.change(
271
  edit_type_change_func,
272
  inputs=[edit_type],
@@ -283,7 +287,7 @@ def create_text2music_ui(
283
 
284
  edit_bnt = gr.Button("Edit", variant="primary")
285
  edit_outputs, edit_input_params_json = create_output_ui("Edit")
286
-
287
  def edit_process_func(
288
  text2music_json_data,
289
  edit_input_params_json,
@@ -314,7 +318,10 @@ def create_text2music_ui(
314
  ):
315
  if edit_source == "upload":
316
  src_audio_path = edit_source_audio_upload
317
- json_data = text2music_json_data
 
 
 
318
  elif edit_source == "text2music":
319
  json_data = text2music_json_data
320
  src_audio_path = json_data["audio_path"]
@@ -354,7 +361,7 @@ def create_text2music_ui(
354
  edit_n_max=edit_n_max,
355
  retake_seeds=retake_seeds,
356
  )
357
-
358
  edit_bnt.click(
359
  fn=edit_process_func,
360
  inputs=[
@@ -392,7 +399,7 @@ def create_text2music_ui(
392
  left_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Left Extend Length", interactive=True)
393
  right_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Right Extend Length", interactive=True)
394
  extend_source = gr.Radio(["text2music", "last_extend", "upload"], value="text2music", label="Extend Source", elem_id="extend_source")
395
-
396
  extend_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="extend_source_audio_upload")
397
  extend_source.change(
398
  fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
@@ -402,7 +409,7 @@ def create_text2music_ui(
402
 
403
  extend_bnt = gr.Button("Extend", variant="primary")
404
  extend_outputs, extend_input_params_json = create_output_ui("Extend")
405
-
406
  def extend_process_func(
407
  text2music_json_data,
408
  extend_input_params_json,
@@ -431,11 +438,15 @@ def create_text2music_ui(
431
  ):
432
  if extend_source == "upload":
433
  src_audio_path = extend_source_audio_upload
434
- json_data = text2music_json_data
 
 
 
 
435
  elif extend_source == "text2music":
436
  json_data = text2music_json_data
437
  src_audio_path = json_data["audio_path"]
438
- elif extend_source == "last_repaint":
439
  json_data = extend_input_params_json
440
  src_audio_path = json_data["audio_path"]
441
 
@@ -467,7 +478,7 @@ def create_text2music_ui(
467
  repaint_end=repaint_end,
468
  src_audio_path=src_audio_path,
469
  )
470
-
471
  extend_bnt.click(
472
  fn=extend_process_func,
473
  inputs=[
@@ -521,7 +532,7 @@ def create_text2music_ui(
521
  json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
522
  json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
523
  )
524
-
525
  sample_bnt.click(
526
  sample_data,
527
  outputs=[
 
1
  import gradio as gr
2
+ import librosa
3
 
4
 
5
  TAG_PLACEHOLDER = "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic"
 
68
  # add markdown, tags and lyrics examples are from ai music generation community
69
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=-1, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
70
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
71
+
72
  prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
73
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
74
 
75
  with gr.Accordion("Basic Settings", open=False):
76
+ infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=27, label="Infer Steps", interactive=True)
77
  guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
78
+ guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=0.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
79
+ guidance_scale_lyric = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=0.0, label="Guidance Scale Lyric", interactive=True)
80
 
81
  manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None, info="Seed for the generation")
82
+
83
  with gr.Accordion("Advanced Settings", open=False):
84
  scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type", info="Scheduler type for the generation. euler is recommended. heun will take more time.")
85
  cfg_type = gr.Radio(["cfg", "apg", "cfg_star"], value="apg", label="CFG Type", elem_id="cfg_type", info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.")
86
  use_erg_tag = gr.Checkbox(label="use ERG for tag", value=True, info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.")
87
  use_erg_lyric = gr.Checkbox(label="use ERG for lyric", value=True, info="The same but apply to lyric encoder's attention.")
88
  use_erg_diffusion = gr.Checkbox(label="use ERG for diffusion", value=True, info="The same but apply to diffusion model's attention.")
89
+
90
  omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True, info="Granularity scale for the generation. Higher values can reduce artifacts")
91
 
92
  guidance_interval = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Guidance Interval", interactive=True, info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)")
 
103
  retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
104
  retake_bnt = gr.Button("Retake", variant="primary")
105
  retake_outputs, retake_input_params_json = create_output_ui("Retake")
106
+
107
  def retake_process_func(json_data, retake_variance, retake_seeds):
108
  return text2music_process_func(
109
  json_data["audio_duration"],
 
144
  repaint_start = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Repaint Start Time", interactive=True)
145
  repaint_end = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Repaint End Time", interactive=True)
146
  repaint_source = gr.Radio(["text2music", "last_repaint", "upload"], value="text2music", label="Repaint Source", elem_id="repaint_source")
147
+
148
  repaint_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="repaint_source_audio_upload")
149
  repaint_source.change(
150
  fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"),
 
154
 
155
  repaint_bnt = gr.Button("Repaint", variant="primary")
156
  repaint_outputs, repaint_input_params_json = create_output_ui("Repaint")
157
+
158
  def repaint_process_func(
159
  text2music_json_data,
160
  repaint_json_data,
 
184
  ):
185
  if repaint_source == "upload":
186
  src_audio_path = repaint_source_audio_upload
187
+ audio_duration = librosa.get_duration(filename=src_audio_path)
188
+ json_data = {
189
+ "audio_duration": audio_duration
190
+ }
191
  elif repaint_source == "text2music":
192
  json_data = text2music_json_data
193
  src_audio_path = json_data["audio_path"]
 
221
  repaint_end=repaint_end,
222
  src_audio_path=src_audio_path,
223
  )
224
+
225
  repaint_bnt.click(
226
  fn=repaint_process_func,
227
  inputs=[
 
257
  edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
258
  edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
259
  retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None)
260
+
261
  edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
262
  edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="edit_n_min", interactive=True)
263
  edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
264
+
265
  def edit_type_change_func(edit_type):
266
  if edit_type == "only_lyrics":
267
  n_min = 0.6
 
270
  n_min = 0.2
271
  n_max = 0.4
272
  return n_min, n_max
273
+
274
  edit_type.change(
275
  edit_type_change_func,
276
  inputs=[edit_type],
 
287
 
288
  edit_bnt = gr.Button("Edit", variant="primary")
289
  edit_outputs, edit_input_params_json = create_output_ui("Edit")
290
+
291
  def edit_process_func(
292
  text2music_json_data,
293
  edit_input_params_json,
 
318
  ):
319
  if edit_source == "upload":
320
  src_audio_path = edit_source_audio_upload
321
+ audio_duration = librosa.get_duration(filename=src_audio_path)
322
+ json_data = {
323
+ "audio_duration": audio_duration
324
+ }
325
  elif edit_source == "text2music":
326
  json_data = text2music_json_data
327
  src_audio_path = json_data["audio_path"]
 
361
  edit_n_max=edit_n_max,
362
  retake_seeds=retake_seeds,
363
  )
364
+
365
  edit_bnt.click(
366
  fn=edit_process_func,
367
  inputs=[
 
399
  left_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Left Extend Length", interactive=True)
400
  right_extend_length = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Right Extend Length", interactive=True)
401
  extend_source = gr.Radio(["text2music", "last_extend", "upload"], value="text2music", label="Extend Source", elem_id="extend_source")
402
+
403
  extend_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="extend_source_audio_upload")
404
  extend_source.change(
405
  fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"),
 
409
 
410
  extend_bnt = gr.Button("Extend", variant="primary")
411
  extend_outputs, extend_input_params_json = create_output_ui("Extend")
412
+
413
  def extend_process_func(
414
  text2music_json_data,
415
  extend_input_params_json,
 
438
  ):
439
  if extend_source == "upload":
440
  src_audio_path = extend_source_audio_upload
441
+ # get audio duration
442
+ audio_duration = librosa.get_duration(filename=src_audio_path)
443
+ json_data = {
444
+ "audio_duration": audio_duration
445
+ }
446
  elif extend_source == "text2music":
447
  json_data = text2music_json_data
448
  src_audio_path = json_data["audio_path"]
449
+ elif extend_source == "last_extend":
450
  json_data = extend_input_params_json
451
  src_audio_path = json_data["audio_path"]
452
 
 
478
  repaint_end=repaint_end,
479
  src_audio_path=src_audio_path,
480
  )
481
+
482
  extend_bnt.click(
483
  fn=extend_process_func,
484
  inputs=[
 
532
  json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
533
  json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
534
  )
535
+
536
  sample_bnt.click(
537
  sample_data,
538
  outputs=[