ing0 commited on
Commit
07e6e48
·
1 Parent(s): 18ea198

Extra mode

Browse files
app.py CHANGED
@@ -18,7 +18,8 @@ import base64
18
  from diffrhythm.infer.infer_utils import (
19
  get_reference_latent,
20
  get_lrc_token,
21
- get_style_prompt,
 
22
  prepare_model,
23
  get_negative_style_prompt
24
  )
@@ -29,16 +30,19 @@ device='cuda'
29
  cfm, tokenizer, muq, vae = prepare_model(device)
30
  cfm = torch.compile(cfm)
31
 
32
- @spaces.GPU(duration=20)
33
- def infer_music(lrc, ref_audio_path, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
34
 
35
  if randomize_seed:
36
  seed = random.randint(0, MAX_SEED)
37
  torch.manual_seed(seed)
38
  sway_sampling_coef = -1 if steps < 32 else None
 
39
  try:
40
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
41
- style_prompt = get_style_prompt(muq, ref_audio_path)
 
 
 
42
  except Exception as e:
43
  raise gr.Error(f"Error: {str(e)}")
44
  negative_style_prompt = get_negative_style_prompt(device)
@@ -53,7 +57,8 @@ def infer_music(lrc, ref_audio_path, seed=42, randomize_seed=False, steps=32, fi
53
  steps=steps,
54
  sway_sampling_coef=sway_sampling_coef,
55
  start_time=start_time,
56
- file_type=file_type
 
57
  )
58
  return generated_song
59
 
@@ -179,7 +184,23 @@ with gr.Blocks(css=css) as demo:
179
  elem_classes="lyrics-scroll-box",
180
  value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
181
  )
