Surn commited on
Commit
907a484
·
1 Parent(s): 6804dbd

Add STYLE model with upgrades

Browse files
app.py CHANGED
@@ -183,7 +183,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model,topp, temperatur
183
 
184
  return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
185
 
186
- def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, settings_animate_waveform=False, video_orientation="Landscape", progress=gr.Progress(track_tqdm=True)):
187
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
188
  output_segments = None
189
  melody_name = "Not Used"
@@ -251,24 +251,47 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
251
 
252
 
253
  print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
254
- MODEL.set_generation_params(
255
- use_sampling=True,
256
- top_k=topk,
257
- top_p=topp,
258
- temperature=temperature,
259
- cfg_coef=cfg_coef,
260
- duration=segment_duration,
261
- two_step_cfg=False,
262
- extend_stride=10,
263
- rep_penalty=0.5
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  MODEL.set_custom_progress_callback(gr.Progress(track_tqdm=True))
266
 
267
  try:
268
- if melody and ("melody" in model):
269
  # return excess duration, load next model and continue in loop structure building up output_segments
270
  if duration > MODEL.duration:
271
- output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.duration, prompt_index, harmony_only, progress=gr.Progress(track_tqdm=True))
272
  else:
273
  # pure original code
274
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
@@ -487,11 +510,11 @@ def ui(**kwargs):
487
  with gr.Column():
488
  with gr.Row():
489
  with gr.Column():
490
- text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi, soft fade-in, soft fade-out", key="prompt", lines=4)
491
  autoplay_cb = gr.Checkbox(value=False, label="Autoplay?", key="autoplay_cb")
492
  with gr.Column():
493
  duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True, key="total_duration", step=1)
494
- model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large"], label="AI Model", value="medium", interactive=True, key="chosen_model")
495
  with gr.Row():
496
  submit = gr.Button("Generate", elem_id="btn-generate")
497
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
@@ -545,7 +568,7 @@ def ui(**kwargs):
545
  gr.Examples(
546
  examples=[
547
  [
548
- "4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background",
549
  "./assets/bach.mp3",
550
  "melody",
551
  "80s Pop Synth",
@@ -554,7 +577,7 @@ def ui(**kwargs):
554
  3.5
555
  ],
556
  [
557
- "4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars",
558
  "./assets/bolero_ravel.mp3",
559
  "stereo-melody-large",
560
  "Country Guitar",
@@ -563,7 +586,7 @@ def ui(**kwargs):
563
  4.0
564
  ],
565
  [
566
- "4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums",
567
  None,
568
  "stereo-medium",
569
  "90s Rock Guitar",
@@ -572,7 +595,7 @@ def ui(**kwargs):
572
  3.75
573
  ],
574
  [
575
- "4/4 120bpm 320kbps 48khz, a light and cheery EDM track, with syncopated drums, aery pads, and strong emotions",
576
  "./assets/bach.mp3",
577
  "melody-large",
578
  "EDM my Bach",
@@ -581,7 +604,7 @@ def ui(**kwargs):
581
  3.75
582
  ],
583
  [
584
- "4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples",
585
  None,
586
  "medium",
587
  "LoFi Chill",
 
183
 
184
  return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
185
 
186
+ def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, settings_animate_waveform=False, video_orientation="Landscape", excerpt_duration=3.5, progress=gr.Progress(track_tqdm=True)):
187
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
188
  output_segments = None
189
  melody_name = "Not Used"
 
251
 
252
 
253
  print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
254
+ if ("style" in model) and melody:
255
+ # style and text-to-music
256
+ MODEL.set_generation_params(
257
+ use_sampling=True,
258
+ top_k=topk,
259
+ top_p=topp,
260
+ temperature=temperature,
261
+ cfg_coef=cfg_coef,
262
+ duration=segment_duration,
263
+ two_step_cfg=False,
264
+ cfg_coef_beta=5, # double CFG is only useful for text-and-style conditioning
265
+ )
266
+
267
+ MODEL.set_style_conditioner_params(
268
+ eval_q=3, # integer between 1 and 6
269
+ # eval_q is the level of quantization that passes
270
+ # through the conditioner. When low, the models adheres less to the
271
+ # audio conditioning
272
+ excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be
273
+ # between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
274
+ )
275
+ else:
276
+ MODEL.set_generation_params(
277
+ use_sampling=True,
278
+ top_k=topk,
279
+ top_p=topp,
280
+ temperature=temperature,
281
+ cfg_coef=cfg_coef,
282
+ duration=segment_duration,
283
+ two_step_cfg=False,
284
+ extend_stride=10,
285
+ rep_penalty=0.5,
286
+ cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
287
+ )
288
  MODEL.set_custom_progress_callback(gr.Progress(track_tqdm=True))
289
 
290
  try:
291
+ if melody and ("melody" or "style" in model):
292
  # return excess duration, load next model and continue in loop structure building up output_segments
293
  if duration > MODEL.duration:
294
+ output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.duration, prompt_index, harmony_only, excerpt_duration, progress=gr.Progress(track_tqdm=True))
295
  else:
296
  # pure original code
297
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
 
510
  with gr.Column():
511
  with gr.Row():
512
  with gr.Column():
513
+ text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 32khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi, soft fade-in, soft fade-out", key="prompt", lines=4)
514
  autoplay_cb = gr.Checkbox(value=False, label="Autoplay?", key="autoplay_cb")
515
  with gr.Column():
516
  duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True, key="total_duration", step=1)
517
+ model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large", "style"], label="AI Model", value="medium", interactive=True, key="chosen_model")
518
  with gr.Row():
519
  submit = gr.Button("Generate", elem_id="btn-generate")
520
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
 
568
  gr.Examples(
569
  examples=[
570
  [
571
+ "4/4 120bpm 320kbps 32khz, An 80s driving pop song with heavy drums and synth pads in the background",
572
  "./assets/bach.mp3",
573
  "melody",
574
  "80s Pop Synth",
 
577
  3.5
578
  ],
579
  [
580
+ "4/4 120bpm 320kbps 32khz, A cheerful country song with acoustic guitars",
581
  "./assets/bolero_ravel.mp3",
582
  "stereo-melody-large",
583
  "Country Guitar",
 
586
  4.0
587
  ],
588
  [
589
+ "4/4 120bpm 320kbps 32khz, 90s rock song with electric guitar and heavy drums",
590
  None,
591
  "stereo-medium",
592
  "90s Rock Guitar",
 
595
  3.75
596
  ],
597
  [
598
+ "4/4 120bpm 320kbps 32khz, a light and cheery EDM track, with syncopated drums, aery pads, and strong emotions",
599
  "./assets/bach.mp3",
600
  "melody-large",
601
  "EDM my Bach",
 
604
  3.75
605
  ],
606
  [
607
+ "4/4 320kbps 32khz, lofi slow bpm electro chill with organic samples",
608
  None,
609
  "medium",
610
  "LoFi Chill",
audiocraft/__init__.py CHANGED
@@ -7,4 +7,4 @@
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
- __version__ = '1.2.Surn'
 
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
+ __version__ = '1.3.Surn'
audiocraft/data/audio.py CHANGED
@@ -79,7 +79,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
79
  seek_time (float): Time at which to start reading in the file.
80
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
81
  Returns:
82
- Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
83
  """
84
  _init_av()
85
  with av.open(str(filepath)) as af:
@@ -115,7 +115,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
115
 
116
 
117
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
118
- duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
119
  """Read audio by picking the most appropriate backend tool based on the audio format.
120
 
121
  Args:
@@ -124,7 +124,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
124
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
125
  pad (bool): Pad output audio if not reaching expected duration.
126
  Returns:
127
- Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
128
  """
129
  fp = Path(filepath)
130
  if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
@@ -299,4 +299,124 @@ def audio_write2(stem_name: tp.Union[str, Path],
299
  # we do not want to leave half written files around.
300
  path.unlink()
301
  raise
302
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  seek_time (float): Time at which to start reading in the file.
80
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
81
  Returns:
82
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
83
  """
84
  _init_av()
85
  with av.open(str(filepath)) as af:
 
115
 
116
 
117
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
118
+ duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
119
  """Read audio by picking the most appropriate backend tool based on the audio format.
120
 
121
  Args:
 
124
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
125
  pad (bool): Pad output audio if not reaching expected duration.
126
  Returns:
127
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
128
  """
129
  fp = Path(filepath)
130
  if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
 
299
  # we do not want to leave half written files around.
300
  path.unlink()
301
  raise
302
+ return path
303
+
304
+
305
+ def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
306
+ """Get the mel-spectrogram from the raw audio.
307
+
308
+ Args:
309
+ y (numpy array): raw input
310
+ sr (int): Sampling rate
311
+ n_fft (int): Number of samples per FFT. Default is 2048.
312
+ hop_length (int): Number of samples between successive frames. Default is 512.
313
+ dur (float): Maxium duration to get the spectrograms
314
+ Returns:
315
+ spectro histogram as a numpy array
316
+ """
317
+ import librosa
318
+ import librosa.display
319
+
320
+ spectrogram = librosa.feature.melspectrogram(
321
+ y=y, sr=sr, n_fft=n_fft, hop_length=hop_length
322
+ )
323
+ spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
324
+ return spectrogram_db
325
+
326
+
327
+ def save_spectrograms(
328
+ ys: tp.List[np.ndarray],
329
+ sr: int,
330
+ path: str,
331
+ names: tp.List[str],
332
+ n_fft: int = 4096,
333
+ hop_length: int = 128,
334
+ dur: float = 8.0,
335
+ ):
336
+ """Plot a spectrogram for an audio file.
337
+
338
+ Args:
339
+ ys: List of audio spectrograms
340
+ sr (int): Sampling rate of the audio file. Default is 22050 Hz.
341
+ path (str): Path to the plot file.
342
+ names: name of each spectrogram plot
343
+ n_fft (int): Number of samples per FFT. Default is 2048.
344
+ hop_length (int): Number of samples between successive frames. Default is 512.
345
+ dur (float): Maxium duration to plot the spectrograms
346
+
347
+ Returns:
348
+ None (plots the spectrogram using matplotlib)
349
+ """
350
+ import matplotlib as mpl # type: ignore
351
+ import matplotlib.pyplot as plt # type: ignore
352
+ import librosa.display
353
+
354
+ if not names:
355
+ names = ["Ground Truth", "Audio Watermarked", "Watermark"]
356
+ ys = [wav[: int(dur * sr)] for wav in ys] # crop
357
+ assert len(names) == len(
358
+ ys
359
+ ), f"There are {len(ys)} wavs but {len(names)} names ({names})"
360
+
361
+ # Set matplotlib stuff
362
+ BIGGER_SIZE = 10
363
+ SMALLER_SIZE = 8
364
+ linewidth = 234.8775 # linewidth in pt
365
+
366
+ plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes
367
+ plt.rcParams["font.family"] = "DeJavu Serif"
368
+ plt.rcParams["font.serif"] = ["Times New Roman"]
369
+
370
+ plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title
371
+ plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels
372
+ plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels
373
+ plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels
374
+ plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize
375
+ plt.rc("figure", titlesize=BIGGER_SIZE)
376
+ height = 1.6 * linewidth / 72.0
377
+ fig, ax = plt.subplots(
378
+ nrows=len(ys),
379
+ ncols=1,
380
+ sharex=True,
381
+ figsize=(linewidth / 72.0, height),
382
+ )
383
+ fig.tight_layout()
384
+
385
+ # Plot the spectrogram
386
+
387
+ for i, ysi in enumerate(ys):
388
+ spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length)
389
+ if i == 0:
390
+ cax = fig.add_axes(
391
+ [
392
+ ax[0].get_position().x1 + 0.01, # type: ignore
393
+ ax[-1].get_position().y0,
394
+ 0.02,
395
+ ax[0].get_position().y1 - ax[-1].get_position().y0,
396
+ ]
397
+ )
398
+ fig.colorbar(
399
+ mpl.cm.ScalarMappable(
400
+ norm=mpl.colors.Normalize(
401
+ np.min(spectrogram_db), np.max(spectrogram_db)
402
+ ),
403
+ cmap="magma",
404
+ ),
405
+ ax=ax,
406
+ orientation="vertical",
407
+ format="%+2.0f dB",
408
+ cax=cax,
409
+ )
410
+ librosa.display.specshow(
411
+ spectrogram_db,
412
+ sr=sr,
413
+ hop_length=hop_length,
414
+ x_axis="time",
415
+ y_axis="mel",
416
+ ax=ax[i],
417
+ )
418
+ ax[i].set(title=names[i])
419
+ ax[i].yaxis.set_label_text(None)
420
+ ax[i].label_outer()
421
+ fig.savefig(path, bbox_inches="tight")
422
+ plt.close()
audiocraft/models/__init__.py CHANGED
@@ -12,6 +12,10 @@ from . import builders, loaders
12
  from .encodec import (
13
  CompressionModel, EncodecModel, DAC,
14
  HFEncodecModel, HFEncodecCompressionModel)
15
- from .musicgen import MusicGen
16
  from .lm import LMModel
 
 
17
  from .encodec import CompressionModel, EncodecModel
 
 
 
 
12
  from .encodec import (
13
  CompressionModel, EncodecModel, DAC,
14
  HFEncodecModel, HFEncodecCompressionModel)
 
15
  from .lm import LMModel
16
+ from .lm_magnet import MagnetLMModel
17
+ from .flow_matching import FlowMatchingModel
18
  from .encodec import CompressionModel, EncodecModel
19
+ from .musicgen import MusicGen
20
+ from .magnet import MAGNeT
21
+ from .unet import DiffusionUnet
audiocraft/models/builders.py CHANGED
@@ -11,51 +11,53 @@ from the Hydra config.
11
 
12
  import typing as tp
13
 
14
- import audiocraft
15
  import omegaconf
16
  import torch
17
 
18
- from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
19
- from .lm import LMModel
20
- from ..modules.codebooks_patterns import (
21
- CodebooksPatternProvider,
22
- DelayedPatternProvider,
23
- MusicLMPattern,
24
- ParallelPatternProvider,
25
- UnrolledPatternProvider,
26
- CoarseFirstPattern,
27
- )
28
- from ..modules.conditioners import (
29
- BaseConditioner,
30
- ChromaStemConditioner,
31
- CLAPEmbeddingConditioner,
32
- ConditionFuser,
33
- ConditioningProvider,
34
- LUTConditioner,
35
- T5Conditioner,
36
- )
37
- from .unet import DiffusionUnet
38
  from .. import quantization as qt
39
- from ..utils.utils import dict_from_config
 
 
 
 
 
 
 
 
 
 
 
 
40
  from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
 
 
 
 
 
 
 
 
41
 
42
 
43
- def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
44
- klass = {
45
- 'no_quant': qt.DummyQuantizer,
46
- 'rvq': qt.ResidualVectorQuantizer
47
- }[quantizer]
 
48
  kwargs = dict_from_config(getattr(cfg, quantizer))
49
- if quantizer != 'no_quant':
50
- kwargs['dimension'] = dimension
51
  return klass(**kwargs)
52
 
53
 
54
  def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
55
- if encoder_name == 'seanet':
56
- kwargs = dict_from_config(getattr(cfg, 'seanet'))
57
- encoder_override_kwargs = kwargs.pop('encoder')
58
- decoder_override_kwargs = kwargs.pop('decoder')
59
  encoder_kwargs = {**kwargs, **encoder_override_kwargs}
60
  decoder_kwargs = {**kwargs, **decoder_override_kwargs}
61
  encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
@@ -67,44 +69,98 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
67
 
68
  def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
69
  """Instantiate a compression model."""
70
- if cfg.compression_model == 'encodec':
71
- kwargs = dict_from_config(getattr(cfg, 'encodec'))
72
- encoder_name = kwargs.pop('autoencoder')
73
- quantizer_name = kwargs.pop('quantizer')
74
  encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
75
  quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
76
- frame_rate = kwargs['sample_rate'] // encoder.hop_length
77
- renormalize = kwargs.pop('renormalize', False)
78
  # deprecated params
79
- kwargs.pop('renorm', None)
80
- return EncodecModel(encoder, decoder, quantizer,
81
- frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
 
 
 
 
 
 
82
  else:
83
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
87
  """Instantiate a transformer LM."""
88
- if cfg.lm_model == 'transformer_lm':
89
- kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
90
- n_q = kwargs['n_q']
91
- q_modeling = kwargs.pop('q_modeling', None)
92
- codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
93
- attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
94
- cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
95
- cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
 
 
 
96
  fuser = get_condition_fuser(cfg)
97
  condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
98
- if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
99
- kwargs['cross_attention'] = True
100
  if codebooks_pattern_cfg.modeling is None:
101
- assert q_modeling is not None, \
102
- "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
 
103
  codebooks_pattern_cfg = omegaconf.OmegaConf.create(
104
- {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
105
  )
 
106
  pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
107
- return LMModel(
 
108
  pattern_provider=pattern_provider,
109
  condition_provider=condition_provider,
110
  fuser=fuser,
@@ -113,67 +169,84 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
113
  attribute_dropout=attribute_dropout,
114
  dtype=getattr(torch, cfg.dtype),
115
  device=cfg.device,
116
- **kwargs
117
  ).to(cfg.device)
118
  else:
119
  raise KeyError(f"Unexpected LM model {cfg.lm_model}")
120
 
121
 
122
- def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
 
 
123
  """Instantiate a conditioning model."""
124
  device = cfg.device
125
  duration = cfg.dataset.segment_duration
126
- cfg = getattr(cfg, 'conditioners')
127
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
128
  conditioners: tp.Dict[str, BaseConditioner] = {}
129
- condition_provider_args = dict_cfg.pop('args', {})
130
- condition_provider_args.pop('merge_text_conditions_p', None)
131
- condition_provider_args.pop('drop_desc_p', None)
132
 
133
  for cond, cond_cfg in dict_cfg.items():
134
- model_type = cond_cfg['model']
135
  model_args = cond_cfg[model_type]
136
- if model_type == 't5':
137
- conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
138
- elif model_type == 'lut':
139
- conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
140
- elif model_type == 'chroma_stem':
 
 
 
 
141
  conditioners[str(cond)] = ChromaStemConditioner(
142
- output_dim=output_dim,
143
- duration=duration,
144
- device=device,
145
- **model_args
146
  )
147
- elif model_type == 'clap':
 
 
 
 
 
 
148
  conditioners[str(cond)] = CLAPEmbeddingConditioner(
 
 
 
 
149
  output_dim=output_dim,
150
  device=device,
151
  **model_args
152
  )
153
  else:
154
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
155
- conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
 
 
156
  return conditioner
157
 
158
 
159
  def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
160
  """Instantiate a condition fuser object."""
161
- fuser_cfg = getattr(cfg, 'fuser')
162
- fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
163
- fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
164
  kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
165
  fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
166
  return fuser
167
 
168
 
169
- def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
 
 
170
  """Instantiate a codebooks pattern provider object."""
171
  pattern_providers = {
172
- 'parallel': ParallelPatternProvider,
173
- 'delay': DelayedPatternProvider,
174
- 'unroll': UnrolledPatternProvider,
175
- 'coarse_first': CoarseFirstPattern,
176
- 'musiclm': MusicLMPattern,
177
  }
178
  name = cfg.modeling
179
  kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
@@ -181,20 +254,23 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
181
  return klass(n_q, **kwargs)
182
 
183
 
184
- def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
185
  """Instantiate a debug compression model to be used for unit tests."""
186
- assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
 
 
 
187
  model_ratios = {
188
  16000: [10, 8, 8], # 25 Hz at 16kHz
189
- 32000: [10, 8, 16] # 25 Hz at 32kHz
190
  }
191
  ratios: tp.List[int] = model_ratios[sample_rate]
192
  frame_rate = 25
193
  seanet_kwargs: dict = {
194
- 'n_filters': 4,
195
- 'n_residual_layers': 1,
196
- 'dimension': 32,
197
- 'ratios': ratios,
198
  }
199
  encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
200
  decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
@@ -202,8 +278,13 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
202
  init_x = torch.randn(8, 32, 128)
203
  quantizer(init_x, 1) # initialize kmeans etc.
204
  compression_model = EncodecModel(
205
- encoder, decoder, quantizer,
206
- frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
 
 
 
 
 
207
  return compression_model.eval()
208
 
209
 
@@ -211,48 +292,60 @@ def get_diffusion_model(cfg: omegaconf.DictConfig):
211
  # TODO Find a way to infer the channels from dset
212
  channels = cfg.channels
213
  num_steps = cfg.schedule.num_steps
214
- return DiffusionUnet(
215
- chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
216
 
217
 
218
  def get_processor(cfg, sample_rate: int = 24000):
219
  sample_processor = SampleProcessor()
220
  if cfg.use:
221
  kw = dict(cfg)
222
- kw.pop('use')
223
- kw.pop('name')
224
  if cfg.name == "multi_band_processor":
225
  sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
226
  return sample_processor
227
 
228
 
229
- def get_debug_lm_model(device='cpu'):
230
  """Instantiate a debug LM to be used for unit tests."""
231
  pattern = DelayedPatternProvider(n_q=4)
232
  dim = 16
233
  providers = {
234
- 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
 
 
235
  }
236
  condition_provider = ConditioningProvider(providers)
237
  fuser = ConditionFuser(
238
- {'cross': ['description'], 'prepend': [],
239
- 'sum': [], 'input_interpolate': []})
240
  lm = LMModel(
241
- pattern, condition_provider, fuser,
242
- n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
243
- cross_attention=True, causal=True)
 
 
 
 
 
 
 
 
 
244
  return lm.to(device).eval()
245
 
246
 
247
  def get_wrapped_compression_model(
248
- compression_model: CompressionModel,
249
- cfg: omegaconf.DictConfig) -> CompressionModel:
250
- if hasattr(cfg, 'interleave_stereo_codebooks'):
251
  if cfg.interleave_stereo_codebooks.use:
252
  kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
253
- kwargs.pop('use')
254
- compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
255
- if hasattr(cfg, 'compression_model_n_q'):
 
 
256
  if cfg.compression_model_n_q is not None:
257
  compression_model.set_num_codebooks(cfg.compression_model_n_q)
258
  return compression_model
 
11
 
12
  import typing as tp
13
 
 
14
  import omegaconf
15
  import torch
16
 
17
+ import audiocraft
18
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from .. import quantization as qt
20
+ from ..modules.codebooks_patterns import (CoarseFirstPattern,
21
+ CodebooksPatternProvider,
22
+ DelayedPatternProvider,
23
+ MusicLMPattern,
24
+ ParallelPatternProvider,
25
+ UnrolledPatternProvider)
26
+ from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner,
27
+ CLAPEmbeddingConditioner,
28
+ ConditionFuser, JascoCondConst,
29
+ ConditioningProvider, LUTConditioner,
30
+ T5Conditioner, StyleConditioner)
31
+ from ..modules.jasco_conditioners import (JascoConditioningProvider, ChordsEmbConditioner,
32
+ DrumsConditioner, MelodyConditioner)
33
  from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
34
+ from ..utils.utils import dict_from_config
35
+ from .encodec import (CompressionModel, EncodecModel,
36
+ InterleaveStereoCompressionModel)
37
+ from .lm import LMModel
38
+ from .lm_magnet import MagnetLMModel
39
+ from .flow_matching import FlowMatchingModel
40
+ from .unet import DiffusionUnet
41
+
42
 
43
 
44
+ def get_quantizer(
45
+ quantizer: str, cfg: omegaconf.DictConfig, dimension: int
46
+ ) -> qt.BaseQuantizer:
47
+ klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[
48
+ quantizer
49
+ ]
50
  kwargs = dict_from_config(getattr(cfg, quantizer))
51
+ if quantizer != "no_quant":
52
+ kwargs["dimension"] = dimension
53
  return klass(**kwargs)
54
 
55
 
56
  def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
57
+ if encoder_name == "seanet":
58
+ kwargs = dict_from_config(getattr(cfg, "seanet"))
59
+ encoder_override_kwargs = kwargs.pop("encoder")
60
+ decoder_override_kwargs = kwargs.pop("decoder")
61
  encoder_kwargs = {**kwargs, **encoder_override_kwargs}
62
  decoder_kwargs = {**kwargs, **decoder_override_kwargs}
63
  encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
 
69
 
70
  def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
71
  """Instantiate a compression model."""
72
+ if cfg.compression_model == "encodec":
73
+ kwargs = dict_from_config(getattr(cfg, "encodec"))
74
+ encoder_name = kwargs.pop("autoencoder")
75
+ quantizer_name = kwargs.pop("quantizer")
76
  encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
77
  quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
78
+ frame_rate = kwargs["sample_rate"] // encoder.hop_length
79
+ renormalize = kwargs.pop("renormalize", False)
80
  # deprecated params
81
+ kwargs.pop("renorm", None)
82
+ return EncodecModel(
83
+ encoder,
84
+ decoder,
85
+ quantizer,
86
+ frame_rate=frame_rate,
87
+ renormalize=renormalize,
88
+ **kwargs,
89
+ ).to(cfg.device)
90
  else:
91
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
92
 
93
 
94
+ def get_jasco_model(cfg: omegaconf.DictConfig,
95
+ compression_model: tp.Optional[CompressionModel] = None) -> FlowMatchingModel:
96
+ kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
97
+ attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
98
+ cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
99
+ cfg_prob = cls_free_guidance["training_dropout"]
100
+ cfg_coef = cls_free_guidance["inference_coef"]
101
+ fuser = get_condition_fuser(cfg)
102
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
103
+ if JascoCondConst.DRM.value in condition_provider.conditioners: # use self_wav for drums
104
+ assert compression_model is not None
105
+
106
+ # use compression model for drums conditioning
107
+ condition_provider.conditioners.self_wav.compression_model = compression_model
108
+ condition_provider.conditioners.self_wav.compression_model.requires_grad_(False)
109
+
110
+ # downcast to jasco conditioning provider
111
+ seq_len = cfg.compression_model_framerate * cfg.dataset.segment_duration
112
+ chords_card = cfg.conditioners.chords.chords_emb.card if JascoCondConst.CRD.value in cfg.conditioners else -1
113
+ condition_provider = JascoConditioningProvider(device=condition_provider.device,
114
+ conditioners=condition_provider.conditioners,
115
+ chords_card=chords_card,
116
+ sequence_length=seq_len)
117
+
118
+ if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
119
+ kwargs["cross_attention"] = True
120
+
121
+ kwargs.pop("n_q", None)
122
+ kwargs.pop("card", None)
123
+
124
+ return FlowMatchingModel(
125
+ condition_provider=condition_provider,
126
+ fuser=fuser,
127
+ cfg_dropout=cfg_prob,
128
+ cfg_coef=cfg_coef,
129
+ attribute_dropout=attribute_dropout,
130
+ dtype=getattr(torch, cfg.dtype),
131
+ device=cfg.device,
132
+ **kwargs,
133
+ ).to(cfg.device)
134
+
135
+
136
  def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
137
  """Instantiate a transformer LM."""
138
+ if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]:
139
+ kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
140
+ n_q = kwargs["n_q"]
141
+ q_modeling = kwargs.pop("q_modeling", None)
142
+ codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern")
143
+ attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
144
+ cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
145
+ cfg_prob, cfg_coef = (
146
+ cls_free_guidance["training_dropout"],
147
+ cls_free_guidance["inference_coef"],
148
+ )
149
  fuser = get_condition_fuser(cfg)
150
  condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
151
+ if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
152
+ kwargs["cross_attention"] = True
153
  if codebooks_pattern_cfg.modeling is None:
154
+ assert (
155
+ q_modeling is not None
156
+ ), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
157
  codebooks_pattern_cfg = omegaconf.OmegaConf.create(
158
+ {"modeling": q_modeling, "delay": {"delays": list(range(n_q))}}
159
  )
160
+
161
  pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
162
+ lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel
163
+ return lm_class(
164
  pattern_provider=pattern_provider,
165
  condition_provider=condition_provider,
166
  fuser=fuser,
 
169
  attribute_dropout=attribute_dropout,
170
  dtype=getattr(torch, cfg.dtype),
171
  device=cfg.device,
172
+ **kwargs,
173
  ).to(cfg.device)
174
  else:
175
  raise KeyError(f"Unexpected LM model {cfg.lm_model}")
176
 
177
 
178
+ def get_conditioner_provider(
179
+ output_dim: int, cfg: omegaconf.DictConfig
180
+ ) -> ConditioningProvider:
181
  """Instantiate a conditioning model."""
182
  device = cfg.device
183
  duration = cfg.dataset.segment_duration
184
+ cfg = getattr(cfg, "conditioners")
185
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
186
  conditioners: tp.Dict[str, BaseConditioner] = {}
187
+ condition_provider_args = dict_cfg.pop("args", {})
188
+ condition_provider_args.pop("merge_text_conditions_p", None)
189
+ condition_provider_args.pop("drop_desc_p", None)
190
 
191
  for cond, cond_cfg in dict_cfg.items():
192
+ model_type = cond_cfg["model"]
193
  model_args = cond_cfg[model_type]
194
+ if model_type == "t5":
195
+ conditioners[str(cond)] = T5Conditioner(
196
+ output_dim=output_dim, device=device, **model_args
197
+ )
198
+ elif model_type == "lut":
199
+ conditioners[str(cond)] = LUTConditioner(
200
+ output_dim=output_dim, **model_args
201
+ )
202
+ elif model_type == "chroma_stem":
203
  conditioners[str(cond)] = ChromaStemConditioner(
204
+ output_dim=output_dim, duration=duration, device=device, **model_args
 
 
 
205
  )
206
+ elif model_type in {"chords_emb", "drum_latents", "melody"}:
207
+ conditioners_classes = {"chords_emb": ChordsEmbConditioner,
208
+ "drum_latents": DrumsConditioner,
209
+ "melody": MelodyConditioner}
210
+ conditioner_class = conditioners_classes[model_type]
211
+ conditioners[str(cond)] = conditioner_class(device=device, **model_args)
212
+ elif model_type == "clap":
213
  conditioners[str(cond)] = CLAPEmbeddingConditioner(
214
+ output_dim=output_dim, device=device, **model_args
215
+ )
216
+ elif model_type == 'style':
217
+ conditioners[str(cond)] = StyleConditioner(
218
  output_dim=output_dim,
219
  device=device,
220
  **model_args
221
  )
222
  else:
223
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
224
+ conditioner = ConditioningProvider(
225
+ conditioners, device=device, **condition_provider_args
226
+ )
227
  return conditioner
228
 
229
 
230
  def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
231
  """Instantiate a condition fuser object."""
232
+ fuser_cfg = getattr(cfg, "fuser")
233
+ fuser_methods = ["sum", "cross", "prepend", "ignore", "input_interpolate"]
234
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods if k in fuser_cfg}
235
  kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
236
  fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
237
  return fuser
238
 
239
 
240
+ def get_codebooks_pattern_provider(
241
+ n_q: int, cfg: omegaconf.DictConfig
242
+ ) -> CodebooksPatternProvider:
243
  """Instantiate a codebooks pattern provider object."""
244
  pattern_providers = {
245
+ "parallel": ParallelPatternProvider,
246
+ "delay": DelayedPatternProvider,
247
+ "unroll": UnrolledPatternProvider,
248
+ "coarse_first": CoarseFirstPattern,
249
+ "musiclm": MusicLMPattern,
250
  }
251
  name = cfg.modeling
252
  kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
 
254
  return klass(n_q, **kwargs)
255
 
256
 
257
+ def get_debug_compression_model(device="cpu", sample_rate: int = 32000):
258
  """Instantiate a debug compression model to be used for unit tests."""
259
+ assert sample_rate in [
260
+ 16000,
261
+ 32000,
262
+ ], "unsupported sample rate for debug compression model"
263
  model_ratios = {
264
  16000: [10, 8, 8], # 25 Hz at 16kHz
265
+ 32000: [10, 8, 16], # 25 Hz at 32kHz
266
  }
267
  ratios: tp.List[int] = model_ratios[sample_rate]
268
  frame_rate = 25
269
  seanet_kwargs: dict = {
270
+ "n_filters": 4,
271
+ "n_residual_layers": 1,
272
+ "dimension": 32,
273
+ "ratios": ratios,
274
  }
275
  encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
276
  decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
 
278
  init_x = torch.randn(8, 32, 128)
279
  quantizer(init_x, 1) # initialize kmeans etc.
280
  compression_model = EncodecModel(
281
+ encoder,
282
+ decoder,
283
+ quantizer,
284
+ frame_rate=frame_rate,
285
+ sample_rate=sample_rate,
286
+ channels=1,
287
+ ).to(device)
288
  return compression_model.eval()
289
 
290
 
 
292
  # TODO Find a way to infer the channels from dset
293
  channels = cfg.channels
294
  num_steps = cfg.schedule.num_steps
295
+ return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
 
296
 
297
 
298
  def get_processor(cfg, sample_rate: int = 24000):
299
  sample_processor = SampleProcessor()
300
  if cfg.use:
301
  kw = dict(cfg)
302
+ kw.pop("use")
303
+ kw.pop("name")
304
  if cfg.name == "multi_band_processor":
305
  sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
306
  return sample_processor
307
 
308
 
309
+ def get_debug_lm_model(device="cpu"):
310
  """Instantiate a debug LM to be used for unit tests."""
311
  pattern = DelayedPatternProvider(n_q=4)
312
  dim = 16
313
  providers = {
314
+ "description": LUTConditioner(
315
+ n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"
316
+ ),
317
  }
318
  condition_provider = ConditioningProvider(providers)
319
  fuser = ConditionFuser(
320
+ {"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []}
321
+ )
322
  lm = LMModel(
323
+ pattern,
324
+ condition_provider,
325
+ fuser,
326
+ n_q=4,
327
+ card=400,
328
+ dim=dim,
329
+ num_heads=4,
330
+ custom=True,
331
+ num_layers=2,
332
+ cross_attention=True,
333
+ causal=True,
334
+ )
335
  return lm.to(device).eval()
336
 
337
 
338
  def get_wrapped_compression_model(
339
+ compression_model: CompressionModel, cfg: omegaconf.DictConfig
340
+ ) -> CompressionModel:
341
+ if hasattr(cfg, "interleave_stereo_codebooks"):
342
  if cfg.interleave_stereo_codebooks.use:
343
  kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
344
+ kwargs.pop("use")
345
+ compression_model = InterleaveStereoCompressionModel(
346
+ compression_model, **kwargs
347
+ )
348
+ if hasattr(cfg, "compression_model_n_q"):
349
  if cfg.compression_model_n_q is not None:
350
  compression_model.set_num_codebooks(cfg.compression_model_n_q)
351
  return compression_model
audiocraft/models/flow_matching.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ import logging
10
+ import math
11
+ import typing as tp
12
+ import torch
13
+ from torch import nn
14
+ from torchdiffeq import odeint # type: ignore
15
+ from ..modules.streaming import StreamingModule
16
+ from ..modules.transformer import create_norm_fn, StreamingTransformerLayer
17
+ from ..modules.unet_transformer import UnetTransformer
18
+ from ..modules.conditioners import (
19
+ ConditionFuser,
20
+ ClassifierFreeGuidanceDropout,
21
+ AttributeDropout,
22
+ ConditioningAttributes,
23
+ JascoCondConst
24
+ )
25
+ from ..modules.jasco_conditioners import JascoConditioningProvider
26
+ from ..modules.activations import get_activation_fn
27
+
28
+ from .lm import ConditionTensors, init_layer
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class FMOutput:
36
+ latents: torch.Tensor # [B, T, D]
37
+ mask: torch.Tensor # [B, T]
38
+
39
+
40
+ class CFGTerm:
41
+ """
42
+ Base class for Multi Source Classifier-Free Guidance (CFG) terms. This class represents a term in the CFG process,
43
+ which is used to guide the generation process by adjusting the influence of different conditions.
44
+ Attributes:
45
+ conditions (dict): A dictionary of conditions that influence the generation process.
46
+ weight (float): The weight of the CFG term, determining its influence on the generation.
47
+ """
48
+ def __init__(self, conditions, weight):
49
+ self.conditions = conditions
50
+ self.weight = weight
51
+
52
+ def drop_irrelevant_conds(self, conditions):
53
+ """
54
+ Drops irrelevant conditions from the CFG term. This method should be implemented by subclasses.
55
+ Args:
56
+ conditions (dict): The conditions to be filtered.
57
+ Raises:
58
+ NotImplementedError: If the method is not implemented in a subclass.
59
+ """
60
+ raise NotImplementedError("No base implementation for setting generation params.")
61
+
62
+
63
+ class AllCFGTerm(CFGTerm):
64
+ """
65
+ A CFG term that retains all conditions. This class does not drop any condition.
66
+ """
67
+ def __init__(self, conditions, weight):
68
+ super().__init__(conditions, weight)
69
+ self.drop_irrelevant_conds()
70
+
71
+ def drop_irrelevant_conds(self):
72
+ pass
73
+
74
+
75
+ class NullCFGTerm(CFGTerm):
76
+ """
77
+ A CFG term that drops all conditions, effectively nullifying their influence.
78
+ """
79
+ def __init__(self, conditions, weight):
80
+ super().__init__(conditions, weight)
81
+ self.drop_irrelevant_conds()
82
+
83
+ def drop_irrelevant_conds(self):
84
+ """
85
+ Drops all conditions by applying a dropout with probability 1.0, effectively nullifying their influence.
86
+ """
87
+ self.conditions = ClassifierFreeGuidanceDropout(p=1.0)(
88
+ samples=self.conditions,
89
+ cond_types=["wav", "text", "symbolic"])
90
+
91
+
92
+ class TextCFGTerm(CFGTerm):
93
+ """
94
+ A CFG term that selectively drops conditions based on specified dropout probabilities for different types
95
+ of conditions, such as 'symbolic' and 'wav'.
96
+ """
97
+ def __init__(self, conditions, weight, model_att_dropout):
98
+ """
99
+ Initializes a TextCFGTerm with specified conditions, weight, and model attention dropout configuration.
100
+ Args:
101
+ conditions (dict): The conditions to be used in the CFG process.
102
+ weight (float): The weight of the CFG term.
103
+ model_att_dropout (object): The attribute dropouts used by the model.
104
+ """
105
+ super().__init__(conditions, weight)
106
+ if 'symbolic' in model_att_dropout.p:
107
+ self.drop_symbolics = {k: 1.0 for k in model_att_dropout.p['symbolic'].keys()}
108
+ else:
109
+ self.drop_symbolics = {}
110
+ if 'wav' in model_att_dropout.p:
111
+ self.drop_wav = {k: 1.0 for k in model_att_dropout.p['wav'].keys()}
112
+ else:
113
+ self.drop_wav = {}
114
+ self.drop_irrelevant_conds()
115
+
116
+ def drop_irrelevant_conds(self):
117
+ self.conditions = AttributeDropout({'symbolic': self.drop_symbolics,
118
+ 'wav': self.drop_wav})(self.conditions) # drop temporal conds
119
+
120
+
121
+ class FlowMatchingModel(StreamingModule):
122
+ """
123
+ A flow matching model inherits from StreamingModule.
124
+ This model uses a transformer architecture to process and fuse conditions, applying learned embeddings and
125
+ transformations and predicts multi-source guided vector fields.
126
+ Attributes:
127
+ condition_provider (JascoConditioningProvider): Provider for conditioning attributes.
128
+ fuser (ConditionFuser): Fuser for combining multiple conditions.
129
+ dim (int): Dimensionality of the model's main features.
130
+ num_heads (int): Number of attention heads in the transformer.
131
+ flow_dim (int): Dimensionality of the flow features.
132
+ chords_dim (int): Dimensionality for chord embeddings, if used.
133
+ drums_dim (int): Dimensionality for drums embeddings, if used.
134
+ melody_dim (int): Dimensionality for melody embeddings, if used.
135
+ hidden_scale (int): Scaling factor for the dimensionality of the feedforward network in the transformer.
136
+ norm (str): Type of normalization to use ('layer_norm' or other supported types).
137
+ norm_first (bool): Whether to apply normalization before other operations in the transformer layers.
138
+ bias_proj (bool): Whether to include bias in the projection layers.
139
+ weight_init (Optional[str]): Method for initializing weights.
140
+ depthwise_init (Optional[str]): Method for initializing depthwise convolutional layers.
141
+ zero_bias_init (bool): Whether to initialize biases to zero.
142
+ cfg_dropout (float): Dropout rate for configuration settings.
143
+ cfg_coef (float): Coefficient for configuration influence.
144
+ attribute_dropout (Dict[str, Dict[str, float]]): Dropout rates for specific attributes.
145
+ time_embedding_dim (int): Dimensionality of time embeddings.
146
+ **kwargs: Additional keyword arguments for the transformer.
147
+ Methods:
148
+ __init__: Initializes the model with the specified attributes and configuration.
149
+ """
150
+ def __init__(self, condition_provider: JascoConditioningProvider,
151
+ fuser: ConditionFuser,
152
+ dim: int = 128,
153
+ num_heads: int = 8,
154
+ flow_dim: int = 128,
155
+ chords_dim: int = 0,
156
+ drums_dim: int = 0,
157
+ melody_dim: int = 0,
158
+ hidden_scale: int = 4,
159
+ norm: str = 'layer_norm',
160
+ norm_first: bool = False,
161
+ bias_proj: bool = True,
162
+ weight_init: tp.Optional[str] = None,
163
+ depthwise_init: tp.Optional[str] = None,
164
+ zero_bias_init: bool = False,
165
+ cfg_dropout: float = 0,
166
+ cfg_coef: float = 1.0,
167
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
168
+ time_embedding_dim: int = 128,
169
+ **kwargs):
170
+ super().__init__()
171
+ self.cfg_coef = cfg_coef
172
+
173
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
174
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
175
+ self.condition_provider = condition_provider
176
+ self.fuser = fuser
177
+ self.dim = dim # transformer dim
178
+ self.flow_dim = flow_dim
179
+ self.chords_dim = chords_dim
180
+ self.emb = nn.Linear(flow_dim + chords_dim + drums_dim + melody_dim, dim, bias=False)
181
+ if 'activation' in kwargs:
182
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
183
+
184
+ self.transformer = UnetTransformer(
185
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
186
+ norm=norm, norm_first=norm_first,
187
+ layer_class=StreamingTransformerLayer,
188
+ **kwargs)
189
+ self.out_norm: tp.Optional[nn.Module] = None
190
+ if norm_first:
191
+ self.out_norm = create_norm_fn(norm, dim)
192
+ self.linear = nn.Linear(dim, flow_dim, bias=bias_proj)
193
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
194
+ self._fsdp: tp.Optional[nn.Module]
195
+ self.__dict__['_fsdp'] = None
196
+
197
+ # init time parameter embedding
198
+ self.d_temb1 = time_embedding_dim
199
+ self.d_temb2 = 4 * time_embedding_dim
200
+ self.temb = nn.Module()
201
+ self.temb.dense = nn.ModuleList([
202
+ torch.nn.Linear(self.d_temb1,
203
+ self.d_temb2),
204
+ torch.nn.Linear(self.d_temb2,
205
+ self.d_temb2),
206
+ ])
207
+ self.temb_proj = nn.Linear(self.d_temb2, dim)
208
+
209
+ def _get_timestep_embedding(self, timesteps, embedding_dim):
210
+ """
211
+ #######################################################################################################
212
+ TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
213
+ #######################################################################################################
214
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
215
+ From Fairseq.
216
+ Build sinusoidal embeddings.
217
+ This matches the implementation in tensor2tensor, but differs slightly
218
+ from the description in Section 3.5 of "Attention Is All You Need".
219
+ """
220
+ assert len(timesteps.shape) == 1
221
+
222
+ half_dim = embedding_dim // 2
223
+ emb = math.log(10000) / (half_dim - 1)
224
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
225
+ emb = emb.to(device=timesteps.device)
226
+ emb = timesteps.float()[:, None] * emb[None, :]
227
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
228
+ if embedding_dim % 2 == 1: # zero pad
229
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
230
+ return emb
231
+
232
+ def _embed_time_parameter(self, t: torch.Tensor):
233
+ """
234
+ #######################################################################################################
235
+ TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
236
+ #######################################################################################################
237
+ """
238
+ temb = self._get_timestep_embedding(t.flatten(), self.d_temb1)
239
+ temb = self.temb.dense[0](temb)
240
+ temb = temb * torch.sigmoid(temb) # swish activation
241
+ temb = self.temb.dense[1](temb)
242
+ return temb
243
+
244
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
245
+ """Initialization of the transformer module weights.
246
+
247
+ Args:
248
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
249
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
250
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
251
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
252
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
253
+ """
254
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
255
+ assert depthwise_init is None or weight_init is not None, \
256
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
257
+ assert not zero_bias_init or weight_init is not None, \
258
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
259
+
260
+ if weight_init is None:
261
+ return
262
+
263
+ init_layer(self.emb, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
264
+
265
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
266
+ depth = None
267
+ if depthwise_init == 'current':
268
+ depth = layer_idx + 1
269
+ elif depthwise_init == 'global':
270
+ depth = len(self.transformer.layers)
271
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
272
+ tr_layer.apply(init_fn)
273
+
274
+ init_layer(self.linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
275
+
276
+ def _align_seq_length(self,
277
+ cond: torch.Tensor,
278
+ seq_len: int = 500):
279
+ # trim if needed
280
+ cond = cond[:, :seq_len, :]
281
+
282
+ # pad if needed
283
+ B, T, C = cond.shape
284
+ if T < seq_len:
285
+ cond = torch.cat((cond, torch.zeros((B, seq_len - T, C), dtype=cond.dtype, device=cond.device)), dim=1)
286
+
287
+ return cond
288
+
289
+ def forward(self,
290
+ latents: torch.Tensor,
291
+ t: torch.Tensor,
292
+ conditions: tp.List[ConditioningAttributes],
293
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
294
+ """Apply flow matching forward pass on latents and conditions.
295
+ Given a tensor of noisy latents of shape [B, T, D] with D the flow dim and T the sequence steps,
296
+ and a time parameter tensor t, return the vector field with shape [B, T, D].
297
+
298
+ Args:
299
+ latents (torch.Tensor): noisy latents.
300
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
301
+ the given codes. Note that when evaluating multiple time with the same conditioning
302
+ you should pre-compute those and pass them as `condition_tensors`.
303
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
304
+ tensors, see `conditions`.
305
+ Returns:
306
+ torch.Tensor: estimated vector field v_theta.
307
+ """
308
+ assert condition_tensors is not None, "FlowMatchingModel require pre-calculation of condition tensors"
309
+ assert not conditions, "Shouldn't pass unprocessed conditions to FlowMatchingModel."
310
+
311
+ B, T, D = latents.shape
312
+ x = latents
313
+
314
+ # concat temporal conditions on the feature dimension
315
+ temporal_conds = JascoCondConst.ALL.value
316
+ for cond in temporal_conds:
317
+ if cond not in condition_tensors:
318
+ continue
319
+ c = self._align_seq_length(condition_tensors[cond][0], seq_len=T)
320
+ x = torch.concat((x, c), dim=-1)
321
+
322
+ # project to transformer dimension
323
+ input_ = self.emb(x)
324
+
325
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
326
+
327
+ # embed time parameter
328
+ t_embs = self._embed_time_parameter(t)
329
+
330
+ # add it to cross_attention_input
331
+ cross_attention_input = cross_attention_input + self.temb_proj(t_embs[:, None, :])
332
+
333
+ out = self.transformer(input_, cross_attention_src=cross_attention_input)
334
+
335
+ if self.out_norm:
336
+ out = self.out_norm(out)
337
+ v_theta = self.linear(out) # [B, T, D]
338
+
339
+ # remove the prefix from the model outputs
340
+ if len(self.fuser.fuse2cond['prepend']) > 0:
341
+ v_theta = v_theta[:, :, -T:]
342
+
343
+ return v_theta # [B, T, D]
344
+
345
+ def _multi_source_cfg_preprocess(self,
346
+ conditions: tp.List[ConditioningAttributes],
347
+ cfg_coef_all: float,
348
+ cfg_coef_txt: float,
349
+ min_weight: float = 1e-6):
350
+ """
351
+ Preprocesses the CFG terms for multi-source conditional generation.
352
+ Args:
353
+ conditions (list): A list of conditions to be applied.
354
+ cfg_coef_all (float): The coefficient for all conditions.
355
+ cfg_coef_txt (float): The coefficient for text conditions.
356
+ min_weight (float): The minimal absolute weight for calculating a CFG term.
357
+ Returns:
358
+ tuple: A tuple containing condition_tensors and cfg_terms.
359
+ condition_tensors is a dictionary or ConditionTensors object with tokenized conditions.
360
+ cfg_terms is a list of CFGTerm objects with weights adjusted based on the coefficients.
361
+ """
362
+ condition_tensors: tp.Optional[ConditionTensors]
363
+ cfg_terms = []
364
+ if conditions:
365
+ # conditional terms
366
+ cfg_terms = [AllCFGTerm(conditions=conditions, weight=cfg_coef_all),
367
+ TextCFGTerm(conditions=conditions, weight=cfg_coef_txt,
368
+ model_att_dropout=self.att_dropout)]
369
+
370
+ # add null term
371
+ cfg_terms.append(NullCFGTerm(conditions=conditions, weight=1 - sum([ct.weight for ct in cfg_terms])))
372
+
373
+ # remove terms with negligible weight
374
+ for ct in cfg_terms:
375
+ if abs(ct.weight) < min_weight:
376
+ cfg_terms.remove(ct)
377
+
378
+ conds: tp.List[ConditioningAttributes] = sum([ct.conditions for ct in cfg_terms], [])
379
+ tokenized = self.condition_provider.tokenize(conds)
380
+ condition_tensors = self.condition_provider(tokenized)
381
+ else:
382
+ condition_tensors = {}
383
+
384
+ return condition_tensors, cfg_terms
385
+
386
+ def estimated_vector_field(self, z, t, condition_tensors=None, cfg_terms=[]):
387
+ """
388
+ Estimates the vector field for the given latent variables and time parameter,
389
+ conditioned on the provided conditions.
390
+ Args:
391
+ z (Tensor): The latent variables.
392
+ t (float): The time variable.
393
+ condition_tensors (ConditionTensors, optional): The condition tensors. Defaults to None.
394
+ cfg_terms (list, optional): The list of CFG terms. Defaults to an empty list.
395
+ Returns:
396
+ Tensor: The estimated vector field.
397
+ """
398
+ if len(cfg_terms) > 1:
399
+ z = z.repeat(len(cfg_terms), 1, 1) # duplicate noisy latents for multi-source CFG
400
+ v_thetas = self(latents=z, t=t, conditions=[], condition_tensors=condition_tensors)
401
+ return self._multi_source_cfg_postprocess(v_thetas, cfg_terms)
402
+
403
+ def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms):
404
+ """
405
+ Postprocesses the vector fields generated for each CFG term to combine them into a single vector field.
406
+ Multi source guidance occurs here.
407
+ Args:
408
+ v_thetas (Tensor): The vector fields for each CFG term.
409
+ cfg_terms (list): The CFG terms used.
410
+ Returns:
411
+ Tensor: The combined vector field.
412
+ """
413
+ if len(cfg_terms) <= 1:
414
+ return v_thetas
415
+ v_theta_per_term = v_thetas.chunk(len(cfg_terms))
416
+ return sum([ct.weight * term_vf for ct, term_vf in zip(cfg_terms, v_theta_per_term)])
417
+
418
+ @torch.no_grad()
419
+ def generate(self,
420
+ prompt: tp.Optional[torch.Tensor] = None,
421
+ conditions: tp.List[ConditioningAttributes] = [],
422
+ num_samples: tp.Optional[int] = None,
423
+ max_gen_len: int = 256,
424
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
425
+ cfg_coef_all: float = 3.0,
426
+ cfg_coef_txt: float = 1.0,
427
+ euler: bool = False,
428
+ euler_steps: int = 100,
429
+ ode_rtol: float = 1e-5,
430
+ ode_atol: float = 1e-5,
431
+ ) -> torch.Tensor:
432
+ """
433
+ Generate audio latents given a prompt or unconditionally. This method supports both Euler integration
434
+ and adaptive ODE solving to generate sequences based on the specified conditions and configuration coefficients.
435
+
436
+ Args:
437
+ prompt (torch.Tensor, optional): Initial prompt to condition the generation. defaults to None
438
+ conditions (List[ConditioningAttributes]): List of conditioning attributes - text, symbolic or audio.
439
+ num_samples (int, optional): Number of samples to generate.
440
+ If None, it is inferred from the number of conditions.
441
+ max_gen_len (int): Maximum length of the generated sequence.
442
+ callback (Callable[[int, int], None], optional): Callback function to monitor the generation process.
443
+ cfg_coef_all (float): Coefficient for the fully conditional CFG term.
444
+ cfg_coef_txt (float): Coefficient for text CFG term.
445
+ euler (bool): If True, use Euler integration, otherwise use adaptive ODE solver.
446
+ euler_steps (int): Number of Euler steps to perform if Euler integration is used.
447
+ ode_rtol (float): ODE solver rtol threshold.
448
+ ode_atol (float): ODE solver atol threshold.
449
+
450
+ Returns:
451
+ torch.Tensor: Generated latents, shaped as (num_samples, max_gen_len, feature_dim).
452
+ """
453
+
454
+ assert not self.training, "generation shouldn't be used in training mode."
455
+ first_param = next(iter(self.parameters()))
456
+ device = first_param.device
457
+
458
+ # Checking all input shapes are consistent.
459
+ possible_num_samples = []
460
+ if num_samples is not None:
461
+ possible_num_samples.append(num_samples)
462
+ elif prompt is not None:
463
+ possible_num_samples.append(prompt.shape[0])
464
+ elif conditions:
465
+ possible_num_samples.append(len(conditions))
466
+ else:
467
+ possible_num_samples.append(1)
468
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
469
+ num_samples = possible_num_samples[0]
470
+
471
+ condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(conditions, cfg_coef_all, cfg_coef_txt)
472
+
473
+ # flow matching inference
474
+ B, T, D = num_samples, max_gen_len, self.flow_dim
475
+
476
+ z_0 = torch.randn((B, T, D), device=device)
477
+
478
+ if euler:
479
+ # vanilla Euler intergration
480
+ dt = (1 / euler_steps)
481
+ z = z_0
482
+ t = torch.zeros((1, ), device=device)
483
+ for _ in range(euler_steps):
484
+ v_theta = self.estimated_vector_field(z, t,
485
+ condition_tensors=condition_tensors,
486
+ cfg_terms=cfg_terms)
487
+ z = z + dt * v_theta
488
+ t = t + dt
489
+ z_1 = z
490
+ else:
491
+ # solve with dynamic ode integrator (dopri5)
492
+ t = torch.tensor([0, 1.0 - 1e-5], device=device)
493
+ num_evals = 0
494
+
495
+ # define ode vector field function
496
+ def inner_ode_func(t, z):
497
+ nonlocal num_evals
498
+ num_evals += 1
499
+ if callback is not None:
500
+ ESTIMATED_ODE_SOLVER_STEPS = 300
501
+ callback(num_evals, ESTIMATED_ODE_SOLVER_STEPS)
502
+ return self.estimated_vector_field(z, t,
503
+ condition_tensors=condition_tensors,
504
+ cfg_terms=cfg_terms)
505
+
506
+ ode_opts: dict = {"options": {}}
507
+ z = odeint(
508
+ inner_ode_func,
509
+ z_0,
510
+ t,
511
+ **{"atol": ode_atol, "rtol": ode_rtol, **ode_opts},
512
+ )
513
+ logger.info("Generated in %d steps", num_evals)
514
+ z_1 = z[-1]
515
+
516
+ return z_1
audiocraft/models/genmodel.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Base implementation for audio generative models. This base implementation
9
+ combines all the required components to run inference with pretrained audio
10
+ generative models. It can be easily inherited by downstream model classes to
11
+ provide easy access to the generation API.
12
+ """
13
+
14
+ from abc import ABC, abstractmethod
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+ import torch
19
+
20
+ from .encodec import CompressionModel
21
+ from .lm import LMModel
22
+ from .builders import get_wrapped_compression_model
23
+ from ..data.audio_utils import convert_audio
24
+ from ..modules.conditioners import ConditioningAttributes
25
+ from ..utils.autocast import TorchAutocast
26
+
27
+
28
+ class BaseGenModel(ABC):
29
+ """Base generative model with convenient generation API.
30
+
31
+ Args:
32
+ name (str): name of the model.
33
+ compression_model (CompressionModel): Compression model
34
+ used to map audio to invertible discrete representations.
35
+ lm (LMModel): Language model over discrete representations.
36
+ max_duration (float, optional): maximum duration the model can produce,
37
+ otherwise, inferred from the training params.
38
+ """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
40
+ max_duration: tp.Optional[float] = None):
41
+ self.name = name
42
+ self.compression_model = compression_model
43
+ self.lm = lm
44
+ self.cfg: tp.Optional[omegaconf.DictConfig] = None
45
+ # Just to be safe, let's put everything in eval mode.
46
+ self.compression_model.eval()
47
+ self.lm.eval()
48
+
49
+ if hasattr(lm, 'cfg'):
50
+ cfg = lm.cfg
51
+ assert isinstance(cfg, omegaconf.DictConfig)
52
+ self.cfg = cfg
53
+
54
+ if self.cfg is not None:
55
+ self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
56
+
57
+ if max_duration is None:
58
+ if self.cfg is not None:
59
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
60
+ else:
61
+ raise ValueError("You must provide max_duration when building directly your GenModel")
62
+ assert max_duration is not None
63
+
64
+ self.max_duration: float = max_duration
65
+ self.duration = self.max_duration
66
+
67
+ # self.extend_stride is the length of audio extension when generating samples longer
68
+ # than self.max_duration. NOTE: the derived class must set self.extend_stride to a
69
+ # positive float value when generating with self.duration > self.max_duration.
70
+ self.extend_stride: tp.Optional[float] = None
71
+ self.device = next(iter(lm.parameters())).device
72
+ self.generation_params: dict = {}
73
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
74
+ if self.device.type == 'cpu':
75
+ self.autocast = TorchAutocast(enabled=False)
76
+ else:
77
+ self.autocast = TorchAutocast(
78
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
79
+
80
+ @property
81
+ def frame_rate(self) -> float:
82
+ """Roughly the number of AR steps per seconds."""
83
+ return self.compression_model.frame_rate
84
+
85
+ @property
86
+ def sample_rate(self) -> int:
87
+ """Sample rate of the generated audio."""
88
+ return self.compression_model.sample_rate
89
+
90
+ @property
91
+ def audio_channels(self) -> int:
92
+ """Audio channels of the generated audio."""
93
+ return self.compression_model.channels
94
+
95
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
96
+ """Override the default progress callback."""
97
+ self._progress_callback = progress_callback
98
+
99
+ @abstractmethod
100
+ def set_generation_params(self, *args, **kwargs):
101
+ """Set the generation parameters."""
102
+ raise NotImplementedError("No base implementation for setting generation params.")
103
+
104
+ @staticmethod
105
+ @abstractmethod
106
+ def get_pretrained(name: str, device=None):
107
+ raise NotImplementedError("No base implementation for getting pretrained model")
108
+
109
+ @torch.no_grad()
110
+ def _prepare_tokens_and_attributes(
111
+ self,
112
+ descriptions: tp.Sequence[tp.Optional[str]],
113
+ prompt: tp.Optional[torch.Tensor],
114
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
115
+ """Prepare model inputs.
116
+
117
+ Args:
118
+ descriptions (list of str): A list of strings used as text conditioning.
119
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
120
+ """
121
+ attributes = [
122
+ ConditioningAttributes(text={'description': description})
123
+ for description in descriptions]
124
+
125
+ if prompt is not None:
126
+ if descriptions is not None:
127
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
128
+ prompt = prompt.to(self.device)
129
+ prompt_tokens, scale = self.compression_model.encode(prompt)
130
+ assert scale is None
131
+ else:
132
+ prompt_tokens = None
133
+ return attributes, prompt_tokens
134
+
135
+ def generate_unconditional(self, num_samples: int, progress: bool = False,
136
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
137
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
138
+ """Generate samples in an unconditional manner.
139
+
140
+ Args:
141
+ num_samples (int): Number of samples to be generated.
142
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
143
+ """
144
+ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
145
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
146
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
147
+ if return_tokens:
148
+ return self.generate_audio(tokens), tokens
149
+ return self.generate_audio(tokens)
150
+
151
+ def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
152
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
153
+ """Generate samples conditioned on text.
154
+
155
+ Args:
156
+ descriptions (list of str): A list of strings used as text conditioning.
157
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
158
+ """
159
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
160
+ assert prompt_tokens is None
161
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
162
+ if return_tokens:
163
+ return self.generate_audio(tokens), tokens
164
+ return self.generate_audio(tokens)
165
+
166
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
167
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
168
+ progress: bool = False, return_tokens: bool = False) \
169
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
170
+ """Generate samples conditioned on audio prompts and an optional text description.
171
+
172
+ Args:
173
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
174
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
175
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
176
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
177
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
178
+ """
179
+ if prompt.dim() == 2:
180
+ prompt = prompt[None]
181
+ if prompt.dim() != 3:
182
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
183
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
184
+ if descriptions is None:
185
+ descriptions = [None] * len(prompt)
186
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
187
+ assert prompt_tokens is not None
188
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
189
+ if return_tokens:
190
+ return self.generate_audio(tokens), tokens
191
+ return self.generate_audio(tokens)
192
+
193
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
194
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
195
+ """Generate discrete audio tokens given audio prompt and/or conditions.
196
+
197
+ Args:
198
+ attributes (list of ConditioningAttributes): Conditions used for generation (here text).
199
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
200
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
201
+ Returns:
202
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
203
+ """
204
+ total_gen_len = int(self.duration * self.frame_rate)
205
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
206
+ current_gen_offset: int = 0
207
+
208
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
209
+ generated_tokens += current_gen_offset
210
+ if self._progress_callback is not None:
211
+ # Note that total_gen_len might be quite wrong depending on the
212
+ # codebook pattern used, but with delay it is almost accurate.
213
+ self._progress_callback(generated_tokens, tokens_to_generate)
214
+ else:
215
+ print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
216
+
217
+ if prompt_tokens is not None:
218
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
219
+ "Prompt is longer than audio to generate"
220
+
221
+ callback = None
222
+ if progress:
223
+ callback = _progress_callback
224
+
225
+ if self.duration <= self.max_duration:
226
+ # generate by sampling from LM, simple case.
227
+ with self.autocast:
228
+ gen_tokens = self.lm.generate(
229
+ prompt_tokens, attributes,
230
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
231
+
232
+ else:
233
+ assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
234
+ assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
235
+ all_tokens = []
236
+ if prompt_tokens is None:
237
+ prompt_length = 0
238
+ else:
239
+ all_tokens.append(prompt_tokens)
240
+ prompt_length = prompt_tokens.shape[-1]
241
+
242
+ stride_tokens = int(self.frame_rate * self.extend_stride)
243
+ while current_gen_offset + prompt_length < total_gen_len:
244
+ time_offset = current_gen_offset / self.frame_rate
245
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
246
+ max_gen_len = int(chunk_duration * self.frame_rate)
247
+ with self.autocast:
248
+ gen_tokens = self.lm.generate(
249
+ prompt_tokens, attributes,
250
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
251
+ if prompt_tokens is None:
252
+ all_tokens.append(gen_tokens)
253
+ else:
254
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
255
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
256
+ prompt_length = prompt_tokens.shape[-1]
257
+ current_gen_offset += stride_tokens
258
+
259
+ gen_tokens = torch.cat(all_tokens, dim=-1)
260
+ return gen_tokens
261
+
262
+ def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
263
+ """Generate Audio from tokens."""
264
+ assert gen_tokens.dim() == 3
265
+ with torch.no_grad():
266
+ gen_audio = self.compression_model.decode(gen_tokens, None)
267
+ return gen_audio
audiocraft/models/lm.py CHANGED
@@ -23,6 +23,7 @@ from ..modules.conditioners import (
23
  ConditioningProvider,
24
  ConditioningAttributes,
25
  ConditionType,
 
26
  )
27
  from ..modules.codebooks_patterns import CodebooksPatternProvider
28
  from ..modules.activations import get_activation_fn
@@ -219,7 +220,8 @@ class LMModel(StreamingModule):
219
 
220
  def forward(self, sequence: torch.Tensor,
221
  conditions: tp.List[ConditioningAttributes],
222
- condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
 
223
  """Apply language model on sequence and conditions.
224
  Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
225
  S the sequence steps, return the logits with shape [B, card, K, S].
@@ -231,6 +233,9 @@ class LMModel(StreamingModule):
231
  you should pre-compute those and pass them as `condition_tensors`.
232
  condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
233
  tensors, see `conditions`.
 
 
 
234
  Returns:
235
  torch.Tensor: Logits.
236
  """
@@ -250,7 +255,8 @@ class LMModel(StreamingModule):
250
 
251
  input_, cross_attention_input = self.fuser(input_, condition_tensors)
252
 
253
- out = self.transformer(input_, cross_attention_src=cross_attention_input)
 
254
  if self.out_norm:
255
  out = self.out_norm(out)
256
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
@@ -264,7 +270,9 @@ class LMModel(StreamingModule):
264
  def compute_predictions(
265
  self, codes: torch.Tensor,
266
  conditions: tp.List[ConditioningAttributes],
267
- condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
 
 
268
  """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
269
  forward using the specified codes interleaving pattern.
270
 
@@ -276,6 +284,11 @@ class LMModel(StreamingModule):
276
  you should pre-compute those and pass them as `condition_tensors`.
277
  condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
278
  tensors, see `conditions`.
 
 
 
 
 
279
  Returns:
280
  LMOutput: Language model outputs
281
  logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
@@ -290,17 +303,18 @@ class LMModel(StreamingModule):
290
  # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
291
  pattern = self.pattern_provider.get_pattern(T)
292
  sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
293
- codes, self.special_token_id, keep_only_valid_steps=True
294
  )
 
295
  # apply model on pattern sequence
296
  model = self if self._fsdp is None else self._fsdp
297
- logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
298
  # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
299
  # and provide the corresponding mask over invalid positions of tokens
300
  logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
301
  # note: we use nans as special token to make it obvious if we feed unexpected logits
302
  logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
303
- logits, float('nan'), keep_only_valid_steps=True
304
  )
305
  logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
306
  logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
@@ -315,6 +329,7 @@ class LMModel(StreamingModule):
315
  top_k: int = 0,
316
  top_p: float = 0.0,
317
  cfg_coef: tp.Optional[float] = None,
 
318
  two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
319
  """Sample next token from the model given a sequence and a set of conditions. The model supports
320
  multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
@@ -330,6 +345,13 @@ class LMModel(StreamingModule):
330
  top_k (int): K for "top-k" sampling.
331
  top_p (float): P for "top-p" sampling.
332
  cfg_coef (float, optional): classifier free guidance coefficient
 
 
 
 
 
 
 
333
  Returns:
334
  next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
335
  """
@@ -337,7 +359,23 @@ class LMModel(StreamingModule):
337
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
338
  model = self if self._fsdp is None else self._fsdp
339
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
340
- if two_step_cfg and cfg_conditions != {}:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
342
  condition_tensors, null_condition_tensors = cfg_conditions
343
  cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
@@ -390,23 +428,30 @@ class LMModel(StreamingModule):
390
  top_k: int = 250,
391
  top_p: float = 0.0,
392
  cfg_coef: tp.Optional[float] = None,
 
393
  two_step_cfg: tp.Optional[bool] = None,
394
  remove_prompts: bool = False,
395
  check: bool = False,
396
- callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
 
397
  """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
398
- be perform in a greedy fashion or using sampling with top K and top P strategies.
399
 
400
  Args:
401
  prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
402
- conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
403
  num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
404
  max_gen_len (int): Maximum generation length.
405
  use_sampling (bool): Whether to use a sampling strategy or not.
406
  temp (float): Sampling temperature.
407
  top_k (int): K for "top-k" sampling.
408
  top_p (float): P for "top-p" sampling.
409
- cfg_coeff (float, optional): Classifier-free guidance coefficient.
 
 
 
 
 
410
  two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
411
  remove_prompts (bool): Whether to remove prompts from generation or not.
412
  check (bool): Whether to apply further checks on generated sequence.
@@ -441,18 +486,27 @@ class LMModel(StreamingModule):
441
  # the padding structure is exactly the same between train and test.
442
  # With a batch size of 1, this can be slower though.
443
  cfg_conditions: CFGConditions
444
- two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
445
- if conditions:
446
- null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
447
- if two_step_cfg:
448
- cfg_conditions = (
449
- self.condition_provider(self.condition_provider.tokenize(conditions)),
450
- self.condition_provider(self.condition_provider.tokenize(null_conditions)),
451
- )
452
- else:
453
- conditions = conditions + null_conditions
454
  tokenized = self.condition_provider.tokenize(conditions)
455
  cfg_conditions = self.condition_provider(tokenized)
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  else:
457
  cfg_conditions = {}
458
 
@@ -463,8 +517,8 @@ class LMModel(StreamingModule):
463
  B, K, T = prompt.shape
464
  start_offset = T
465
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
466
- assert start_offset <= max_gen_len
467
-
468
  pattern = self.pattern_provider.get_pattern(max_gen_len)
469
  # this token is used as default value for codes that are not generated yet
470
  unknown_token = -1
@@ -496,7 +550,7 @@ class LMModel(StreamingModule):
496
  # sample next token from the model, next token shape is [B, K, 1]
497
  next_token = self._sample_next_token(
498
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
499
- cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
500
  # ensure the tokens that should be masked are properly set to special_token_id
501
  # as the model never output special_token_id
502
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
 
23
  ConditioningProvider,
24
  ConditioningAttributes,
25
  ConditionType,
26
+ _drop_description_condition
27
  )
28
  from ..modules.codebooks_patterns import CodebooksPatternProvider
29
  from ..modules.activations import get_activation_fn
 
220
 
221
  def forward(self, sequence: torch.Tensor,
222
  conditions: tp.List[ConditioningAttributes],
223
+ condition_tensors: tp.Optional[ConditionTensors] = None,
224
+ stage: int = -1) -> torch.Tensor:
225
  """Apply language model on sequence and conditions.
226
  Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
227
  S the sequence steps, return the logits with shape [B, card, K, S].
 
233
  you should pre-compute those and pass them as `condition_tensors`.
234
  condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
235
  tensors, see `conditions`.
236
+ stage (int): The codebook level that is being predicted. Relevant for MAGNeT
237
+ in which prediction is done in a codebook-by-codebook manner.
238
+ Takes values in range(n_q), and ignored by default.
239
  Returns:
240
  torch.Tensor: Logits.
241
  """
 
255
 
256
  input_, cross_attention_input = self.fuser(input_, condition_tensors)
257
 
258
+ out = self.transformer(input_, cross_attention_src=cross_attention_input,
259
+ src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) # type: ignore
260
  if self.out_norm:
261
  out = self.out_norm(out)
262
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
 
270
  def compute_predictions(
271
  self, codes: torch.Tensor,
272
  conditions: tp.List[ConditioningAttributes],
273
+ condition_tensors: tp.Optional[ConditionTensors] = None,
274
+ stage: int = -1,
275
+ keep_only_valid_steps: bool = True) -> LMOutput:
276
  """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
277
  forward using the specified codes interleaving pattern.
278
 
 
284
  you should pre-compute those and pass them as `condition_tensors`.
285
  condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
286
  tensors, see `conditions`.
287
+ stage (int): The codebook level that is being predicted. Relevant for MAGNeT
288
+ in which prediction is done in a codebook-by-codebook manner.
289
+ Takes values in range(n_q), and ignored by default.
290
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
291
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
292
  Returns:
293
  LMOutput: Language model outputs
294
  logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
 
303
  # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
304
  pattern = self.pattern_provider.get_pattern(T)
305
  sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
306
+ codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
307
  )
308
+
309
  # apply model on pattern sequence
310
  model = self if self._fsdp is None else self._fsdp
311
+ logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
312
  # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
313
  # and provide the corresponding mask over invalid positions of tokens
314
  logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
315
  # note: we use nans as special token to make it obvious if we feed unexpected logits
316
  logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
317
+ logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
318
  )
319
  logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
320
  logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
 
329
  top_k: int = 0,
330
  top_p: float = 0.0,
331
  cfg_coef: tp.Optional[float] = None,
332
+ cfg_coef_beta: tp.Optional[float] = None,
333
  two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
334
  """Sample next token from the model given a sequence and a set of conditions. The model supports
335
  multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
 
345
  top_k (int): K for "top-k" sampling.
346
  top_p (float): P for "top-p" sampling.
347
  cfg_coef (float, optional): classifier free guidance coefficient
348
+ cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
349
+ If not None, we apply double classifier free guidance as introduced in MusicGen-Style
350
+ in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
351
+ push the text condition more than the style condition in the case where both text and style
352
+ conditions are being used.
353
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
354
+
355
  Returns:
356
  next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
357
  """
 
359
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
360
  model = self if self._fsdp is None else self._fsdp
361
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
362
+ if cfg_coef_beta is not None:
363
+ assert isinstance(cfg_conditions, dict)
364
+ condition_tensors = cfg_conditions
365
+ if condition_tensors:
366
+ # Preparing for CFG, predicting conditional text and style, conditional style
367
+ # and unconditional
368
+ sequence = torch.cat([sequence, sequence, sequence], dim=0)
369
+ all_logits = model(
370
+ sequence,
371
+ conditions=[], condition_tensors=condition_tensors)
372
+ if condition_tensors:
373
+ cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
374
+ logits = uncond_logits + cfg_coef * (
375
+ wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits
376
+ )
377
+
378
+ elif two_step_cfg and cfg_conditions != {}:
379
  assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
380
  condition_tensors, null_condition_tensors = cfg_conditions
381
  cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
 
428
  top_k: int = 250,
429
  top_p: float = 0.0,
430
  cfg_coef: tp.Optional[float] = None,
431
+ cfg_coef_beta: tp.Optional[float] = None,
432
  two_step_cfg: tp.Optional[bool] = None,
433
  remove_prompts: bool = False,
434
  check: bool = False,
435
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
436
+ ) -> torch.Tensor:
437
  """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
438
+ be performed in a greedy fashion or using sampling with top K and top P strategies.
439
 
440
  Args:
441
  prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
442
+ conditions (list of ConditioningAttributes, optional): List of conditions.
443
  num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
444
  max_gen_len (int): Maximum generation length.
445
  use_sampling (bool): Whether to use a sampling strategy or not.
446
  temp (float): Sampling temperature.
447
  top_k (int): K for "top-k" sampling.
448
  top_p (float): P for "top-p" sampling.
449
+ cfg_coef (float, optional): Classifier-free guidance coefficient.
450
+ cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
451
+ If not None, we apply double classifier free guidance as introduced in MusicGen-Style
452
+ in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
453
+ push the text condition more than the style condition in the case where both text and style
454
+ conditions are being used.
455
  two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
456
  remove_prompts (bool): Whether to remove prompts from generation or not.
457
  check (bool): Whether to apply further checks on generated sequence.
 
486
  # the padding structure is exactly the same between train and test.
487
  # With a batch size of 1, this can be slower though.
488
  cfg_conditions: CFGConditions
489
+ cfg_conditions = {}
490
+ if cfg_coef_beta is not None:
491
+ if conditions:
492
+ wav_conditions = _drop_description_condition(conditions)
493
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
494
+ conditions = conditions + wav_conditions + null_conditions
 
 
 
 
495
  tokenized = self.condition_provider.tokenize(conditions)
496
  cfg_conditions = self.condition_provider(tokenized)
497
+ elif conditions:
498
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
499
+ if conditions:
500
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
501
+ if two_step_cfg:
502
+ cfg_conditions = (
503
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
504
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
505
+ )
506
+ else:
507
+ conditions = conditions + null_conditions
508
+ tokenized = self.condition_provider.tokenize(conditions)
509
+ cfg_conditions = self.condition_provider(tokenized)
510
  else:
511
  cfg_conditions = {}
512
 
 
517
  B, K, T = prompt.shape
518
  start_offset = T
519
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
520
+ assert start_offset < max_gen_len
521
+
522
  pattern = self.pattern_provider.get_pattern(max_gen_len)
523
  # this token is used as default value for codes that are not generated yet
524
  unknown_token = -1
 
550
  # sample next token from the model, next token shape is [B, K, 1]
551
  next_token = self._sample_next_token(
552
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
553
+ cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg)
554
  # ensure the tokens that should be masked are properly set to special_token_id
555
  # as the model never output special_token_id
556
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
audiocraft/models/lm_magnet.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import math
9
+ import typing as tp
10
+ import torch
11
+ import numpy as np
12
+
13
+ from ..utils import utils
14
+ from ..modules.conditioners import (
15
+ ClassifierFreeGuidanceDropout,
16
+ ConditioningAttributes,
17
+ ConditionType,
18
+ )
19
+ from .lm import LMModel
20
+
21
+ logger = logging.getLogger(__name__)
22
+ ConditionTensors = tp.Dict[str, ConditionType]
23
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
24
+
25
+
26
+ class MagnetLMModel(LMModel):
27
+ """Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
28
+ Args:
29
+ subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
30
+ When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
31
+ compression_model_framerate (int): frame rate of the audio tokenizer.
32
+ segment_duration (int): Sample length in seconds.
33
+ span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
34
+ for both training and inference. Defaults to 3.
35
+ **kwargs: Additional parameters for the LMModel.
36
+ """
37
+ def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
38
+ segment_duration: int = 10, span_len: int = 3, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.causal = kwargs['causal']
41
+ self.subcodes_context = subcodes_context
42
+ self.span_len = span_len
43
+ self._build_attn_masks(compression_model_framerate=compression_model_framerate,
44
+ segment_duration=segment_duration,
45
+ num_heads=kwargs['num_heads'],
46
+ device=kwargs['device'], dtype=kwargs['dtype'])
47
+
48
+ def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
49
+ """Creates a restricted attention mask (local attention map) where the context
50
+ is determined by self.subcodes_context.
51
+ Args:
52
+ seq_len (int): token sequence length.
53
+ device (torch.device): device of the output tensor.
54
+ dtype (torch.dtype): data type of the output tensor.
55
+ Returns:
56
+ torch.Tensor: The restricted attention mask.
57
+ """
58
+ # Return a context restricted non-causal att mask
59
+ queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
60
+ keys_pos = torch.arange(seq_len, device=device).view(1, -1)
61
+
62
+ delta = queries_pos - keys_pos
63
+ valid = torch.abs(delta) <= self.subcodes_context
64
+ return torch.where(
65
+ valid,
66
+ torch.zeros([], device=device, dtype=dtype),
67
+ torch.full([], float('-inf'), device=device, dtype=dtype))
68
+
69
+ def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
70
+ device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
71
+ """Creates a restricted attention mask given the stage (codebook index).
72
+ Args:
73
+ stage (int): The codebook index. Takes values in [0, n_q].
74
+ seq_len (int): Token sequence length.
75
+ num_heads (int): Num transformer attention heads.
76
+ device (torch.device): device of the output tensor.
77
+ dtype (torch.dtype): data type of the output tensor.
78
+ Returns:
79
+ torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
80
+ """
81
+ sa_mask = None
82
+
83
+ if stage > 0 and self.subcodes_context > -1:
84
+ # parallel - non-causal - with restricted subcodes context
85
+ sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)
86
+
87
+ if sa_mask is not None:
88
+ # Repeat for each attention head
89
+ sa_mask = sa_mask.repeat((1, num_heads, 1, 1))
90
+
91
+ # align8 to enable memory efficient attention
92
+ MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
93
+ seq_len_aligned = \
94
+ int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR
95
+
96
+ sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
97
+ sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
98
+ sa_mask = sa_mask_aligned
99
+
100
+ return sa_mask
101
+
102
+ def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
103
+ device: torch.device, dtype: torch.dtype):
104
+ """Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
105
+ either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
106
+ Args:
107
+ compression_model_framerate (int): The frame rate of the tokenizer.
108
+ segment_duration (int): Sample length in seconds.
109
+ num_heads (int): Num transformer attention heads.
110
+ device (torch.device): device of the output tensor.
111
+ dtype (torch.dtype): data type of the output tensor.
112
+ """
113
+ seq_len = compression_model_framerate * segment_duration
114
+ self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
115
+ device, dtype) for stage in range(self.n_q)]
116
+
117
+ @torch.no_grad()
118
+ def generate(self,
119
+ prompt: tp.Optional[torch.Tensor] = None,
120
+ conditions: tp.List[ConditioningAttributes] = [],
121
+ num_samples: tp.Optional[int] = None,
122
+ max_gen_len: int = 256,
123
+ use_sampling: bool = True,
124
+ temp: float = 1.0,
125
+ top_k: int = 250,
126
+ top_p: float = 0.0,
127
+ cfg_coef: tp.Optional[float] = None,
128
+ cfg_coef_beta: tp.Optional[float] = None,
129
+ two_step_cfg: tp.Optional[bool] = None,
130
+ remove_prompts: bool = False,
131
+ check: bool = False,
132
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
133
+ **kwargs) -> torch.Tensor:
134
+
135
+ assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
136
+ assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
137
+ assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
138
+ assert check is False, "MAGNeT currently doesn't support the check arg."
139
+ assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg."
140
+ # Call the MAGNeT-specific generation method
141
+ return self._generate_magnet(prompt=prompt,
142
+ conditions=conditions,
143
+ num_samples=num_samples,
144
+ max_gen_len=max_gen_len,
145
+ use_sampling=use_sampling,
146
+ temp=temp,
147
+ top_k=top_k,
148
+ top_p=top_p,
149
+ callback=callback, **kwargs)
150
+
151
+ @torch.no_grad()
152
+ def _generate_magnet(self,
153
+ prompt: tp.Optional[torch.Tensor] = None,
154
+ conditions: tp.List[ConditioningAttributes] = [],
155
+ num_samples: tp.Optional[int] = None,
156
+ max_gen_len: int = 256,
157
+ use_sampling: bool = True,
158
+ temp: float = 3.0,
159
+ top_k: int = 0,
160
+ top_p: float = 0.9,
161
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
162
+ max_cfg_coef: float = 10.0,
163
+ min_cfg_coef: float = 1.0,
164
+ decoding_steps: tp.List[int] = [20, 10, 10, 10],
165
+ anneal_temp: bool = True,
166
+ span_scoring='max',
167
+ span_arrangement='nonoverlap') -> torch.Tensor:
168
+ """Generate audio tokens given textual conditions, and optionally given audio prompts,
169
+ by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
170
+ Args:
171
+ prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
172
+ conditions (list of ConditioningAttributes): List of conditions.
173
+ num_samples (int): Number of samples to generate when no prompt and no conditions are given.
174
+ max_gen_len (int): Maximum generation length.
175
+ use_sampling (bool): Whether to use a sampling strategy or not.
176
+ temp (float): Initial sampling temperature.
177
+ top_k (int): k for "top-k" sampling.
178
+ top_p (float): p for "top-p" sampling.
179
+ callback (Callback): Callback function to report generation progress.
180
+ max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
181
+ min_clsfg_coef (float): Final coefficient used for classifier free guidance.
182
+ decoding_steps (list of n_q ints): The number of iterative decoding steps,
183
+ for each of the n_q RVQ codebooks.
184
+ anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
185
+ span_scoring (str): Use the maximum probability of each span ('max')
186
+ or the product of probabilities ('prod').
187
+ span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
188
+ in the masking scheme.
189
+ Returns:
190
+ torch.Tensor: Generated tokens.
191
+ """
192
+ assert not self.training, "generation shouldn't be used in training mode."
193
+ first_param = next(iter(self.parameters()))
194
+ device = first_param.device
195
+
196
+ # Checking all input shapes are consistent.
197
+ possible_num_samples = []
198
+ if num_samples is not None:
199
+ possible_num_samples.append(num_samples)
200
+ elif prompt is not None:
201
+ possible_num_samples.append(prompt.shape[0])
202
+ elif conditions:
203
+ possible_num_samples.append(len(conditions))
204
+ else:
205
+ possible_num_samples.append(1)
206
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
207
+ num_samples = possible_num_samples[0]
208
+
209
+ # below we create set of conditions: one conditional and one unconditional
210
+ # to do that we merge the regular condition together with the null condition
211
+ # we then do 1 forward pass instead of 2.
212
+ cfg_conditions: tp.Optional[ConditionTensors]
213
+ if conditions:
214
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
215
+ conditions = conditions + null_conditions
216
+ tokenized = self.condition_provider.tokenize(conditions)
217
+ cfg_conditions = self.condition_provider(tokenized)
218
+ else:
219
+ cfg_conditions = {}
220
+
221
+ if prompt is None:
222
+ assert num_samples > 0
223
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
224
+
225
+ B, K, prompt_length = prompt.shape
226
+ start_offset = prompt_length
227
+ assert start_offset < max_gen_len
228
+
229
+ mask_id = self.special_token_id
230
+
231
+ # we generate codes with a fixed sequence length
232
+ shape = (B, K, max_gen_len)
233
+
234
+ gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
235
+ # filling the gen_codes with the prompt if needed
236
+ gen_codes[..., :start_offset] = prompt
237
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
238
+ gen_sequence = gen_codes
239
+
240
+ curr_step = 0
241
+ for stage, n_steps in zip(range(self.n_q), decoding_steps):
242
+ gen_sequence, curr_step = self._generate_stage(gen_sequence,
243
+ cfg_conditions,
244
+ stage=stage,
245
+ device=device,
246
+ prompt_length=prompt_length,
247
+ prompt=prompt,
248
+ temp=temp,
249
+ max_cfg_coef=max_cfg_coef,
250
+ min_cfg_coef=min_cfg_coef,
251
+ top_k=top_k,
252
+ top_p=top_p,
253
+ timesteps=n_steps,
254
+ anneal_temp=anneal_temp,
255
+ span_scoring=span_scoring,
256
+ use_sampling=use_sampling,
257
+ span_arrangement=span_arrangement,
258
+ curr_step=curr_step,
259
+ total_steps=sum(decoding_steps),
260
+ callback=callback)
261
+
262
+ return gen_sequence
263
+
264
+ @torch.no_grad()
265
+ def _generate_stage(self,
266
+ gen_sequence: torch.Tensor,
267
+ condition_tensors: tp.Optional[ConditionTensors],
268
+ stage: int,
269
+ device: torch.device,
270
+ prompt_length: int = 0,
271
+ prompt: tp.Optional[torch.Tensor] = None,
272
+ use_sampling: bool = True,
273
+ temp: float = 3.0,
274
+ max_cfg_coef: float = 10.0,
275
+ min_cfg_coef: float = 1.0,
276
+ top_k: int = 0,
277
+ top_p: float = 0.0,
278
+ timesteps: int = 10,
279
+ anneal_temp: bool = True,
280
+ span_scoring: str = 'max',
281
+ span_arrangement: str = 'nonoverlap',
282
+ curr_step: int = 0,
283
+ total_steps: int = 0,
284
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
285
+ """Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
286
+ and the textual conditions.
287
+ Args:
288
+ gen_sequence (torch.Tensor): Previously generated tokens.
289
+ condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
290
+ stage (int): RVQ level to generate.
291
+ device (torch.device): device of the output tensor.
292
+ prompt_length (int): Temporal length of the audio prompt.
293
+ prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
294
+ use_sampling (bool): Whether to use a sampling strategy or not.
295
+ temp (float): Initial sampling temperature.
296
+ max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
297
+ min_clsfg_coef (float): Final coefficient used for classifier free guidance.
298
+ top_k (int): k for "top-k" sampling.
299
+ top_p (float): p for "top-p" sampling.
300
+ timesteps (int): Number of iterative decoding steps.
301
+ anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
302
+ span_scoring (str): Use the maximum probability of each span ('max')
303
+ or the product of probabilities ('prod').
304
+ span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
305
+ in the masking scheme.
306
+ curr_step (int): Global iterative decoding step counter.
307
+ total_steps (int): Total decoding steps.
308
+ callback (Callback): Callback function to report generation progress.
309
+ Returns:
310
+ tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
311
+ """
312
+ B, K, T = gen_sequence.shape
313
+ shape = (B, 1, T) # generating a single codebook per stage
314
+
315
+ mask_id = self.special_token_id
316
+ stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)
317
+
318
+ assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
319
+ chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'
320
+
321
+ DONT_REMASK_ME_SCORE = -1e4
322
+
323
+ model = self if self._fsdp is None else self._fsdp
324
+
325
+ if chunk_masking:
326
+ # span-wise scores
327
+ n_chunks = T // self.span_len
328
+ if T % self.span_len != 0:
329
+ # trim sequence ending to achieve a multiple of span_len
330
+ T = self.span_len * n_chunks
331
+ gen_sequence = gen_sequence[..., :T]
332
+ stage_gen_seq = stage_gen_seq[..., :T]
333
+
334
+ chunked_shape = (B, 1, n_chunks)
335
+ n_prompt_chunks = prompt_length // self.span_len
336
+ scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
337
+ scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
338
+ num_chunks_to_gen = n_chunks - n_prompt_chunks
339
+ else:
340
+ # token-wise scores
341
+ scores = torch.zeros(shape, dtype=torch.float32, device=device)
342
+ scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
343
+ gen_T = T - prompt_length
344
+
345
+ # run MAGNeT iterative decoding for "timesteps" iterations
346
+ for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
347
+
348
+ mask_p = torch.cos(timestep * math.pi * 0.5)
349
+
350
+ if chunk_masking:
351
+ num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
352
+ else:
353
+ num_masked = max(int((mask_p * gen_T).item()), 1)
354
+
355
+ # masking
356
+ run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
357
+ if run_lps_masking:
358
+ # masking of the k least probable overlapping (stride 1) spans
359
+ mask = torch.concat((
360
+ [self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
361
+ for i in range(B)]), dim=0)
362
+ stage_gen_seq[mask] = mask_id
363
+ else:
364
+ # masking of the k least probable non-overlapping spans
365
+ masked = scores.topk(num_masked, dim=-1).indices
366
+ if chunk_masking:
367
+ chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
368
+ chunks_mask = chunks_mask.scatter(2, masked, True)
369
+ mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
370
+ stage_gen_seq[mask] = mask_id
371
+ else:
372
+ stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)
373
+
374
+ if prompt is not None:
375
+ stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)
376
+
377
+ gen_sequence[:, [stage], :] = stage_gen_seq
378
+ if condition_tensors:
379
+ # duplicate input for classifier free guidance
380
+ sequence = torch.cat([gen_sequence, gen_sequence], dim=0)
381
+
382
+ all_logits = model(sequence, [], condition_tensors, stage=stage)
383
+
384
+ if condition_tensors:
385
+ # classifier free guidance with annealing
386
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
387
+ clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
388
+ logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
389
+ else:
390
+ logits = all_logits
391
+
392
+ # temperature annealing - linear
393
+ t = temp * (steps_left / timesteps) if anneal_temp else temp
394
+
395
+ # sampling
396
+ logits = logits[:, stage, :, :].unsqueeze(1)
397
+ probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
398
+ if use_sampling:
399
+ if top_p > 0.0:
400
+ sampled_tokens = utils.sample_top_p(probs, p=top_p)
401
+ elif top_k > 0:
402
+ sampled_tokens = utils.sample_top_k(probs, k=top_k)
403
+ else:
404
+ sampled_tokens = utils.multinomial(probs, num_samples=1)
405
+ else:
406
+ sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)
407
+
408
+ # place mask_id token in each of the masked positions
409
+ mask = stage_gen_seq == mask_id
410
+ stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
411
+ gen_sequence[:, [stage], :] = stage_gen_seq
412
+
413
+ # get probs of sampled tokens
414
+ sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]
415
+
416
+ # span scoring
417
+ if chunk_masking:
418
+ if span_scoring == 'max':
419
+ # max in linear space
420
+ scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
421
+ elif span_scoring == 'prod':
422
+ # prod in log space
423
+ scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
424
+ else:
425
+ raise NotImplementedError
426
+ else:
427
+ # prod in log space for lps masking (stride1)
428
+ scores = -torch.log(sampled_probs)
429
+
430
+ # Fix unmasked tokens by placing inf probs (-inf scores)
431
+ if chunk_masking:
432
+ scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
433
+ else:
434
+ scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)
435
+
436
+ if callback is not None:
437
+ curr_step += 1
438
+ callback(curr_step, total_steps)
439
+
440
+ return gen_sequence, curr_step
441
+
442
+ def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
443
+ """Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
444
+ span_starts defines the initial index of each span, and the span length is
445
+ defined by self.span_len.
446
+ Args:
447
+ span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
448
+ T (int): Sequence length.
449
+ device (torch.device): device of the output tensor.
450
+ Returns:
451
+ torch.Tensor: Spans mask of shape [1x1xT]
452
+ """
453
+ mask = torch.full((1, 1, T), False, device=device)
454
+ mask[:, :, span_starts] = True
455
+ shifted_mask = mask.clone()
456
+ for _ in range(self.span_len - 1):
457
+ shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
458
+ mask = torch.logical_or(mask, shifted_mask)
459
+ return mask
460
+
461
+ def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
462
+ """Construct a [1x1xT] boolean mask, consists of the u least probable spans,
463
+ where the token probability is determined by -scores, and the total
464
+ number of masked tokens is as closest as possible to num_masked_trg.
465
+ Find u using binary search.
466
+ Args:
467
+ scores (torch.Tensor): Per token score [-log(prob)]
468
+ num_masked_trg: int: The desired amount of tokens to be masked.
469
+ Returns:
470
+ torch.Tensor: Spans mask of shape [1x1xT]
471
+ """
472
+ T = scores.shape[-1]
473
+ device = scores.device
474
+ scores_unfolded = scores.unfold(2, self.span_len, 1)
475
+ # Span score is the product of probs (sum in log space)
476
+ span_scores = scores_unfolded.sum(dim=-1)
477
+ spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)
478
+
479
+ num_masked_trg = max(num_masked_trg, self.span_len)
480
+
481
+ # Binary search for u - the number least probable overlapping masked spans s.t.
482
+ # the total masking rate is the closest to num_masked_trg / T.
483
+ min_u = num_masked_trg // self.span_len
484
+ max_u = num_masked_trg - self.span_len + 1
485
+ mid = round(0.5 * (min_u + max_u))
486
+
487
+ if mid == min_u or mid == max_u:
488
+ return self._construct_spans_mask(spans_by_scores[:mid], T, device)
489
+
490
+ while mid > min_u and mid < max_u:
491
+ mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
492
+ n_masked = mask.sum()
493
+ if n_masked > num_masked_trg:
494
+ max_u = mid
495
+ mid = round(0.5 * (min_u + max_u))
496
+ else:
497
+ min_u = mid
498
+ mid = round(0.5 * (min_u + max_u))
499
+
500
+ return mask
audiocraft/models/loaders.py CHANGED
@@ -28,6 +28,7 @@ from omegaconf import OmegaConf, DictConfig
28
  import torch
29
 
30
  import audiocraft
 
31
  from . import builders
32
  from .encodec import CompressionModel
33
 
@@ -47,6 +48,7 @@ HF_MODEL_CHECKPOINTS_MAP = {
47
  "stereo-large": "facebook/musicgen-stereo-large",
48
  "stereo-melody": "facebook/musicgen-stereo-melody",
49
  "stereo-melody-large": "facebook/musicgen-stereo-melody-large",
 
50
  }
51
 
52
 
@@ -156,7 +158,7 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
156
  # Handle newer model formats that might not have xp.cfg
157
  if 'xp.cfg' not in pkg:
158
  if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
159
- 'stereo-small', 'stereo-large', 'stereo-melody-large']:
160
  print(f"Using fallback configuration for {file_or_url_or_id}")
161
  # Create a default configuration based on the model type
162
  # This is where you'd need to add model-specific configurations
@@ -212,6 +214,52 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
212
  return model
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
216
  filename: tp.Optional[str] = None,
217
  cache_dir: tp.Optional[str] = None):
 
28
  import torch
29
 
30
  import audiocraft
31
+
32
  from . import builders
33
  from .encodec import CompressionModel
34
 
 
48
  "stereo-large": "facebook/musicgen-stereo-large",
49
  "stereo-melody": "facebook/musicgen-stereo-melody",
50
  "stereo-melody-large": "facebook/musicgen-stereo-melody-large",
51
+ "style": "facebook/musicgen-style",
52
  }
53
 
54
 
 
158
  # Handle newer model formats that might not have xp.cfg
159
  if 'xp.cfg' not in pkg:
160
  if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
161
+ 'stereo-small', 'stereo-large', 'stereo-melody-large','style']:
162
  print(f"Using fallback configuration for {file_or_url_or_id}")
163
  # Create a default configuration based on the model type
164
  # This is where you'd need to add model-specific configurations
 
214
  return model
215
 
216
 
217
+ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
218
+ device='cpu', cache_dir: tp.Optional[str] = None):
219
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
220
+ cfg = OmegaConf.create(pkg['xp.cfg'])
221
+ cfg.device = str(device)
222
+ if cfg.device == 'cpu':
223
+ cfg.dtype = 'float32'
224
+ else:
225
+ cfg.dtype = 'float16'
226
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
227
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
228
+
229
+ cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
230
+ cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
231
+ cfg.transformer_lm.span_len = cfg.masking.span_len
232
+
233
+ # MAGNeT models v1 support only xformers backend.
234
+ from audiocraft.modules.transformer import set_efficient_attention_backend
235
+
236
+ if cfg.transformer_lm.memory_efficient:
237
+ set_efficient_attention_backend("xformers")
238
+
239
+ model = builders.get_lm_model(cfg)
240
+ model.load_state_dict(pkg['best_state'])
241
+ model.eval()
242
+ model.cfg = cfg
243
+ return model
244
+
245
+
246
+ def load_jasco_model(file_or_url_or_id: tp.Union[Path, str],
247
+ compression_model: CompressionModel,
248
+ device='cpu', cache_dir: tp.Optional[str] = None):
249
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
250
+ cfg = OmegaConf.create(pkg['xp.cfg'])
251
+ cfg.device = str(device)
252
+ if cfg.device == 'cpu':
253
+ cfg.dtype = 'float32'
254
+ else:
255
+ cfg.dtype = 'float16'
256
+ model = builders.get_jasco_model(cfg, compression_model)
257
+ model.load_state_dict(pkg['best_state'])
258
+ model.eval()
259
+ model.cfg = cfg
260
+ return model
261
+
262
+
263
  def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
264
  filename: tp.Optional[str] = None,
265
  cache_dir: tp.Optional[str] = None):
audiocraft/models/magnet.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using MAGNeT. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+ import typing as tp
12
+ import torch
13
+
14
+ from .genmodel import BaseGenModel
15
+ from .loaders import load_compression_model, load_lm_model_magnet
16
+
17
+
18
+ class MAGNeT(BaseGenModel):
19
+ """MAGNeT main model with convenient generation API.
20
+ Args:
21
+ See MusicGen class.
22
+ """
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+ # MAGNeT operates over a fixed sequence length defined in it's config.
26
+ self.duration = self.lm.cfg.dataset.segment_duration
27
+ self.set_generation_params()
28
+
29
+ @staticmethod
30
+ def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
31
+ """Return pretrained model, we provide six models:
32
+ - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
33
+ # see: https://huggingface.co/facebook/magnet-small-10secs
34
+ - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
35
+ # see: https://huggingface.co/facebook/magnet-medium-10secs
36
+ - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
37
+ # see: https://huggingface.co/facebook/magnet-small-30secs
38
+ - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
39
+ # see: https://huggingface.co/facebook/magnet-medium-30secs
40
+ - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
41
+ # see: https://huggingface.co/facebook/audio-magnet-small
42
+ - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
43
+ # see: https://huggingface.co/facebook/audio-magnet-medium
44
+ """
45
+ if device is None:
46
+ if torch.cuda.device_count():
47
+ device = 'cuda'
48
+ else:
49
+ device = 'cpu'
50
+
51
+ compression_model = load_compression_model(name, device=device)
52
+ lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)
53
+
54
+ if 'self_wav' in lm.condition_provider.conditioners:
55
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
56
+
57
+ kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
58
+ return MAGNeT(**kwargs)
59
+
60
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
61
+ top_p: float = 0.9, temperature: float = 3.0,
62
+ max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
63
+ decoding_steps: tp.List[int] = [20, 10, 10, 10],
64
+ span_arrangement: str = 'nonoverlap'):
65
+ """Set the generation parameters for MAGNeT.
66
+
67
+ Args:
68
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
69
+ top_k (int, optional): top_k used for sampling. Defaults to 0.
70
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
71
+ temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
72
+ max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
73
+ min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
74
+ decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
75
+ for each of the n_q RVQ codebooks.
76
+ span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
77
+ or overlapping spans ('stride1') in the masking scheme.
78
+ """
79
+ self.generation_params = {
80
+ 'use_sampling': use_sampling,
81
+ 'temp': temperature,
82
+ 'top_k': top_k,
83
+ 'top_p': top_p,
84
+ 'max_cfg_coef': max_cfg_coef,
85
+ 'min_cfg_coef': min_cfg_coef,
86
+ 'decoding_steps': [int(s) for s in decoding_steps],
87
+ 'span_arrangement': span_arrangement
88
+ }
audiocraft/models/musicgen.py CHANGED
@@ -18,11 +18,12 @@ import torch
18
  import gradio as gr
19
 
20
  from .encodec import CompressionModel
 
21
  from .lm import LMModel
22
  from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
23
  from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
24
  from ..data.audio_utils import convert_audio
25
- from ..modules.conditioners import ConditioningAttributes, WavCondition
26
  from ..utils.autocast import TorchAutocast
27
 
28
  MelodyList = tp.List[tp.Optional[torch.Tensor]]
@@ -108,6 +109,7 @@ class MusicGen:
108
  - stereo-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-stereo-melody
109
  - stereo-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-large
110
  - stereo-melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-stereo-melody-large
 
111
  """
112
 
113
  if device is None:
@@ -120,7 +122,7 @@ class MusicGen:
120
  # used only for unit tests
121
  compression_model = get_debug_compression_model(device)
122
  lm = get_debug_lm_model(device)
123
- return MusicGen(name, compression_model, lm)
124
 
125
  if name not in HF_MODEL_CHECKPOINTS_MAP:
126
  if not os.path.isfile(name) and not os.path.isdir(name):
@@ -143,6 +145,7 @@ class MusicGen:
143
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
144
  top_p: float = 0.0, temperature: float = 1.0,
145
  duration: float = 30.0, cfg_coef: float = 3.0,
 
146
  two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
147
  """Set the generation parameters for MusicGen.
148
 
@@ -153,6 +156,10 @@ class MusicGen:
153
  temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
154
  duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
155
  cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
 
 
 
 
156
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
157
  instead of batching together the two. This has some impact on how things
158
  are padded but seems to have little impact in practice.
@@ -172,8 +179,30 @@ class MusicGen:
172
  'top_p': top_p,
173
  'cfg_coef': cfg_coef,
174
  'two_step_cfg': two_step_cfg,
 
175
  }
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
178
  """Override the default progress callback."""
179
  self._progress_callback = progress_callback
@@ -399,8 +428,8 @@ class MusicGen:
399
  """Generate discrete audio tokens given audio prompt and/or conditions.
400
 
401
  Args:
402
- attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
403
- prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
404
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
405
  Returns:
406
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
 
18
  import gradio as gr
19
 
20
  from .encodec import CompressionModel
21
+ from .genmodel import BaseGenModel
22
  from .lm import LMModel
23
  from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
24
  from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
25
  from ..data.audio_utils import convert_audio
26
+ from ..modules.conditioners import ConditioningAttributes, WavCondition, StyleConditioner
27
  from ..utils.autocast import TorchAutocast
28
 
29
  MelodyList = tp.List[tp.Optional[torch.Tensor]]
 
109
  - stereo-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-stereo-melody
110
  - stereo-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-large
111
  - stereo-melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-stereo-melody-large
112
+ - musicgen-style (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-style
113
  """
114
 
115
  if device is None:
 
122
  # used only for unit tests
123
  compression_model = get_debug_compression_model(device)
124
  lm = get_debug_lm_model(device)
125
+ return MusicGen(name, compression_model, lm, max_duration=30)
126
 
127
  if name not in HF_MODEL_CHECKPOINTS_MAP:
128
  if not os.path.isfile(name) and not os.path.isdir(name):
 
145
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
146
  top_p: float = 0.0, temperature: float = 1.0,
147
  duration: float = 30.0, cfg_coef: float = 3.0,
148
+ cfg_coef_beta: tp.Optional[float] = None,
149
  two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
150
  """Set the generation parameters for MusicGen.
151
 
 
156
  temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
157
  duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
158
  cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
159
+ cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
160
+ Should be only used for MusicGen melody if we want to push the text condition more than
161
+ the audio conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
162
+ double CFG.
163
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
164
  instead of batching together the two. This has some impact on how things
165
  are padded but seems to have little impact in practice.
 
179
  'top_p': top_p,
180
  'cfg_coef': cfg_coef,
181
  'two_step_cfg': two_step_cfg,
182
+ 'cfg_coef_beta': cfg_coef_beta,
183
  }
184
 
185
+ def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0,
186
+ ds_factor: tp.Optional[int] = None,
187
+ encodec_n_q: tp.Optional[int] = None) -> None:
188
+ """Set the parameters of the style conditioner
189
+ Args:
190
+ eval_q (int): the number of residual quantization streams used to quantize the style condition
191
+ the smaller it is, the narrower is the information bottleneck
192
+ excerpt_length (float): the excerpt length in seconds that is extracted from the audio
193
+ conditioning
194
+ ds_factor: (int): the downsampling factor used to downsample the style tokens before
195
+ using them as a prefix
196
+ encodec_n_q: (int, optional): if encodec is used as a feature extractor, sets the number
197
+ of streams that is used to extract features
198
+ """
199
+ assert isinstance(self.lm.condition_provider.conditioners.self_wav, StyleConditioner), \
200
+ "Only use this function if you model is MusicGen-Style"
201
+ self.lm.condition_provider.conditioners.self_wav.set_params(eval_q=eval_q,
202
+ excerpt_length=excerpt_length,
203
+ ds_factor=ds_factor,
204
+ encodec_n_q=encodec_n_q)
205
+
206
  def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