182
- audio_prompt = gr.Audio(label="Audio Prompt", type="filepath", value="./src/prompt/default.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  with gr.Column():
185
  with gr.Accordion("Best Practices Guide", open=True):
@@ -218,7 +239,7 @@ with gr.Blocks(css=css) as demo:
218
  steps = gr.Slider(
219
  minimum=10,
220
  maximum=100,
221
- value=32,
222
  step=1,
223
  label="Diffusion Steps",
224
  interactive=True,
@@ -248,6 +269,19 @@ with gr.Blocks(css=css) as demo:
248
  examples_per_page=13,
249
  elem_id="audio-examples-container"
250
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  gr.Examples(
253
  examples=[
@@ -352,7 +386,7 @@ with gr.Blocks(css=css) as demo:
352
 
353
  lyrics_btn.click(
354
  fn=infer_music,
355
- inputs=[lrc, audio_prompt, seed, randomize_seed, steps, file_type],
356
  outputs=audio_output
357
  )
358
 
 
18
  from diffrhythm.infer.infer_utils import (
19
  get_reference_latent,
20
  get_lrc_token,
21
+ get_audio_style_prompt,
22
+ get_text_style_prompt,
23
  prepare_model,
24
  get_negative_style_prompt
25
  )
 
30
  cfm, tokenizer, muq, vae = prepare_model(device)
31
  cfm = torch.compile(cfm)
32
 
33
+ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
 
34
 
35
  if randomize_seed:
36
  seed = random.randint(0, MAX_SEED)
37
  torch.manual_seed(seed)
38
  sway_sampling_coef = -1 if steps < 32 else None
39
+ vocal_flag = False
40
  try:
41
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
42
+ if current_prompt_type == 'audio':
43
+ style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
44
+ else:
45
+ style_prompt = get_text_style_prompt(muq, text_prompt)
46
  except Exception as e:
47
  raise gr.Error(f"Error: {str(e)}")
48
  negative_style_prompt = get_negative_style_prompt(device)
 
57
  steps=steps,
58
  sway_sampling_coef=sway_sampling_coef,
59
  start_time=start_time,
60
+ file_type=file_type,
61
+ vocal_flag=vocal_flag
62
  )
63
  return generated_song
64
 
 
184
  elem_classes="lyrics-scroll-box",
185
  value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
186
  )
187
+
188
+ current_prompt_type = gr.State(value="audio")
189
+ with gr.Tabs() as inside_tabs:
190
+ with gr.Tab("Audio Prompt"):
191
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath", value="./src/prompt/default.wav")
192
+ with gr.Tab("Text Prompt"):
193
+ text_prompt = gr.Textbox(
194
+ label="Text Prompt",
195
+ placeholder="Enter the Text Prompt, eg: emotional piano pop",
196
+ )
197
+ def update_prompt_type(evt: gr.SelectData):
198
+ return "audio" if evt.index == 0 else "text"
199
+
200
+ inside_tabs.select(
201
+ fn=update_prompt_type,
202
+ outputs=current_prompt_type
203
+ )
204
 
205
  with gr.Column():
206
  with gr.Accordion("Best Practices Guide", open=True):
 
239
  steps = gr.Slider(
240
  minimum=10,
241
  maximum=100,
242
+ value=32,
243
  step=1,
244
  label="Diffusion Steps",
245
  interactive=True,
 
269
  examples_per_page=13,
270
  elem_id="audio-examples-container"
271
  )
272
+
273
+ gr.Examples(
274
+ examples=[
275
+ ["Pop Emotional Piano"],
276
+ ["流行 情感 钢琴"],
277
+ ["Indie folk ballad, coming-of-age themes, acoustic guitar picking with harmonica interludes"],
278
+ ["独立民谣, 成长主题, 原声吉他弹奏与口琴间奏"]
279
+ ],
280
+ inputs=[text_prompt],
281
+ label="Text Examples",
282
+ examples_per_page=4,
283
+ elem_id="text-examples-container"
284
+ )
285
 
286
  gr.Examples(
287
  examples=[
 
386
 
387
  lyrics_btn.click(
388
  fn=infer_music,
389
+ inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, file_type],
390
  outputs=audio_output
391
  )
392
 
diffrhythm/infer/infer.py CHANGED
@@ -14,7 +14,7 @@ import pydub
14
  from diffrhythm.infer.infer_utils import (
15
  get_reference_latent,
16
  get_lrc_token,
17
- get_style_prompt,
18
  prepare_model,
19
  get_negative_style_prompt
20
  )
@@ -74,7 +74,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
74
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
  return y_final
76
 
77
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
78
 
79
  with torch.inference_mode():
80
  generated, _ = cfm_model.sample(
@@ -86,7 +86,8 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
86
  steps=steps,
87
  cfg_strength=4.0,
88
  sway_sampling_coef=sway_sampling_coef,
89
- start_time=start_time
 
90
  )
91
 
92
  generated = generated.to(torch.float32)
@@ -133,7 +134,7 @@ if __name__ == "__main__":
133
  lrc = f.read()
134
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
135
 
136
- style_prompt = get_style_prompt(muq, args.ref_audio_path)
137
 
138
  negative_style_prompt = get_negative_style_prompt(device)
139
 
 
14
  from diffrhythm.infer.infer_utils import (
15
  get_reference_latent,
16
  get_lrc_token,
17
+ get_audio_style_prompt,
18
  prepare_model,
19
  get_negative_style_prompt
20
  )
 
74
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
  return y_final
76
 
77
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type, vocal_flag):
78
 
79
  with torch.inference_mode():
80
  generated, _ = cfm_model.sample(
 
86
  steps=steps,
87
  cfg_strength=4.0,
88
  sway_sampling_coef=sway_sampling_coef,
89
+ start_time=start_time,
90
+ vocal_flag=vocal_flag,
91
  )
92
 
93
  generated = generated.to(torch.float32)
 
134
  lrc = f.read()
135
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
136
 
137
+ style_prompt = get_audio_style_prompt(muq, args.ref_audio_path)
138
 
139
  negative_style_prompt = get_negative_style_prompt(device)
140
 
diffrhythm/infer/infer_utils.py CHANGED
@@ -51,13 +51,14 @@ def get_negative_style_prompt(device):
51
 
52
  return vocal_stlye
53
 
54
- def get_style_prompt(model, wav_path):
 
55
  mulan = model
56
  audio, _ = librosa.load(wav_path, sr=24000)
57
  audio_len = librosa.get_duration(y=audio, sr=24000)
58
 
59
-
60
- assert audio_len >= 1, "Input audio length shorter than 1 second"
61
 
62
  if audio_len > 10:
63
  start_time = int(audio_len // 2 - 5)
@@ -70,10 +71,20 @@ def get_style_prompt(model, wav_path):
70
  with torch.no_grad():
71
  audio_emb = mulan(wavs = wav) # [1, 512]
72
 
73
- audio_emb = audio_emb
74
  audio_emb = audio_emb.half()
75
 
76
- return audio_emb
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def parse_lyrics(lyrics: str):
79
  lyrics_with_time = []
@@ -94,7 +105,6 @@ class CNENTokenizer():
94
  with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file:
95
  self.phone2id:dict = json.load(file)['vocab']
96
  self.id2phone = {v:k for (k, v) in self.phone2id.items()}
97
- # from f5_tts.g2p.g2p_generation import chn_eng_g2p
98
  from diffrhythm.g2p.g2p_generation import chn_eng_g2p
99
  self.tokenizer = chn_eng_g2p
100
  def encode(self, text):
@@ -115,6 +125,8 @@ def get_lrc_token(text, tokenizer, device):
115
  pad_token_id = 0
116
  comma_token_id = 1
117
  period_token_id = 2
 
 
118
 
119
  lrc_with_time = parse_lyrics(text)
120
 
@@ -146,7 +158,7 @@ def get_lrc_token(text, tokenizer, device):
146
  frame_start = max(gt_frame_start - frame_shift, last_end_pos)
147
  frame_len = min(num_tokens, max_frames - frame_start)
148
 
149
- #print(gt_frame_start, frame_shift, frame_start, frame_len, tokens_count, last_end_pos, full_pos_emb.shape)
150
 
151
  lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
152
 
 
51
 
52
  return vocal_stlye
53
 
54
+ def get_audio_style_prompt(model, wav_path):
55
+ vocal_flag = False
56
  mulan = model
57
  audio, _ = librosa.load(wav_path, sr=24000)
58
  audio_len = librosa.get_duration(y=audio, sr=24000)
59
 
60
+ if audio_len <= 1:
61
+ vocal_flag = True
62
 
63
  if audio_len > 10:
64
  start_time = int(audio_len // 2 - 5)
 
71
  with torch.no_grad():
72
  audio_emb = mulan(wavs = wav) # [1, 512]
73
 
 
74
  audio_emb = audio_emb.half()
75
 
76
+ return audio_emb, vocal_flag
77
+
78
+ def get_text_style_prompt(model, text_prompt):
79
+ mulan = model
80
+
81
+ with torch.no_grad():
82
+ text_emb = mulan(texts = text_prompt) # [1, 512]
83
+ text_emb = text_emb.half()
84
+
85
+ return text_emb
86
+
87
+
88
 
89
  def parse_lyrics(lyrics: str):
90
  lyrics_with_time = []
 
105
  with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file:
106
  self.phone2id:dict = json.load(file)['vocab']
107
  self.id2phone = {v:k for (k, v) in self.phone2id.items()}
 
108
  from diffrhythm.g2p.g2p_generation import chn_eng_g2p
109
  self.tokenizer = chn_eng_g2p
110
  def encode(self, text):
 
125
  pad_token_id = 0
126
  comma_token_id = 1
127
  period_token_id = 2
128
+ if text == "":
129
+ return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
130
 
131
  lrc_with_time = parse_lyrics(text)
132
 
 
158
  frame_start = max(gt_frame_start - frame_shift, last_end_pos)
159
  frame_len = min(num_tokens, max_frames - frame_start)
160
 
161
+
162
 
163
  lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
164
 
diffrhythm/model/cfm.py CHANGED
@@ -42,10 +42,7 @@ class CFM(nn.Module):
42
  transformer: nn.Module,
43
  sigma=0.0,
44
  odeint_kwargs: dict = dict(
45
- # atol = 1e-5,
46
- # rtol = 1e-5,
47
  method="euler" # 'midpoint'
48
- # method="adaptive_heun" # dopri5
49
  ),
50
  odeint_options: dict = dict(
51
  min_step=0.05
@@ -71,8 +68,6 @@ class CFM(nn.Module):
71
  self.style_drop_prob = style_drop_prob
72
  self.lrc_drop_prob = lrc_drop_prob
73
 
74
- print(f"audio drop prob -> {self.audio_drop_prob}; style_drop_prob -> {self.style_drop_prob}; lrc_drop_prob: {self.lrc_drop_prob}")
75
-
76
  # transformer
77
  self.transformer = transformer
78
  dim = transformer.dim
@@ -83,7 +78,6 @@ class CFM(nn.Module):
83
 
84
  # sampling related
85
  self.odeint_kwargs = odeint_kwargs
86
- # print(f"ODE SOLVER: {self.odeint_kwargs['method']}")
87
 
88
  self.odeint_options = odeint_options
89
 
@@ -120,6 +114,7 @@ class CFM(nn.Module):
120
  start_time=None,
121
  latent_pred_start_frame=0,
122
  latent_pred_end_frame=2048,
 
123
  ):
124
  self.eval()
125
 
@@ -151,10 +146,9 @@ class CFM(nn.Module):
151
 
152
  if exists(text):
153
  text_lens = (text != -1).sum(dim=-1)
154
- #lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
155
 
156
  # duration
157
- # import pdb; pdb.set_trace()
158
  cond_mask = lens_to_mask(lens)
159
  if edit_mask is not None:
160
  cond_mask = cond_mask & edit_mask
@@ -170,7 +164,7 @@ class CFM(nn.Module):
170
  if isinstance(duration, int):
171
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
172
 
173
- # duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
174
  duration = duration.clamp(max=max_duration)
175
  max_duration = duration.amax()
176
 
@@ -178,12 +172,6 @@ class CFM(nn.Module):
178
  if duplicate_test:
179
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
180
 
181
- # cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) # [b, t, d]
182
- # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) # [b, max_duration]
183
- # cond_mask = cond_mask.unsqueeze(-1) #[b, t, d]
184
- # step_cond = torch.where(
185
- # cond_mask, cond, torch.zeros_like(cond)
186
- # ) # allow direct control (cut cond audio) with lens passed in
187
 
188
  if batch > 1:
189
  mask = lens_to_mask(duration)
@@ -197,6 +185,10 @@ class CFM(nn.Module):
197
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
198
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
199
 
 
 
 
 
200
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
201
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
202
  step_cond = torch.cat([step_cond, step_cond], 0)
@@ -242,7 +234,6 @@ class CFM(nn.Module):
242
 
243
  sampled = trajectory[-1]
244
  out = sampled
245
- # out = torch.where(cond_mask, cond, out)
246
  out = torch.where(fixed_span_mask, out, cond)
247
 
248
  if exists(vocoder):
@@ -286,7 +277,6 @@ class CFM(nn.Module):
286
  x0 = torch.randn_like(x1)
287
 
288
  # time step
289
- # time = torch.rand((batch,), dtype=dtype, device=self.device)
290
  time = torch.normal(mean=0, std=1, size=(batch,), device=self.device)
291
  time = torch.nn.functional.sigmoid(time)
292
  # TODO. noise_scheduler
 
42
  transformer: nn.Module,
43
  sigma=0.0,
44
  odeint_kwargs: dict = dict(
 
 
45
  method="euler" # 'midpoint'
 
46
  ),
47
  odeint_options: dict = dict(
48
  min_step=0.05
 
68
  self.style_drop_prob = style_drop_prob
69
  self.lrc_drop_prob = lrc_drop_prob
70
 
 
 
71
  # transformer
72
  self.transformer = transformer
73
  dim = transformer.dim
 
78
 
79
  # sampling related
80
  self.odeint_kwargs = odeint_kwargs
 
81
 
82
  self.odeint_options = odeint_options
83
 
 
114
  start_time=None,
115
  latent_pred_start_frame=0,
116
  latent_pred_end_frame=2048,
117
+ vocal_flag=False
118
  ):
119
  self.eval()
120
 
 
146
 
147
  if exists(text):
148
  text_lens = (text != -1).sum(dim=-1)
149
+
150
 
151
  # duration
 
152
  cond_mask = lens_to_mask(lens)
153
  if edit_mask is not None:
154
  cond_mask = cond_mask & edit_mask
 
164
  if isinstance(duration, int):
165
  duration = torch.full((batch,), duration, device=device, dtype=torch.long)
166
 
167
+
168
  duration = duration.clamp(max=max_duration)
169
  max_duration = duration.amax()
170
 
 
172
  if duplicate_test:
173
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
174
 
 
 
 
 
 
 
175
 
176
  if batch > 1:
177
  mask = lens_to_mask(duration)
 
185
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
186
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
187
 
188
+ if vocal_flag:
189
+ style_prompt = negative_style_prompt
190
+ negative_style_prompt = torch.zeros_like(style_prompt)
191
+
192
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
193
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
194
  step_cond = torch.cat([step_cond, step_cond], 0)
 
234
 
235
  sampled = trajectory[-1]
236
  out = sampled
 
237
  out = torch.where(fixed_span_mask, out, cond)
238
 
239
  if exists(vocoder):
 
277
  x0 = torch.randn_like(x1)
278
 
279
  # time step
 
280
  time = torch.normal(mean=0, std=1, size=(batch,), device=self.device)
281
  time = torch.nn.functional.sigmoid(time)
282
  # TODO. noise_scheduler
diffrhythm/model/dit.py CHANGED
@@ -13,8 +13,6 @@ import torch
13
  from torch import nn
14
  import torch
15
  import torch.nn.functional as F
16
-
17
- from x_transformers.x_transformers import RotaryEmbedding
18
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
19
  from transformers.models.llama import LlamaConfig
20
  from torch.utils.checkpoint import checkpoint
@@ -32,8 +30,6 @@ from diffrhythm.model.modules import (
32
  # apply_liger_kernel_to_llama()
33
 
34
  # Text embedding
35
-
36
-
37
  class TextEmbedding(nn.Module):
38
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
39
  super().__init__()
@@ -50,10 +46,7 @@ class TextEmbedding(nn.Module):
50
  self.extra_modeling = False
51
 
52
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
53
- #text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
54
- #text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
55
  batch, text_len = text.shape[0], text.shape[1]
56
- #text = F.pad(text, (0, seq_len - text_len), value=0)
57
 
58
  if drop_text: # cfg for text
59
  text = torch.zeros_like(text)
@@ -75,8 +68,6 @@ class TextEmbedding(nn.Module):
75
 
76
 
77
  # noised input audio and context mixing embedding
78
-
79
-
80
  class InputEmbedding(nn.Module):
81
  def __init__(self, mel_dim, text_dim, out_dim, cond_dim):
82
  super().__init__()
@@ -89,7 +80,6 @@ class InputEmbedding(nn.Module):
89
 
90
  style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
91
  time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
92
- # print(x.shape, cond.shape, text_embed.shape, style_emb.shape, time_emb.shape)
93
  x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
94
  x = self.conv_pos_embed(x) + x
95
  return x
@@ -125,17 +115,13 @@ class DiT(nn.Module):
125
  self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
126
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
127
 
128
- #self.rotary_embed = RotaryEmbedding(dim_head)
129
 
130
  self.dim = dim
131
  self.depth = depth
132
 
133
- #self.transformer_blocks = nn.ModuleList(
134
- # [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, use_style_prompt=use_style_prompt) for _ in range(depth)]
135
- #)
136
  llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
137
  llama_config._attn_implementation = 'sdpa'
138
- #llama_config._attn_implementation = ''
139
  self.transformer_blocks = nn.ModuleList(
140
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
141
  )
@@ -157,8 +143,6 @@ class DiT(nn.Module):
157
  self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
158
  self.proj_out = nn.Linear(dim, mel_dim)
159
 
160
- # if use_style_prompt:
161
- # self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
162
 
163
  def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
164
  s_t = self.start_time_embed(start_time)
 
13
  from torch import nn
14
  import torch
15
  import torch.nn.functional as F
 
 
16
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
17
  from transformers.models.llama import LlamaConfig
18
  from torch.utils.checkpoint import checkpoint
 
30
  # apply_liger_kernel_to_llama()
31
 
32
  # Text embedding
 
 
33
  class TextEmbedding(nn.Module):
34
  def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
35
  super().__init__()
 
46
  self.extra_modeling = False
47
 
48
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
 
49
  batch, text_len = text.shape[0], text.shape[1]
 
50
 
51
  if drop_text: # cfg for text
52
  text = torch.zeros_like(text)
 
68
 
69
 
70
  # noised input audio and context mixing embedding
 
 
71
  class InputEmbedding(nn.Module):
72
  def __init__(self, mel_dim, text_dim, out_dim, cond_dim):
73
  super().__init__()
 
80
 
81
  style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
82
  time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
 
83
  x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
84
  x = self.conv_pos_embed(x) + x
85
  return x
 
115
  self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
116
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
117
 
 
118
 
119
  self.dim = dim
120
  self.depth = depth
121
 
 
 
 
122
  llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
123
  llama_config._attn_implementation = 'sdpa'
124
+
125
  self.transformer_blocks = nn.ModuleList(
126
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
127
  )
 
143
  self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
144
  self.proj_out = nn.Linear(dim, mel_dim)
145
 
 
 
146
 
147
  def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
148
  s_t = self.start_time_embed(start_time)