207
  """Override the default progress callback."""
208
  self._progress_callback = progress_callback
 
428
  """Generate discrete audio tokens given audio prompt and/or conditions.
429
 
430
  Args:
431
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
432
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
433
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
434
  Returns:
435
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
audiocraft/modules/codebooks_patterns.py CHANGED
@@ -30,7 +30,7 @@ class Pattern:
30
 
31
  The pattern provides convenient methods to build and revert interleaved sequences from it:
32
  ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
33
- to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
34
  K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
35
  for the output sequence. The unfilled positions are replaced with a special token and the built sequence
36
  is returned along with a mask indicating valid tokens.
@@ -49,7 +49,6 @@ class Pattern:
49
 
50
  def __post_init__(self):
51
  assert len(self.layout) > 0
52
- assert self.layout[0] == []
53
  self._validate_layout()
54
  self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
55
  self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
@@ -93,6 +92,9 @@ class Pattern:
93
  valid_step = len(self.layout) - self.max_delay
94
  return self.layout[:valid_step]
95
 
 
 
 
96
  def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
97
  """Get codebook coordinates in the layout that corresponds to the specified timestep t
98
  and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
@@ -202,7 +204,7 @@ class Pattern:
202
  f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
203
 
204
  # ensure we take the appropriate indexes to keep the model output from the first special token as well
205
- if is_model_output:
206
  ref_layout = ref_layout[1:]
207
 
208
  # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
@@ -335,7 +337,8 @@ class DelayedPatternProvider(CodebooksPatternProvider):
335
  assert sorted(self.delays) == self.delays
336
 
337
  def get_pattern(self, timesteps: int) -> Pattern:
338
- out: PatternLayout = [[]]
 
339
  max_delay = max(self.delays)
340
  if self.empty_initial:
341
  out += [[] for _ in range(self.empty_initial)]
@@ -360,9 +363,10 @@ class ParallelPatternProvider(DelayedPatternProvider):
360
 
361
  Args:
362
  n_q (int): Number of codebooks.
 
363
  """
364
- def __init__(self, n_q: int):
365
- super().__init__(n_q, [0] * n_q)
366
 
367
 
368
  class UnrolledPatternProvider(CodebooksPatternProvider):
 
30
 
31
  The pattern provides convenient methods to build and revert interleaved sequences from it:
32
  ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
33
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
34
  K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
35
  for the output sequence. The unfilled positions are replaced with a special token and the built sequence
36
  is returned along with a mask indicating valid tokens.
 
49
 
50
  def __post_init__(self):
51
  assert len(self.layout) > 0
 
52
  self._validate_layout()
53
  self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
  self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
 
92
  valid_step = len(self.layout) - self.max_delay
93
  return self.layout[:valid_step]
94
 
95
+ def starts_with_special_token(self):
96
+ return self.layout[0] == []
97
+
98
  def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
99
  """Get codebook coordinates in the layout that corresponds to the specified timestep t
100
  and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
 
204
  f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
205
 
206
  # ensure we take the appropriate indexes to keep the model output from the first special token as well
207
+ if is_model_output and self.starts_with_special_token():
208
  ref_layout = ref_layout[1:]
209
 
210
  # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
 
337
  assert sorted(self.delays) == self.delays
338
 
339
  def get_pattern(self, timesteps: int) -> Pattern:
340
+ omit_special_token = self.empty_initial < 0
341
+ out: PatternLayout = [] if omit_special_token else [[]]
342
  max_delay = max(self.delays)
343
  if self.empty_initial:
344
  out += [[] for _ in range(self.empty_initial)]
 
363
 
364
  Args:
365
  n_q (int): Number of codebooks.
366
+ empty_initial (int): Prepend with N empty list of coordinates.
367
  """
368
+ def __init__(self, n_q: int, empty_initial: int = 0):
369
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
370
 
371
 
372
  class UnrolledPatternProvider(CodebooksPatternProvider):
audiocraft/modules/conditioners.py CHANGED
@@ -15,8 +15,8 @@ import random
15
  import re
16
  import typing as tp
17
  import warnings
18
-
19
  import einops
 
20
  from num2words import num2words
21
  import spacy
22
  from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
@@ -24,10 +24,10 @@ import torch
24
  from torch import nn
25
  import torch.nn.functional as F
26
  from torch.nn.utils.rnn import pad_sequence
27
-
28
  from .chroma import ChromaExtractor
29
  from .streaming import StreamingModule
30
- from .transformer import create_sin_embedding
31
  from ..data.audio import audio_read
32
  from ..data.audio_dataset import SegmentInfo
33
  from ..data.audio_utils import convert_audio
@@ -43,6 +43,15 @@ TextCondition = tp.Optional[str] # a text condition can be a string or None (if
43
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
44
 
45
 
 
 
 
 
 
 
 
 
 
46
  class WavCondition(tp.NamedTuple):
47
  wav: torch.Tensor
48
  length: torch.Tensor
@@ -60,11 +69,17 @@ class JointEmbedCondition(tp.NamedTuple):
60
  seek_time: tp.List[tp.Optional[float]] = []
61
 
62
 
 
 
 
 
 
63
  @dataclass
64
  class ConditioningAttributes:
65
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
66
  wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
67
  joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
 
68
 
69
  def __getitem__(self, item):
70
  return getattr(self, item)
@@ -81,19 +96,25 @@ class ConditioningAttributes:
81
  def joint_embed_attributes(self):
82
  return self.joint_embed.keys()
83
 
 
 
 
 
84
  @property
85
  def attributes(self):
86
  return {
87
  "text": self.text_attributes,
88
  "wav": self.wav_attributes,
89
  "joint_embed": self.joint_embed_attributes,
 
90
  }
91
 
92
  def to_flat_dict(self):
93
  return {
94
  **{f"text.{k}": v for k, v in self.text.items()},
95
  **{f"wav.{k}": v for k, v in self.wav.items()},
96
- **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
 
97
  }
98
 
99
  @classmethod
@@ -177,6 +198,44 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
177
  )
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  class Tokenizer:
181
  """Base tokenizer implementation
182
  (in case we want to introduce more advances tokenizers in the future).
@@ -297,7 +356,8 @@ class BaseConditioner(nn.Module):
297
  super().__init__()
298
  self.dim = dim
299
  self.output_dim = output_dim
300
- self.output_proj = nn.Linear(dim, output_dim)
 
301
 
302
  def tokenize(self, *args, **kwargs) -> tp.Any:
303
  """Should be any part of the processing that will lead to a synchronization
@@ -495,8 +555,9 @@ class WaveformConditioner(BaseConditioner):
495
  wav, lengths, *_ = x
496
  with torch.no_grad():
497
  embeds = self._get_wav_embedding(x)
498
- embeds = embeds.to(self.output_proj.weight)
499
- embeds = self.output_proj(embeds)
 
500
 
501
  if lengths is not None and self._use_masking:
502
  lengths = lengths / self._downsampling_factor()
@@ -607,7 +668,7 @@ class ChromaStemConditioner(WaveformConditioner):
607
  with self.autocast:
608
  wav = convert_audio(
609
  wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
610
- stems = apply_model(self.demucs, wav, device=self.device)
611
  stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
612
  mix_wav = stems.sum(1) # merge extracted stems to single waveform
613
  mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
@@ -698,6 +759,250 @@ class ChromaStemConditioner(WaveformConditioner):
698
  return x
699
 
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  class JointEmbeddingConditioner(BaseConditioner):
702
  """Joint embedding conditioning supporting both audio or text conditioning.
703
 
@@ -996,13 +1301,48 @@ class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
996
  return embed, empty_idx
997
 
998
 
999
- def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1000
  """Utility function for nullifying an attribute inside an ConditioningAttributes object.
1001
  If the condition is of type "wav", then nullify it using `nullify_condition` function.
1002
  If the condition is of any other type, set its value to None.
1003
  Works in-place.
1004
  """
1005
- if condition_type not in ['text', 'wav', 'joint_embed']:
1006
  raise ValueError(
1007
  "dropout_condition got an unexpected condition type!"
1008
  f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
@@ -1021,6 +1361,8 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi
1021
  elif condition_type == 'joint_embed':
1022
  embed = sample.joint_embed[condition]
1023
  sample.joint_embed[condition] = nullify_joint_embed(embed)
 
 
1024
  else:
1025
  sample.text[condition] = None
1026
 
@@ -1071,7 +1413,7 @@ class AttributeDropout(DropoutModule):
1071
  return samples
1072
 
1073
  samples = deepcopy(samples)
1074
- for condition_type, ps in self.p.items(): # for condition types [text, wav]
1075
  for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
1076
  if torch.rand(1, generator=self.rng).item() < p:
1077
  for sample in samples:
@@ -1094,7 +1436,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
1094
  super().__init__(seed=seed)
1095
  self.p = p
1096
 
1097
- def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
 
 
1098
  """
1099
  Args:
1100
  samples (list[ConditioningAttributes]): List of conditions.
@@ -1111,10 +1455,11 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
1111
 
1112
  # nullify conditions of all attributes
1113
  samples = deepcopy(samples)
1114
- for condition_type in ["wav", "text"]:
1115
  for sample in samples:
1116
  for condition in sample.attributes[condition_type]:
1117
- dropout_condition(sample, condition_type, condition)
 
1118
  return samples
1119
 
1120
  def __repr__(self):
@@ -1339,7 +1684,7 @@ class ConditionFuser(StreamingModule):
1339
  cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
1340
  cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
1341
  """
1342
- FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
1343
 
1344
  def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
1345
  cross_attention_pos_emb_scale: float = 1.0):
@@ -1399,6 +1744,8 @@ class ConditionFuser(StreamingModule):
1399
  cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
1400
  else:
1401
  cross_attention_output = cond
 
 
1402
  else:
1403
  raise ValueError(f"unknown op ({op})")
1404
 
 
15
  import re
16
  import typing as tp
17
  import warnings
 
18
  import einops
19
+ import flashy
20
  from num2words import num2words
21
  import spacy
22
  from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
 
24
  from torch import nn
25
  import torch.nn.functional as F
26
  from torch.nn.utils.rnn import pad_sequence
27
+ from enum import Enum
28
  from .chroma import ChromaExtractor
29
  from .streaming import StreamingModule
30
+ from .transformer import create_sin_embedding, StreamingTransformer
31
  from ..data.audio import audio_read
32
  from ..data.audio_dataset import SegmentInfo
33
  from ..data.audio_utils import convert_audio
 
43
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
44
 
45
 
46
+ class JascoCondConst(Enum):
47
+ DRM = 'self_wav'
48
+ CRD = 'chords'
49
+ MLD = 'melody'
50
+ SYM = {'chords', 'melody'}
51
+ LAT = {'self_wav'}
52
+ ALL = ['chords', 'self_wav', 'melody'] # order matters
53
+
54
+
55
  class WavCondition(tp.NamedTuple):
56
  wav: torch.Tensor
57
  length: torch.Tensor
 
69
  seek_time: tp.List[tp.Optional[float]] = []
70
 
71
 
72
+ class SymbolicCondition(tp.NamedTuple):
73
+ frame_chords: tp.Optional[torch.Tensor] = None
74
+ melody: tp.Optional[torch.Tensor] = None
75
+
76
+
77
  @dataclass
78
  class ConditioningAttributes:
79
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
80
  wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
81
  joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
82
+ symbolic: tp.Dict[str, SymbolicCondition] = field(default_factory=dict)
83
 
84
  def __getitem__(self, item):
85
  return getattr(self, item)
 
96
  def joint_embed_attributes(self):
97
  return self.joint_embed.keys()
98
 
99
+ @property
100
+ def symbolic_attributes(self):
101
+ return self.symbolic.keys()
102
+
103
  @property
104
  def attributes(self):
105
  return {
106
  "text": self.text_attributes,
107
  "wav": self.wav_attributes,
108
  "joint_embed": self.joint_embed_attributes,
109
+ "symbolic": self.symbolic_attributes,
110
  }
111
 
112
  def to_flat_dict(self):
113
  return {
114
  **{f"text.{k}": v for k, v in self.text.items()},
115
  **{f"wav.{k}": v for k, v in self.wav.items()},
116
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()},
117
+ **{f"symbolic.{k}": v for k, v in self.symbolic.items()}
118
  }
119
 
120
  @classmethod
 
198
  )
199
 
200
 
201
+ def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 194) -> SymbolicCondition:
202
+ """Nullify the symbolic condition by setting all frame chords to a specified null chord index.
203
+ Args:
204
+ sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
205
+ null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
206
+ Returns:
207
+ SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
208
+ """
209
+ return SymbolicCondition(frame_chords=torch.ones_like(sym_cond.frame_chords) * null_chord_idx) # type: ignore
210
+
211
+
212
+ def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition:
213
+ """Nullify the symbolic condition by replacing the melody matrix with zeros matrix.
214
+ Args:
215
+ sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
216
+ null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
217
+ Returns:
218
+ SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
219
+ """
220
+ return SymbolicCondition(melody=torch.zeros_like(sym_cond.melody)) # type: ignore
221
+
222
+
223
+ def _drop_description_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
224
+ """Drop the text condition but keep the wav conditon on a list of ConditioningAttributes.
225
+ This is useful to calculate l_style in the double classifier free guidance formula.
226
+ See paragraph 4.3 in https://arxiv.org/pdf/2407.12563
227
+
228
+ Args:
229
+ conditions (tp.List[ConditioningAttributes]): List of conditions.
230
+ """
231
+ # We assert that description and self_wav are in the conditions
232
+ for condition in conditions:
233
+ assert 'description' in condition.text.keys()
234
+ assert 'self_wav' in condition.wav.keys()
235
+ return AttributeDropout(p={'text': {'description': 1.0},
236
+ 'wav': {'self_wav': 0.0}})(conditions)
237
+
238
+
239
  class Tokenizer:
240
  """Base tokenizer implementation
241
  (in case we want to introduce more advances tokenizers in the future).
 
356
  super().__init__()
357
  self.dim = dim
358
  self.output_dim = output_dim
359
+ if self.output_dim > -1: # omit projection when output_dim <= 0
360
+ self.output_proj = nn.Linear(dim, output_dim)
361
 
362
  def tokenize(self, *args, **kwargs) -> tp.Any:
363
  """Should be any part of the processing that will lead to a synchronization
 
555
  wav, lengths, *_ = x
556
  with torch.no_grad():
557
  embeds = self._get_wav_embedding(x)
558
+ if hasattr(self, 'output_proj'):
559
+ embeds = embeds.to(self.output_proj.weight)
560
+ embeds = self.output_proj(embeds)
561
 
562
  if lengths is not None and self._use_masking:
563
  lengths = lengths / self._downsampling_factor()
 
668
  with self.autocast:
669
  wav = convert_audio(
670
  wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
671
+ stems = apply_model(self.demucs, wav, device=self.device) # type: ignore
672
  stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
673
  mix_wav = stems.sum(1) # merge extracted stems to single waveform
674
  mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
 
759
  return x
760
 
761
 
762
+ class FeatureExtractor(WaveformConditioner):
763
+ """
764
+ Feature Extractor used for the style conditioner of the paper AUDIO CONDITIONING
765
+ FOR MUSIC GENERATION VIA DISCRETE BOTTLENECK FEATURES.
766
+
767
+ Given a waveform, we extract an excerpt of defined length randomly subsampled.
768
+ Then, we feed this excerpt to a feature extractor.
769
+
770
+ Args:
771
+ model_name (str): 'encodec' or 'mert'.
772
+ sample_rate (str): sample rate of the input audio. (32000)
773
+ encodec_checkpoint (str): if encodec is used as a feature extractor, checkpoint
774
+ of the model. ('//pretrained/facebook/encodec_32khz' is the default)
775
+ encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
776
+ quantization streams used in it.
777
+ length (float): length in seconds of the random subsampled excerpt that is used
778
+ for conditioning.
779
+ dim (int): The internal representation dimension.
780
+ output_dim (int): Output dimension for the conditioner.
781
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
782
+ compute_mask (bool): whether to mask the tokens corresponding to the subsampled
783
+ excerpt in the computation of the music language model cross-entropy loss.
784
+ use_middle_of_segment (bool): if True, always take the middle of the input
785
+ instead of a random subsampled excerpt.
786
+ ds_rate_compression (int): downsampling parameter of the compression model used
787
+ for the music language model. (640 for encodec_32khz)
788
+ num_codebooks_lm (int): the number of codebooks used by the music language model.
789
+ """
790
+ def __init__(
791
+ self, model_name: str,
792
+ sample_rate: int, encodec_checkpoint: str, encodec_n_q: int, length: float,
793
+ dim: int, output_dim: int, device: tp.Union[torch.device, str],
794
+ compute_mask: bool = True,
795
+ use_middle_of_segment: bool = False, ds_rate_compression: int = 640,
796
+ num_codebooks_lm: int = 4
797
+ ):
798
+ assert model_name in ['encodec', 'mert']
799
+ if model_name == 'encodec':
800
+ from ..solvers.compression import CompressionSolver
801
+ feat_extractor = CompressionSolver.model_from_checkpoint(encodec_checkpoint, device)
802
+ elif model_name == 'mert':
803
+ from transformers import AutoModel
804
+ feat_extractor = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
805
+ super().__init__(
806
+ dim=dim,
807
+ output_dim=output_dim,
808
+ device=device
809
+ )
810
+ self.sample_rate = sample_rate
811
+ self.compute_mask = compute_mask
812
+ self.feat_extractor: nn.Module
813
+ self.embed: tp.Union[nn.ModuleList, nn.Linear]
814
+ if model_name == 'encodec':
815
+ self.__dict__["feat_extractor"] = feat_extractor.to(device)
816
+ self.encodec_n_q = encodec_n_q
817
+ self.embed = nn.ModuleList([nn.Embedding(feat_extractor.cardinality, dim) for _ in range(encodec_n_q)])
818
+ if model_name == 'mert':
819
+ self.__dict__["feat_extractor"] = feat_extractor.eval().to(device)
820
+ self.embed = nn.Linear(768, dim) # hardcoded
821
+ self.length_subwav = int(length * sample_rate)
822
+ self.ds_rate_compression = ds_rate_compression
823
+ self.model_name = model_name
824
+ self.use_middle_of_segment = use_middle_of_segment
825
+ self.num_codebooks_lm = num_codebooks_lm
826
+
827
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
828
+ if x.wav.shape[-1] == 1:
829
+ self.temp_mask = None
830
+ return torch.zeros(x.wav.shape[0], 1, self.dim, device=self.device)
831
+ else:
832
+ with torch.no_grad():
833
+ if self.use_middle_of_segment:
834
+ start = int((x.wav.shape[-1] - self.length_subwav) / 2)
835
+ wav = x.wav[:, :, start:start+self.length_subwav]
836
+ else:
837
+ start = random.randint(0, x.wav.shape[-1] - self.length_subwav)
838
+ wav = x.wav[:, :, start:start+self.length_subwav]
839
+ if self.compute_mask:
840
+ self.temp_mask = self._get_mask_wav(x, start)
841
+ if self.model_name == 'encodec':
842
+ tokens = self.feat_extractor.encode(wav)[0] # type: ignore
843
+ elif self.model_name == 'mert':
844
+ wav = convert_audio(wav, from_rate=x.sample_rate[0], to_rate=24000, to_channels=1)
845
+ embeds = self.feat_extractor(wav.squeeze(-2)).last_hidden_state
846
+ if self.model_name == 'encodec':
847
+ tokens = tokens[:, :self.encodec_n_q]
848
+ embeds = sum([self.embed[k](tokens[:, k]) for k in range(self.encodec_n_q)]) # type: ignore
849
+ else:
850
+ embeds = self.embed(embeds)
851
+
852
+ return embeds # [B, T, dim]
853
+
854
+ def _downsampling_factor(self):
855
+ if self.model_name == 'encodec':
856
+ return self.sample_rate / self.feat_extractor.frame_rate
857
+ elif self.model_name == 'mert':
858
+ return self.sample_rate / 75
859
+
860
+ def _get_mask_wav(self, x: WavCondition, start: int) -> tp.Union[torch.Tensor, None]:
861
+ if x.wav.shape[-1] == 1:
862
+ return None
863
+ total_length = int(x.wav.shape[-1] / self.ds_rate_compression)
864
+ mask_length = int(self.length_subwav / self.ds_rate_compression)
865
+ start = int(start / self.ds_rate_compression)
866
+ mask = torch.ones(x.wav.shape[0], self.num_codebooks_lm,
867
+ total_length, device=self.device, dtype=torch.bool)
868
+ mask[:, :, start:start+mask_length] = 0
869
+ return mask
870
+
871
+
872
+ class StyleConditioner(FeatureExtractor):
873
+ """Conditioner from the paper AUDIO CONDITIONING FOR MUSIC GENERATION VIA
874
+ DISCRETE BOTTLENECK FEATURES.
875
+ Given an audio input, it is passed through a Feature Extractor and a
876
+ transformer encoder. Then it is quantized through RVQ.
877
+
878
+ Args:
879
+ transformer_scale (str): size of the transformer. See in the __init__ to have more infos.
880
+ ds_factor (int): the downsampling factor applied to the representation after quantization.
881
+ encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
882
+ quantization streams used in it.
883
+ n_q_out (int): the number of quantization streams used for the RVQ. If increased, there
884
+ is more information passing as a conditioning.
885
+ eval_q (int): the number of quantization streams used for the RVQ at evaluation time.
886
+ q_dropout (bool): if True, at training time, a random number of stream is sampled
887
+ at each step in the interval [1, n_q_out].
888
+ bins (int): the codebook size used for each quantization stream.
889
+ varying_lengths (List[float]): list of the min and max duration in seconds for the
890
+ randomly subsampled excerpt at training time. For each step a length is sampled
891
+ in this interval.
892
+ batch_norm (bool): use of batch normalization after the transformer. Stabilizes the
893
+ training.
894
+ rvq_threshold_ema_dead_code (float): threshold for dropping dead codes in the
895
+ RVQ.
896
+ """
897
+ def __init__(self, transformer_scale: str = 'default', ds_factor: int = 15, encodec_n_q: int = 4,
898
+ n_q_out: int = 6, eval_q: int = 3, q_dropout: bool = True, bins: int = 1024,
899
+ varying_lengths: tp.List[float] = [1.5, 4.5],
900
+ batch_norm: bool = True, rvq_threshold_ema_dead_code: float = 0.1,
901
+ **kwargs):
902
+ tr_args: tp.Dict[str, tp.Any]
903
+ if transformer_scale == 'xsmall':
904
+ tr_args = {'d_model': 256, 'num_heads': 8, 'num_layers': 4}
905
+ elif transformer_scale == 'large':
906
+ tr_args = {'d_model': 1024, 'num_heads': 16, 'num_layers': 24}
907
+ elif transformer_scale == 'default':
908
+ tr_args = {'d_model': 512, 'num_heads': 8, 'num_layers': 8}
909
+ elif transformer_scale == 'none':
910
+ tr_args = {'d_model': 512}
911
+ tr_args.update({
912
+ 'memory_efficient': True, 'activation': 'gelu',
913
+ 'norm_first': True, 'causal': False, 'layer_scale': None,
914
+ 'bias_ff': False, 'bias_attn': False,
915
+ })
916
+ dim = tr_args['d_model']
917
+ super().__init__(dim=dim, encodec_n_q=encodec_n_q, **kwargs)
918
+
919
+ self.ds_factor = ds_factor
920
+ if transformer_scale == 'none':
921
+ self.transformer = None
922
+ else:
923
+ self.transformer = StreamingTransformer(dim_feedforward=int(4 * dim), **tr_args)
924
+ self.n_q_out = n_q_out
925
+ self.eval_q = eval_q
926
+ self.rvq = None
927
+ if n_q_out > 0:
928
+ self.rvq = ResidualVectorQuantizer(dim, n_q=n_q_out, q_dropout=q_dropout, bins=bins,
929
+ threshold_ema_dead_code=rvq_threshold_ema_dead_code)
930
+ self.autocast = TorchAutocast(enabled=self.device != 'cpu', device_type=self.device, dtype=torch.float32)
931
+ self.varying_lengths = varying_lengths
932
+ self.batch_norm = None
933
+ if batch_norm:
934
+ self.batch_norm = nn.BatchNorm1d(dim, affine=False)
935
+ self.mask = None
936
+
937
+ def _get_wav_embedding(self, wav: WavCondition) -> torch.Tensor:
938
+ with self.autocast:
939
+ # Sample the length of the excerpts
940
+ if self.varying_lengths and self.training:
941
+ assert len(self.varying_lengths) == 2
942
+ length = random.uniform(self.varying_lengths[0], self.varying_lengths[1])
943
+ self.length_subwav = int(length * self.sample_rate)
944
+ z1 = super()._get_wav_embedding(wav)
945
+ if self.compute_mask:
946
+ self.mask = self.temp_mask # type: ignore
947
+ self.temp_mask = None
948
+
949
+ if self.transformer is not None:
950
+ out1 = self.transformer(z1)
951
+ else:
952
+ out1 = z1
953
+ if self.batch_norm:
954
+ out1 = self.batch_norm(out1.transpose(1, 2)).transpose(1, 2)
955
+ # Apply quantization
956
+ if self.rvq:
957
+ if self.training:
958
+ self.rvq.set_num_codebooks(self.n_q_out)
959
+ else:
960
+ self.rvq.set_num_codebooks(self.eval_q)
961
+ out1 = self.rvq(out1.transpose(1, 2), frame_rate=1.)
962
+ if self.training:
963
+ flashy.distrib.average_tensors(self.rvq.buffers())
964
+ out1 = out1.x.transpose(1, 2)
965
+ # Apply fix downsample
966
+ out1 = out1[:, ::self.ds_factor]
967
+
968
+ return out1
969
+
970
+ def set_params(self, eval_q: int = 3,
971
+ excerpt_length: float = 3.0,
972
+ ds_factor: tp.Optional[int] = None, encodec_n_q: tp.Optional[int] = None):
973
+ """Modify the parameters of the SSL or introduce new parameters to add noise to
974
+ the conditioning or to downsample it
975
+
976
+ Args:
977
+ eval_q (int): number of codebooks used when evaluating the model
978
+ excerpt_length (float): the length of the excerpts used to condition the model
979
+ """
980
+ self.eval_q = eval_q
981
+ self.length_subwav = int(excerpt_length * self.sample_rate)
982
+ if ds_factor is not None:
983
+ self.ds_factor = ds_factor
984
+ if encodec_n_q is not None:
985
+ self.encodec_n_q = encodec_n_q
986
+
987
+ def _downsampling_factor(self):
988
+ df = super()._downsampling_factor()
989
+ return df * self.ds_factor
990
+
991
+ def forward(self, x: WavCondition) -> ConditionType:
992
+ wav, lengths, *_ = x
993
+
994
+ embeds = self._get_wav_embedding(x)
995
+ embeds = embeds.to(self.output_proj.weight)
996
+ embeds = self.output_proj(embeds)
997
+
998
+ lengths = lengths / self._downsampling_factor()
999
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
1000
+
1001
+ embeds = (embeds * mask.unsqueeze(2).to(self.device))
1002
+
1003
+ return embeds, mask
1004
+
1005
+
1006
  class JointEmbeddingConditioner(BaseConditioner):
1007
  """Joint embedding conditioning supporting both audio or text conditioning.
1008
 
 
1301
  return embed, empty_idx
1302
 
1303
 
1304
+ def dropout_symbolic_conditions(sample: ConditioningAttributes,
1305
+ condition: str, null_chord_idx: int = 194) -> ConditioningAttributes:
1306
+ """
1307
+ Applies dropout to symbolic conditions within the sample based on the specified condition by setting the condition
1308
+ value to a null index.
1309
+ Args:
1310
+ sample (ConditioningAttributes): The sample containing symbolic attributes to potentially dropout.
1311
+ condition (str): The specific condition within the symbolic attributes to apply dropout.
1312
+ null_chord_idx (int, optional): The index used to represent a null chord. Defaults to 194.
1313
+ Returns:
1314
+ ConditioningAttributes: The modified sample with dropout applied to the specified condition.
1315
+ Raises:
1316
+ ValueError: If the specified condition is not present in the sample's symbolic attributes.
1317
+ """
1318
+ if sample.symbolic == {} or sample.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1: # type: ignore
1319
+ # nothing to drop
1320
+ return sample
1321
+
1322
+ if condition not in getattr(sample, 'symbolic'):
1323
+ raise ValueError(
1324
+ "dropout_symbolic_condition received an unexpected condition!"
1325
+ f" expected {sample.symbolic.keys()}"
1326
+ f" but got '{condition}'!"
1327
+ )
1328
+
1329
+ if condition == JascoCondConst.CRD.value:
1330
+ sample.symbolic[condition] = nullify_chords(sample.symbolic[condition], null_chord_idx=null_chord_idx)
1331
+ elif condition == JascoCondConst.MLD.value:
1332
+ sample.symbolic[condition] = nullify_melody(sample.symbolic[condition])
1333
+
1334
+ return sample
1335
+
1336
+
1337
+ def dropout_condition(sample: ConditioningAttributes,
1338
+ condition_type: str, condition: str,
1339
+ **kwargs) -> ConditioningAttributes:
1340
  """Utility function for nullifying an attribute inside an ConditioningAttributes object.
1341
  If the condition is of type "wav", then nullify it using `nullify_condition` function.
1342
  If the condition is of any other type, set its value to None.
1343
  Works in-place.
1344
  """
1345
+ if condition_type not in ['text', 'wav', 'joint_embed', 'symbolic']:
1346
  raise ValueError(
1347
  "dropout_condition got an unexpected condition type!"
1348
  f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
 
1361
  elif condition_type == 'joint_embed':
1362
  embed = sample.joint_embed[condition]
1363
  sample.joint_embed[condition] = nullify_joint_embed(embed)
1364
+ elif condition_type == 'symbolic':
1365
+ sample = dropout_symbolic_conditions(sample=sample, condition=condition, **kwargs)
1366
  else:
1367
  sample.text[condition] = None
1368
 
 
1413
  return samples
1414
 
1415
  samples = deepcopy(samples)
1416
+ for condition_type, ps in self.p.items(): # for condition types [text, wav, symbolic]
1417
  for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
1418
  if torch.rand(1, generator=self.rng).item() < p:
1419
  for sample in samples:
 
1436
  super().__init__(seed=seed)
1437
  self.p = p
1438
 
1439
+ def forward(self, samples: tp.List[ConditioningAttributes],
1440
+ cond_types: tp.List[str] = ["wav", "text"],
1441
+ **kwargs) -> tp.List[ConditioningAttributes]:
1442
  """
1443
  Args:
1444
  samples (list[ConditioningAttributes]): List of conditions.
 
1455
 
1456
  # nullify conditions of all attributes
1457
  samples = deepcopy(samples)
1458
+ for condition_type in cond_types:
1459
  for sample in samples:
1460
  for condition in sample.attributes[condition_type]:
1461
+ dropout_condition(sample, condition_type, condition,
1462
+ **kwargs)
1463
  return samples
1464
 
1465
  def __repr__(self):
 
1684
  cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
1685
  cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
1686
  """
1687
+ FUSING_METHODS = ["sum", "prepend", "cross", "ignore", "input_interpolate"]
1688
 
1689
  def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
1690
  cross_attention_pos_emb_scale: float = 1.0):
 
1744
  cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
1745
  else:
1746
  cross_attention_output = cond
1747
+ elif op == 'ignore':
1748
+ continue
1749
  else:
1750
  raise ValueError(f"unknown op ({op})")
1751
 
audiocraft/modules/jasco_conditioners.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import typing as tp
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from torch import nn
6
+ from .conditioners import (ConditioningAttributes, BaseConditioner, ConditionType,
7
+ ConditioningProvider, JascoCondConst,
8
+ WaveformConditioner, WavCondition, SymbolicCondition)
9
+ from ..data.audio import audio_read
10
+ from ..data.audio_utils import convert_audio
11
+ from ..utils.autocast import TorchAutocast
12
+ from ..utils.cache import EmbeddingCache
13
+
14
+
15
+ class MelodyConditioner(BaseConditioner):
16
+ """
17
+ A conditioner that handles melody conditioning from pre-computed salience matrix.
18
+ Attributes:
19
+ card (int): The cardinality of the melody matrix.
20
+ out_dim (int): The dimensionality of the output projection.
21
+ device (Union[torch.device, str]): The device on which the embeddings are stored.
22
+ """
23
+ def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
24
+ super().__init__(dim=card, output_dim=out_dim)
25
+ self.device = device
26
+
27
+ def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
28
+ return SymbolicCondition(melody=x.melody.to(self.device)) # type: ignore
29
+
30
+ def forward(self, x: SymbolicCondition) -> ConditionType:
31
+ embeds = self.output_proj(x.melody.permute(0, 2, 1)) # type: ignore
32
+ mask = torch.ones_like(embeds[..., 0])
33
+ return embeds, mask
34
+
35
+
36
+ class ChordsEmbConditioner(BaseConditioner):
37
+ """
38
+ A conditioner that embeds chord symbols into a continuous vector space.
39
+ Attributes:
40
+ card (int): The cardinality of the chord vocabulary.
41
+ out_dim (int): The dimensionality of the output embeddings.
42
+ device (Union[torch.device, str]): The device on which the embeddings are stored.
43
+ """
44
+ def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
45
+ vocab_size = card + 1 # card + 1 - for null chord used during dropout
46
+ super().__init__(dim=vocab_size, output_dim=-1) # out_dim=-1 to avoid another projection
47
+ self.emb = nn.Embedding(vocab_size, out_dim, device=device)
48
+ self.device = device
49
+
50
+ def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
51
+ return SymbolicCondition(frame_chords=x.frame_chords.to(self.device)) # type: ignore
52
+
53
+ def forward(self, x: SymbolicCondition) -> ConditionType:
54
+ embeds = self.emb(x.frame_chords)
55
+ mask = torch.ones_like(embeds[..., 0])
56
+ return embeds, mask
57
+
58
+
59
+ class DrumsConditioner(WaveformConditioner):
60
+ def __init__(self, out_dim: int, sample_rate: int, blurring_factor: int = 3,
61
+ cache_path: tp.Optional[tp.Union[str, Path]] = None,
62
+ compression_model_latent_dim: int = 128,
63
+ compression_model_framerate: float = 50,
64
+ segment_duration: float = 10.0,
65
+ device: tp.Union[torch.device, str] = 'cpu',
66
+ **kwargs):
67
+ """Drum condition conditioner
68
+
69
+ Args:
70
+ out_dim (int): _description_
71
+ sample_rate (int): _description_
72
+ blurring_factor (int, optional): _description_. Defaults to 3.
73
+ cache_path (tp.Optional[tp.Union[str, Path]], optional): path to precomputed cache. Defaults to None.
74
+ compression_model_latent_dim (int, optional): latent dimensino. Defaults to 128.
75
+ compression_model_framerate (float, optional): frame rate of the representation model. Defaults to 50.
76
+ segment_duration (float, optional): duration in sec for each audio segment. Defaults to 10.0.
77
+ device (tp.Union[torch.device, str], optional): device. Defaults to 'cpu'.
78
+ """
79
+ from demucs import pretrained
80
+ self.sample_rate = sample_rate
81
+ self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
82
+ stem_sources: list = self.demucs.sources # type: ignore
83
+ self.stem_idx = stem_sources.index('drums')
84
+ self.compression_model = None
85
+ self.latent_dim = compression_model_latent_dim
86
+ super().__init__(dim=self.latent_dim, output_dim=out_dim, device=device)
87
+ self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
88
+ self._use_masking = False
89
+ self.blurring_factor = blurring_factor
90
+ self.seq_len = int(segment_duration * compression_model_framerate)
91
+ self.cache = None # If you wish to train with EmbeddingCache, call self.create_embedding_cache(cache_path)
92
+
93
+ def create_embedding_cache(self, cache_path):
94
+ if cache_path is not None:
95
+ self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
96
+ compute_embed_fn=self._calc_coarse_drum_codes_for_cache,
97
+ extract_embed_fn=self._load_drum_codes_chunk)
98
+
99
+ @torch.no_grad()
100
+ def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
101
+ """Get parts of the wav that holds the drums, extracting the main stems from the wav."""
102
+ from demucs.apply import apply_model
103
+ from demucs.audio import convert_audio
104
+ with self.autocast:
105
+ wav = convert_audio(
106
+ wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
107
+ stems = apply_model(self.demucs, wav, device=self.device)
108
+ drum_stem = stems[:, self.stem_idx] # extract relevant stems for drums conditioning
109
+ return convert_audio(drum_stem, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
110
+
111
+ def _temporal_blur(self, z: torch.Tensor):
112
+ # z: (B, T, C)
113
+ B, T, C = z.shape
114
+ if T % self.blurring_factor != 0:
115
+ # pad with reflect for T % self.temporal_blurring on the right in dim=1
116
+ pad_val = self.blurring_factor - T % self.blurring_factor
117
+ z = torch.nn.functional.pad(z, (0, 0, 0, pad_val), mode='reflect')
118
+ z = z.reshape(B, -1, self.blurring_factor, C).sum(dim=2) / self.blurring_factor
119
+ z = z.unsqueeze(2).repeat(1, 1, self.blurring_factor, 1).reshape(B, -1, C)
120
+ z = z[:, :T]
121
+ assert z.shape == (B, T, C)
122
+ return z
123
+
124
+ @torch.no_grad()
125
+ def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
126
+ assert self.compression_model is not None
127
+
128
+ # stem separation of drums
129
+ drums = self._get_drums_stem(wav, sample_rate)
130
+
131
+ # continuous encoding with compression model
132
+ latents = self.compression_model.model.encoder(drums)
133
+
134
+ # quantization to coarsest codebook
135
+ coarsest_quantizer = self.compression_model.model.quantizer.layers[0]
136
+ drums = coarsest_quantizer.encode(latents).to(torch.int16)
137
+ return drums
138
+
139
+ @torch.no_grad()
140
+ def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path],
141
+ x: WavCondition, idx: int,
142
+ max_duration_to_process: float = 600) -> torch.Tensor:
143
+ """Extract blurred drum latents from the whole audio waveform at the given path."""
144
+ wav, sr = audio_read(path)
145
+ wav = wav[None].to(self.device)
146
+ wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
147
+
148
+ max_frames_to_process = int(max_duration_to_process * self.sample_rate)
149
+ if wav.shape[-1] > max_frames_to_process:
150
+ # process very long tracks in chunks
151
+ start = 0
152
+ codes = []
153
+ while start < wav.shape[-1] - 1:
154
+ wav_chunk = wav[..., start: start + max_frames_to_process]
155
+ codes.append(self._extract_coarse_drum_codes(wav_chunk, self.sample_rate)[0])
156
+ start += max_frames_to_process
157
+ return torch.cat(codes)
158
+
159
+ return self._extract_coarse_drum_codes(wav, self.sample_rate)[0]
160
+
161
+ def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
162
+ """Extract a chunk of coarse drum codes from the full coarse drum codes derived from the full waveform."""
163
+ wav_length = x.wav.shape[-1]
164
+ seek_time = x.seek_time[idx]
165
+ assert seek_time is not None, (
166
+ "WavCondition seek_time is required "
167
+ "when extracting chunks from pre-computed drum codes.")
168
+ assert self.compression_model is not None
169
+ frame_rate = self.compression_model.frame_rate
170
+ target_length = int(frame_rate * wav_length / self.sample_rate)
171
+ target_length = max(target_length, self.seq_len)
172
+ index = int(frame_rate * seek_time)
173
+ out = full_coarse_drum_codes[index: index + target_length]
174
+ # pad
175
+ out = torch.cat((out, torch.zeros(target_length - out.shape[0], dtype=out.dtype, device=out.device)))
176
+ return out.to(self.device)
177
+
178
+ @torch.no_grad()
179
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
180
+ bs = x.wav.shape[0]
181
+ if x.wav.shape[-1] <= 1:
182
+ # null condition
183
+ return torch.zeros((bs, self.seq_len, self.latent_dim), device=x.wav.device, dtype=x.wav.dtype)
184
+
185
+ # extract coarse drum codes
186
+ no_undefined_paths = all(p is not None for p in x.path)
187
+ no_nullified_cond = x.wav.shape[-1] > 1
188
+ if self.cache is not None and no_undefined_paths and no_nullified_cond:
189
+ paths = [Path(p) for p in x.path if p is not None]
190
+ codes = self.cache.get_embed_from_cache(paths, x)
191
+ else:
192
+ assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
193
+ codes = self._extract_coarse_drum_codes(x.wav, x.sample_rate[0])
194
+
195
+ assert self.compression_model is not None
196
+ # decode back to the continuous representation of compression model
197
+ codes = codes.unsqueeze(1).permute(1, 0, 2) # (B, T) -> (1, B, T)
198
+ codes = codes.to(torch.int64)
199
+ latents = self.compression_model.model.quantizer.decode(codes)
200
+
201
+ latents = latents.permute(0, 2, 1) # [B, C, T] -> [B, T, C]
202
+
203
+ # temporal blurring
204
+ return self._temporal_blur(latents)
205
+
206
+ def tokenize(self, x: WavCondition) -> WavCondition:
207
+ """Apply WavConditioner tokenization and populate cache if needed."""
208
+ x = super().tokenize(x)
209
+ no_undefined_paths = all(p is not None for p in x.path)
210
+ if self.cache is not None and no_undefined_paths:
211
+ paths = [Path(p) for p in x.path if p is not None]
212
+ self.cache.populate_embed_cache(paths, x)
213
+ return x
214
+
215
+
216
+ class JascoConditioningProvider(ConditioningProvider):
217
+ """
218
+ A cond-provider that manages and tokenizes various types of conditioning attributes for Jasco models.
219
+ Attributes:
220
+ chords_card (int): The cardinality of the chord vocabulary.
221
+ sequence_length (int): The length of the sequence for padding purposes.
222
+ melody_dim (int): The dimensionality of the melody matrix.
223
+ """
224
+ def __init__(self, *args,
225
+ chords_card: int = 194,
226
+ sequence_length: int = 500,
227
+ melody_dim: int = 53, **kwargs):
228
+ self.null_chord = chords_card
229
+ self.sequence_len = sequence_length
230
+ self.melody_dim = melody_dim
231
+ super().__init__(*args, **kwargs)
232
+
233
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
234
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
235
+ This should be called before starting any real GPU work to avoid synchronization points.
236
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
237
+
238
+ Args:
239
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
240
+ text and wav conditions.
241
+ """
242
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
243
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
244
+ f" but types were {set([type(x) for x in inputs])}"
245
+ )
246
+
247
+ output = {}
248
+ text = self._collate_text(inputs)
249
+ wavs = self._collate_wavs(inputs)
250
+
251
+ symbolic = self._collate_symbolic(inputs, self.conditioners.keys())
252
+
253
+ assert set(text.keys() | wavs.keys() | symbolic.keys()).issubset(set(self.conditioners.keys())), (
254
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
255
+ f"got {text.keys(), wavs.keys(), symbolic.keys()}"
256
+ )
257
+
258
+ for attribute, batch in chain(text.items(), wavs.items(), symbolic.items()):
259
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
260
+ return output
261
+
262
+ def _collate_symbolic(self, samples: tp.List[ConditioningAttributes],
263
+ conditioner_keys: tp.Set) -> tp.Dict[str, SymbolicCondition]:
264
+ output = {}
265
+
266
+ # collate if symbolic cond exists
267
+ if any(x in conditioner_keys for x in JascoCondConst.SYM.value):
268
+
269
+ for s in samples:
270
+ # hydrate with null chord if chords not exist - for inference support
271
+ if (s.symbolic == {} or
272
+ s.symbolic[JascoCondConst.CRD.value].frame_chords is None or
273
+ s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1): # type: ignore
274
+ # no chords conditioning - fill with null chord token
275
+ s.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
276
+ frame_chords=torch.ones(self.sequence_len, dtype=torch.int32) * self.null_chord)
277
+
278
+ if (s.symbolic == {} or
279
+ s.symbolic[JascoCondConst.MLD.value].melody is None or
280
+ s.symbolic[JascoCondConst.MLD.value].melody.shape[-1] <= 1): # type: ignore
281
+ # no chords conditioning - fill with null chord token
282
+ s.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(
283
+ melody=torch.zeros((self.melody_dim, self.sequence_len)))
284
+
285
+ if JascoCondConst.CRD.value in conditioner_keys:
286
+ # pad to max
287
+ max_seq_len = max(
288
+ [s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] for s in samples]) # type: ignore
289
+ padded_chords = [
290
+ torch.cat((x.symbolic[JascoCondConst.CRD.value].frame_chords, # type: ignore
291
+ torch.ones(max_seq_len -
292
+ x.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1], # type: ignore
293
+ dtype=torch.int32) * self.null_chord))
294
+ for x in samples
295
+ ]
296
+ output[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.stack(padded_chords))
297
+ if JascoCondConst.MLD.value in conditioner_keys:
298
+ melodies = torch.stack([x.symbolic[JascoCondConst.MLD.value].melody for x in samples]) # type: ignore
299
+ output[JascoCondConst.MLD.value] = SymbolicCondition(melody=melodies)
300
+ return output
audiocraft/modules/transformer.py CHANGED
@@ -315,7 +315,6 @@ class StreamingMultiheadAttention(StreamingModule):
315
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
316
  key_padding_mask=None, need_weights=False, attn_mask=None,
317
  average_attn_weights=True, is_causal=False):
318
- assert attn_mask is None
319
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
320
  "use the causal args in the constructor.")
321
 
@@ -329,7 +328,10 @@ class StreamingMultiheadAttention(StreamingModule):
329
  assert self.causal or self.cross_attention, \
330
  "Streaming only available for causal or cross attention"
331
 
 
 
332
  if self.causal:
 
333
  # At the moment we specialize only for the self-attention case.
334
  assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
335
  assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
@@ -398,6 +400,14 @@ class StreamingMultiheadAttention(StreamingModule):
398
  if self.attention_as_float32:
399
  q, k, v = [x.float() for x in [q, k, v]]
400
  if self.memory_efficient:
 
 
 
 
 
 
 
 
401
  p = self.dropout if self.training else 0
402
  if _efficient_attention_backend == 'torch':
403
  x = torch.nn.functional.scaled_dot_product_attention(
 
315
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
316
  key_padding_mask=None, need_weights=False, attn_mask=None,
317
  average_attn_weights=True, is_causal=False):
 
318
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
319
  "use the causal args in the constructor.")
320
 
 
328
  assert self.causal or self.cross_attention, \
329
  "Streaming only available for causal or cross attention"
330
 
331
+ custom_attn_mask = attn_mask is not None
332
+
333
  if self.causal:
334
+ assert attn_mask is None
335
  # At the moment we specialize only for the self-attention case.
336
  assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
337
  assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
 
400
  if self.attention_as_float32:
401
  q, k, v = [x.float() for x in [q, k, v]]
402
  if self.memory_efficient:
403
+ if custom_attn_mask:
404
+ # When using a custom attn mask:
405
+ # Move to query's device, repeat for each sample, remove align8 padding
406
+ seq_len = query.shape[1]
407
+ attn_mask = attn_mask.to(q.dtype)
408
+ attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
409
+ attn_mask = attn_mask[..., :seq_len, :seq_len]
410
+
411
  p = self.dropout if self.training else 0
412
  if _efficient_attention_backend == 'torch':
413
  x = torch.nn.functional.scaled_dot_product_attention(
audiocraft/modules/unet_transformer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import typing as tp
3
+ from .transformer import StreamingTransformer, create_sin_embedding
4
+
5
+
6
+ class UnetTransformer(StreamingTransformer):
7
+ """U-net Transformer for processing sequences with optional skip connections.
8
+ This transformer architecture incorporates U-net style skip connections
9
+ between layers, which can be optionally enabled. It inherits from a
10
+ StreamingTransformer.
11
+
12
+ Args:
13
+ d_model (int): Dimension of the model, typically the number of expected features in the input.
14
+ num_layers (int): Total number of layers in the transformer.
15
+ skip_connections (bool, optional): Flag to determine whether skip connections should be used.
16
+ Defaults to False.
17
+ layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training).
18
+ **kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`.
19
+ """
20
+ def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False,
21
+ layer_dropout_p: tp.Optional[float] = None, **kwargs):
22
+ super().__init__(d_model=d_model,
23
+ num_layers=num_layers,
24
+ **kwargs)
25
+ self.skip_connect = skip_connections
26
+ if self.skip_connect:
27
+ self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model)
28
+ for _ in range(num_layers // 2)])
29
+ self.num_layers = num_layers
30
+ self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0
31
+
32
+ def forward(self, x: torch.Tensor, *args, **kwargs):
33
+ B, T, C = x.shape
34
+
35
+ if 'offsets' in self._streaming_state:
36
+ offsets = self._streaming_state['offsets']
37
+ else:
38
+ offsets = torch.zeros(B, dtype=torch.long, device=x.device)
39
+
40
+ if self.positional_embedding in ['sin', 'sin_rope']:
41
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
42
+ positions = positions + offsets.view(-1, 1, 1)
43
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
44
+ x = x + self.positional_scale * pos_emb
45
+
46
+ skip_connections: tp.List[torch.Tensor] = []
47
+
48
+ for i, layer in enumerate(self.layers):
49
+ if self.skip_connect and i >= self.num_layers // 2:
50
+
51
+ # in the second half of the layers, add residual connection
52
+ # and linearly project the concatenated features back to d_model
53
+ x = torch.cat([x, skip_connections.pop()], dim=-1)
54
+ x = self.skip_projections[i % len(self.skip_projections)](x)
55
+
56
+ x = self._apply_layer(layer, x, *args, **kwargs)
57
+
58
+ if self.skip_connect and i < self.num_layers // 2:
59
+ if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip
60
+ skip_connections.append(torch.zeros_like(x))
61
+ else:
62
+ skip_connections.append(x)
63
+
64
+ if self._is_streaming:
65
+ self._streaming_state['offsets'] = offsets + T
66
+
67
+ return x
audiocraft/utils/extend.py CHANGED
@@ -51,7 +51,7 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
51
  print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
52
  return segments
53
 
54
- def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, progress= gr.Progress(track_tqdm=True)):
55
  # generate audio segments
56
  melody_segments = separate_audio_segments(melody, segment_duration, 0)
57
 
@@ -96,7 +96,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
96
  pbar.update(1)
97
  print(f"melody_segments: {len(melody_segments)} fixed")
98
 
99
- # Iterate over the segments to create list of Meldoy tensors
100
  for segment_idx in range(total_segments):
101
  if INTERRUPTING:
102
  return [], duration
@@ -119,6 +119,10 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
119
  verse = verse[None]
120
  verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
121
 
 
 
 
 
122
  # Append the segment to the melodys list
123
  melodys.append(verse)
124
  pbar.update(1)
@@ -139,10 +143,17 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
139
  top_p=MODEL.generation_params["top_p"],
140
  temperature=MODEL.generation_params["temp"],
141
  cfg_coef=MODEL.generation_params["cfg_coef"],
 
142
  duration=segment_duration,
143
  two_step_cfg=False,
144
- rep_penalty=0.5
145
  )
 
 
 
 
 
 
146
  # Generate a new prompt segment. This will be applied to all segments for consistency
147
  print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
148
  prompt_segment = MODEL.generate_with_all(
@@ -168,10 +179,18 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
168
  top_p=MODEL.generation_params["top_p"],
169
  temperature=MODEL.generation_params["temp"],
170
  cfg_coef=MODEL.generation_params["cfg_coef"],
 
171
  duration=mod_duration,
172
  two_step_cfg=False,
173
- rep_penalty=0.5
174
  )
 
 
 
 
 
 
 
175
  try:
176
  # get last chunk
177
  verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
 
51
  print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
52
  return segments
53
 
54
+ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, excerpt_duration:float=3.5, progress= gr.Progress(track_tqdm=True)):
55
  # generate audio segments
56
  melody_segments = separate_audio_segments(melody, segment_duration, 0)
57
 
 
96
  pbar.update(1)
97
  print(f"melody_segments: {len(melody_segments)} fixed")
98
 
99
+ # Iterate over the segments to create list of Melody tensors
100
  for segment_idx in range(total_segments):
101
  if INTERRUPTING:
102
  return [], duration
 
119
  verse = verse[None]
120
  verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
121
 
122
+ # Reduce the length of verse to sr * excerpt_duration
123
+ if ("style" in MODEL.name):
124
+ verse = verse[:, :, :int(sr * excerpt_duration)]
125
+
126
  # Append the segment to the melodys list
127
  melodys.append(verse)
128
  pbar.update(1)
 
143
  top_p=MODEL.generation_params["top_p"],
144
  temperature=MODEL.generation_params["temp"],
145
  cfg_coef=MODEL.generation_params["cfg_coef"],
146
+ cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
147
  duration=segment_duration,
148
  two_step_cfg=False,
149
+ rep_penalty=0.5,
150
  )
151
+ if ("style" in MODEL.name):
152
+ MODEL.set_style_conditioner_params(
153
+ eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
154
+ excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
155
+ )
156
+
157
  # Generate a new prompt segment. This will be applied to all segments for consistency
158
  print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
159
  prompt_segment = MODEL.generate_with_all(
 
179
  top_p=MODEL.generation_params["top_p"],
180
  temperature=MODEL.generation_params["temp"],
181
  cfg_coef=MODEL.generation_params["cfg_coef"],
182
+ cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
183
  duration=mod_duration,
184
  two_step_cfg=False,
185
+ rep_penalty=0.5,
186
  )
187
+
188
+ if ("style" in MODEL.name):
189
+ MODEL.set_style_conditioner_params(
190
+ eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
191
+ excerpt_length=min(excerpt_duration, mod_duration), # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
192
+ )
193
+
194
  try:
195
  # get last chunk
196
  verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
audiocraft/utils/utils.py CHANGED
@@ -298,3 +298,31 @@ def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
298
  pkg = load_state_dict(path)
299
  pkg.pop('text_branch.embeddings.position_ids', None)
300
  clap_model.model.load_state_dict(pkg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  pkg = load_state_dict(path)
299
  pkg.pop('text_branch.embeddings.position_ids', None)
300
  clap_model.model.load_state_dict(pkg)
301
+
302
+ def construct_frame_chords(
303
+ min_timestamp: int,
304
+ chord_changes: tp.List[tp.Tuple[float, str]],
305
+ mapping_dict: tp.Dict,
306
+ prev_chord: str,
307
+ frame_rate: float,
308
+ segment_duration: float,
309
+ ) -> tp.List[str]:
310
+ """ Translate symbolic chords [(start_time, tuples),...] into a frame-level int sequence"""
311
+
312
+ frames = [
313
+ frame / frame_rate
314
+ for frame in range(
315
+ min_timestamp, int(min_timestamp + segment_duration * frame_rate)
316
+ )
317
+ ]
318
+
319
+ frame_chords = []
320
+ current_chord = prev_chord
321
+
322
+ for frame in frames:
323
+ while chord_changes and frame >= chord_changes[0][0]:
324
+ current_chord = chord_changes.pop(0)[1]
325
+ current_chord = 'N' if current_chord in {None, ''} else current_chord
326
+ frame_chords.append(mapping_dict[current_chord])
327
+
328
+ return frame_chords
requirements.txt CHANGED
@@ -10,14 +10,16 @@ soundfile
10
  huggingface_hub
11
  hf_xet
12
  tqdm
13
- transformers>=4.48.0 # need Encodec there.
14
  xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
15
  demucs
16
  librosa==0.11.0
17
  soundfile
18
  gradio[oauth]
19
  pillow
 
20
  torchmetrics
 
21
  encodec
22
  protobuf>=3.20.1
23
  filetype
 
10
  huggingface_hub
11
  hf_xet
12
  tqdm
13
+ transformers==4.43.4 # need Encodec there.
14
  xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
15
  demucs
16
  librosa==0.11.0
17
  soundfile
18
  gradio[oauth]
19
  pillow
20
+ torchdiffeq
21
  torchmetrics
22
+ nnAudio
23
  encodec
24
  protobuf>=3.20.1
25
  filetype