Spaces:
Running
on
T4
Running
on
T4
Add STYLE model with upgrades
Browse files- app.py +44 -21
- audiocraft/__init__.py +1 -1
- audiocraft/data/audio.py +124 -4
- audiocraft/models/__init__.py +5 -1
- audiocraft/models/builders.py +202 -109
- audiocraft/models/flow_matching.py +516 -0
- audiocraft/models/genmodel.py +267 -0
- audiocraft/models/lm.py +78 -24
- audiocraft/models/lm_magnet.py +500 -0
- audiocraft/models/loaders.py +49 -1
- audiocraft/models/magnet.py +88 -0
- audiocraft/models/musicgen.py +33 -4
- audiocraft/modules/codebooks_patterns.py +10 -6
- audiocraft/modules/conditioners.py +362 -15
- audiocraft/modules/jasco_conditioners.py +300 -0
- audiocraft/modules/transformer.py +11 -1
- audiocraft/modules/unet_transformer.py +67 -0
- audiocraft/utils/extend.py +23 -4
- audiocraft/utils/utils.py +28 -0
- requirements.txt +3 -1
app.py
CHANGED
@@ -183,7 +183,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model,topp, temperatur
|
|
183 |
|
184 |
return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
|
185 |
|
186 |
-
def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, settings_animate_waveform=False, video_orientation="Landscape", progress=gr.Progress(track_tqdm=True)):
|
187 |
global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
|
188 |
output_segments = None
|
189 |
melody_name = "Not Used"
|
@@ -251,24 +251,47 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
251 |
|
252 |
|
253 |
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
MODEL.set_custom_progress_callback(gr.Progress(track_tqdm=True))
|
266 |
|
267 |
try:
|
268 |
-
if melody and ("melody" in model):
|
269 |
# return excess duration, load next model and continue in loop structure building up output_segments
|
270 |
if duration > MODEL.duration:
|
271 |
-
output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.duration, prompt_index, harmony_only, progress=gr.Progress(track_tqdm=True))
|
272 |
else:
|
273 |
# pure original code
|
274 |
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
|
@@ -487,11 +510,11 @@ def ui(**kwargs):
|
|
487 |
with gr.Column():
|
488 |
with gr.Row():
|
489 |
with gr.Column():
|
490 |
-
text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps
|
491 |
autoplay_cb = gr.Checkbox(value=False, label="Autoplay?", key="autoplay_cb")
|
492 |
with gr.Column():
|
493 |
duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True, key="total_duration", step=1)
|
494 |
-
model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large"], label="AI Model", value="medium", interactive=True, key="chosen_model")
|
495 |
with gr.Row():
|
496 |
submit = gr.Button("Generate", elem_id="btn-generate")
|
497 |
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
@@ -545,7 +568,7 @@ def ui(**kwargs):
|
|
545 |
gr.Examples(
|
546 |
examples=[
|
547 |
[
|
548 |
-
"4/4 120bpm 320kbps
|
549 |
"./assets/bach.mp3",
|
550 |
"melody",
|
551 |
"80s Pop Synth",
|
@@ -554,7 +577,7 @@ def ui(**kwargs):
|
|
554 |
3.5
|
555 |
],
|
556 |
[
|
557 |
-
"4/4 120bpm 320kbps
|
558 |
"./assets/bolero_ravel.mp3",
|
559 |
"stereo-melody-large",
|
560 |
"Country Guitar",
|
@@ -563,7 +586,7 @@ def ui(**kwargs):
|
|
563 |
4.0
|
564 |
],
|
565 |
[
|
566 |
-
"4/4 120bpm 320kbps
|
567 |
None,
|
568 |
"stereo-medium",
|
569 |
"90s Rock Guitar",
|
@@ -572,7 +595,7 @@ def ui(**kwargs):
|
|
572 |
3.75
|
573 |
],
|
574 |
[
|
575 |
-
"4/4 120bpm 320kbps
|
576 |
"./assets/bach.mp3",
|
577 |
"melody-large",
|
578 |
"EDM my Bach",
|
@@ -581,7 +604,7 @@ def ui(**kwargs):
|
|
581 |
3.75
|
582 |
],
|
583 |
[
|
584 |
-
"4/4 320kbps
|
585 |
None,
|
586 |
"medium",
|
587 |
"LoFi Chill",
|
|
|
183 |
|
184 |
return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
|
185 |
|
186 |
+
def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, settings_animate_waveform=False, video_orientation="Landscape", excerpt_duration=3.5, progress=gr.Progress(track_tqdm=True)):
|
187 |
global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
|
188 |
output_segments = None
|
189 |
melody_name = "Not Used"
|
|
|
251 |
|
252 |
|
253 |
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
|
254 |
+
if ("style" in model) and melody:
|
255 |
+
# style and text-to-music
|
256 |
+
MODEL.set_generation_params(
|
257 |
+
use_sampling=True,
|
258 |
+
top_k=topk,
|
259 |
+
top_p=topp,
|
260 |
+
temperature=temperature,
|
261 |
+
cfg_coef=cfg_coef,
|
262 |
+
duration=segment_duration,
|
263 |
+
two_step_cfg=False,
|
264 |
+
cfg_coef_beta=5, # double CFG is only useful for text-and-style conditioning
|
265 |
+
)
|
266 |
+
|
267 |
+
MODEL.set_style_conditioner_params(
|
268 |
+
eval_q=3, # integer between 1 and 6
|
269 |
+
# eval_q is the level of quantization that passes
|
270 |
+
# through the conditioner. When low, the models adheres less to the
|
271 |
+
# audio conditioning
|
272 |
+
excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be
|
273 |
+
# between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
MODEL.set_generation_params(
|
277 |
+
use_sampling=True,
|
278 |
+
top_k=topk,
|
279 |
+
top_p=topp,
|
280 |
+
temperature=temperature,
|
281 |
+
cfg_coef=cfg_coef,
|
282 |
+
duration=segment_duration,
|
283 |
+
two_step_cfg=False,
|
284 |
+
extend_stride=10,
|
285 |
+
rep_penalty=0.5,
|
286 |
+
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
287 |
+
)
|
288 |
MODEL.set_custom_progress_callback(gr.Progress(track_tqdm=True))
|
289 |
|
290 |
try:
|
291 |
+
if melody and ("melody" or "style" in model):
|
292 |
# return excess duration, load next model and continue in loop structure building up output_segments
|
293 |
if duration > MODEL.duration:
|
294 |
+
output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.duration, prompt_index, harmony_only, excerpt_duration, progress=gr.Progress(track_tqdm=True))
|
295 |
else:
|
296 |
# pure original code
|
297 |
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
|
|
|
510 |
with gr.Column():
|
511 |
with gr.Row():
|
512 |
with gr.Column():
|
513 |
+
text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 32khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi, soft fade-in, soft fade-out", key="prompt", lines=4)
|
514 |
autoplay_cb = gr.Checkbox(value=False, label="Autoplay?", key="autoplay_cb")
|
515 |
with gr.Column():
|
516 |
duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True, key="total_duration", step=1)
|
517 |
+
model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large", "style"], label="AI Model", value="medium", interactive=True, key="chosen_model")
|
518 |
with gr.Row():
|
519 |
submit = gr.Button("Generate", elem_id="btn-generate")
|
520 |
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
|
|
568 |
gr.Examples(
|
569 |
examples=[
|
570 |
[
|
571 |
+
"4/4 120bpm 320kbps 32khz, An 80s driving pop song with heavy drums and synth pads in the background",
|
572 |
"./assets/bach.mp3",
|
573 |
"melody",
|
574 |
"80s Pop Synth",
|
|
|
577 |
3.5
|
578 |
],
|
579 |
[
|
580 |
+
"4/4 120bpm 320kbps 32khz, A cheerful country song with acoustic guitars",
|
581 |
"./assets/bolero_ravel.mp3",
|
582 |
"stereo-melody-large",
|
583 |
"Country Guitar",
|
|
|
586 |
4.0
|
587 |
],
|
588 |
[
|
589 |
+
"4/4 120bpm 320kbps 32khz, 90s rock song with electric guitar and heavy drums",
|
590 |
None,
|
591 |
"stereo-medium",
|
592 |
"90s Rock Guitar",
|
|
|
595 |
3.75
|
596 |
],
|
597 |
[
|
598 |
+
"4/4 120bpm 320kbps 32khz, a light and cheery EDM track, with syncopated drums, aery pads, and strong emotions",
|
599 |
"./assets/bach.mp3",
|
600 |
"melody-large",
|
601 |
"EDM my Bach",
|
|
|
604 |
3.75
|
605 |
],
|
606 |
[
|
607 |
+
"4/4 320kbps 32khz, lofi slow bpm electro chill with organic samples",
|
608 |
None,
|
609 |
"medium",
|
610 |
"LoFi Chill",
|
audiocraft/__init__.py
CHANGED
@@ -7,4 +7,4 @@
|
|
7 |
# flake8: noqa
|
8 |
from . import data, modules, models
|
9 |
|
10 |
-
__version__ = '1.
|
|
|
7 |
# flake8: noqa
|
8 |
from . import data, modules, models
|
9 |
|
10 |
+
__version__ = '1.3.Surn'
|
audiocraft/data/audio.py
CHANGED
@@ -79,7 +79,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
|
|
79 |
seek_time (float): Time at which to start reading in the file.
|
80 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
81 |
Returns:
|
82 |
-
|
83 |
"""
|
84 |
_init_av()
|
85 |
with av.open(str(filepath)) as af:
|
@@ -115,7 +115,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
|
|
115 |
|
116 |
|
117 |
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
118 |
-
duration: float = -1
|
119 |
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
120 |
|
121 |
Args:
|
@@ -124,7 +124,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
|
124 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
125 |
pad (bool): Pad output audio if not reaching expected duration.
|
126 |
Returns:
|
127 |
-
|
128 |
"""
|
129 |
fp = Path(filepath)
|
130 |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
@@ -299,4 +299,124 @@ def audio_write2(stem_name: tp.Union[str, Path],
|
|
299 |
# we do not want to leave half written files around.
|
300 |
path.unlink()
|
301 |
raise
|
302 |
-
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
seek_time (float): Time at which to start reading in the file.
|
80 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
81 |
Returns:
|
82 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
|
83 |
"""
|
84 |
_init_av()
|
85 |
with av.open(str(filepath)) as af:
|
|
|
115 |
|
116 |
|
117 |
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
118 |
+
duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
119 |
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
120 |
|
121 |
Args:
|
|
|
124 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
125 |
pad (bool): Pad output audio if not reaching expected duration.
|
126 |
Returns:
|
127 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
128 |
"""
|
129 |
fp = Path(filepath)
|
130 |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
|
|
299 |
# we do not want to leave half written files around.
|
300 |
path.unlink()
|
301 |
raise
|
302 |
+
return path
|
303 |
+
|
304 |
+
|
305 |
+
def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
|
306 |
+
"""Get the mel-spectrogram from the raw audio.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
y (numpy array): raw input
|
310 |
+
sr (int): Sampling rate
|
311 |
+
n_fft (int): Number of samples per FFT. Default is 2048.
|
312 |
+
hop_length (int): Number of samples between successive frames. Default is 512.
|
313 |
+
dur (float): Maxium duration to get the spectrograms
|
314 |
+
Returns:
|
315 |
+
spectro histogram as a numpy array
|
316 |
+
"""
|
317 |
+
import librosa
|
318 |
+
import librosa.display
|
319 |
+
|
320 |
+
spectrogram = librosa.feature.melspectrogram(
|
321 |
+
y=y, sr=sr, n_fft=n_fft, hop_length=hop_length
|
322 |
+
)
|
323 |
+
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
|
324 |
+
return spectrogram_db
|
325 |
+
|
326 |
+
|
327 |
+
def save_spectrograms(
|
328 |
+
ys: tp.List[np.ndarray],
|
329 |
+
sr: int,
|
330 |
+
path: str,
|
331 |
+
names: tp.List[str],
|
332 |
+
n_fft: int = 4096,
|
333 |
+
hop_length: int = 128,
|
334 |
+
dur: float = 8.0,
|
335 |
+
):
|
336 |
+
"""Plot a spectrogram for an audio file.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
ys: List of audio spectrograms
|
340 |
+
sr (int): Sampling rate of the audio file. Default is 22050 Hz.
|
341 |
+
path (str): Path to the plot file.
|
342 |
+
names: name of each spectrogram plot
|
343 |
+
n_fft (int): Number of samples per FFT. Default is 2048.
|
344 |
+
hop_length (int): Number of samples between successive frames. Default is 512.
|
345 |
+
dur (float): Maxium duration to plot the spectrograms
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
None (plots the spectrogram using matplotlib)
|
349 |
+
"""
|
350 |
+
import matplotlib as mpl # type: ignore
|
351 |
+
import matplotlib.pyplot as plt # type: ignore
|
352 |
+
import librosa.display
|
353 |
+
|
354 |
+
if not names:
|
355 |
+
names = ["Ground Truth", "Audio Watermarked", "Watermark"]
|
356 |
+
ys = [wav[: int(dur * sr)] for wav in ys] # crop
|
357 |
+
assert len(names) == len(
|
358 |
+
ys
|
359 |
+
), f"There are {len(ys)} wavs but {len(names)} names ({names})"
|
360 |
+
|
361 |
+
# Set matplotlib stuff
|
362 |
+
BIGGER_SIZE = 10
|
363 |
+
SMALLER_SIZE = 8
|
364 |
+
linewidth = 234.8775 # linewidth in pt
|
365 |
+
|
366 |
+
plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes
|
367 |
+
plt.rcParams["font.family"] = "DeJavu Serif"
|
368 |
+
plt.rcParams["font.serif"] = ["Times New Roman"]
|
369 |
+
|
370 |
+
plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title
|
371 |
+
plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels
|
372 |
+
plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels
|
373 |
+
plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels
|
374 |
+
plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize
|
375 |
+
plt.rc("figure", titlesize=BIGGER_SIZE)
|
376 |
+
height = 1.6 * linewidth / 72.0
|
377 |
+
fig, ax = plt.subplots(
|
378 |
+
nrows=len(ys),
|
379 |
+
ncols=1,
|
380 |
+
sharex=True,
|
381 |
+
figsize=(linewidth / 72.0, height),
|
382 |
+
)
|
383 |
+
fig.tight_layout()
|
384 |
+
|
385 |
+
# Plot the spectrogram
|
386 |
+
|
387 |
+
for i, ysi in enumerate(ys):
|
388 |
+
spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length)
|
389 |
+
if i == 0:
|
390 |
+
cax = fig.add_axes(
|
391 |
+
[
|
392 |
+
ax[0].get_position().x1 + 0.01, # type: ignore
|
393 |
+
ax[-1].get_position().y0,
|
394 |
+
0.02,
|
395 |
+
ax[0].get_position().y1 - ax[-1].get_position().y0,
|
396 |
+
]
|
397 |
+
)
|
398 |
+
fig.colorbar(
|
399 |
+
mpl.cm.ScalarMappable(
|
400 |
+
norm=mpl.colors.Normalize(
|
401 |
+
np.min(spectrogram_db), np.max(spectrogram_db)
|
402 |
+
),
|
403 |
+
cmap="magma",
|
404 |
+
),
|
405 |
+
ax=ax,
|
406 |
+
orientation="vertical",
|
407 |
+
format="%+2.0f dB",
|
408 |
+
cax=cax,
|
409 |
+
)
|
410 |
+
librosa.display.specshow(
|
411 |
+
spectrogram_db,
|
412 |
+
sr=sr,
|
413 |
+
hop_length=hop_length,
|
414 |
+
x_axis="time",
|
415 |
+
y_axis="mel",
|
416 |
+
ax=ax[i],
|
417 |
+
)
|
418 |
+
ax[i].set(title=names[i])
|
419 |
+
ax[i].yaxis.set_label_text(None)
|
420 |
+
ax[i].label_outer()
|
421 |
+
fig.savefig(path, bbox_inches="tight")
|
422 |
+
plt.close()
|
audiocraft/models/__init__.py
CHANGED
@@ -12,6 +12,10 @@ from . import builders, loaders
|
|
12 |
from .encodec import (
|
13 |
CompressionModel, EncodecModel, DAC,
|
14 |
HFEncodecModel, HFEncodecCompressionModel)
|
15 |
-
from .musicgen import MusicGen
|
16 |
from .lm import LMModel
|
|
|
|
|
17 |
from .encodec import CompressionModel, EncodecModel
|
|
|
|
|
|
|
|
12 |
from .encodec import (
|
13 |
CompressionModel, EncodecModel, DAC,
|
14 |
HFEncodecModel, HFEncodecCompressionModel)
|
|
|
15 |
from .lm import LMModel
|
16 |
+
from .lm_magnet import MagnetLMModel
|
17 |
+
from .flow_matching import FlowMatchingModel
|
18 |
from .encodec import CompressionModel, EncodecModel
|
19 |
+
from .musicgen import MusicGen
|
20 |
+
from .magnet import MAGNeT
|
21 |
+
from .unet import DiffusionUnet
|
audiocraft/models/builders.py
CHANGED
@@ -11,51 +11,53 @@ from the Hydra config.
|
|
11 |
|
12 |
import typing as tp
|
13 |
|
14 |
-
import audiocraft
|
15 |
import omegaconf
|
16 |
import torch
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
from ..modules.codebooks_patterns import (
|
21 |
-
CodebooksPatternProvider,
|
22 |
-
DelayedPatternProvider,
|
23 |
-
MusicLMPattern,
|
24 |
-
ParallelPatternProvider,
|
25 |
-
UnrolledPatternProvider,
|
26 |
-
CoarseFirstPattern,
|
27 |
-
)
|
28 |
-
from ..modules.conditioners import (
|
29 |
-
BaseConditioner,
|
30 |
-
ChromaStemConditioner,
|
31 |
-
CLAPEmbeddingConditioner,
|
32 |
-
ConditionFuser,
|
33 |
-
ConditioningProvider,
|
34 |
-
LUTConditioner,
|
35 |
-
T5Conditioner,
|
36 |
-
)
|
37 |
-
from .unet import DiffusionUnet
|
38 |
from .. import quantization as qt
|
39 |
-
from ..
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
-
def get_quantizer(
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
kwargs = dict_from_config(getattr(cfg, quantizer))
|
49 |
-
if quantizer !=
|
50 |
-
kwargs[
|
51 |
return klass(**kwargs)
|
52 |
|
53 |
|
54 |
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
55 |
-
if encoder_name ==
|
56 |
-
kwargs = dict_from_config(getattr(cfg,
|
57 |
-
encoder_override_kwargs = kwargs.pop(
|
58 |
-
decoder_override_kwargs = kwargs.pop(
|
59 |
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
60 |
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
61 |
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
@@ -67,44 +69,98 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
|
67 |
|
68 |
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
69 |
"""Instantiate a compression model."""
|
70 |
-
if cfg.compression_model ==
|
71 |
-
kwargs = dict_from_config(getattr(cfg,
|
72 |
-
encoder_name = kwargs.pop(
|
73 |
-
quantizer_name = kwargs.pop(
|
74 |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
75 |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
76 |
-
frame_rate = kwargs[
|
77 |
-
renormalize = kwargs.pop(
|
78 |
# deprecated params
|
79 |
-
kwargs.pop(
|
80 |
-
return EncodecModel(
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
else:
|
83 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
87 |
"""Instantiate a transformer LM."""
|
88 |
-
if cfg.lm_model
|
89 |
-
kwargs = dict_from_config(getattr(cfg,
|
90 |
-
n_q = kwargs[
|
91 |
-
q_modeling = kwargs.pop(
|
92 |
-
codebooks_pattern_cfg = getattr(cfg,
|
93 |
-
attribute_dropout = dict_from_config(getattr(cfg,
|
94 |
-
cls_free_guidance = dict_from_config(getattr(cfg,
|
95 |
-
cfg_prob, cfg_coef =
|
|
|
|
|
|
|
96 |
fuser = get_condition_fuser(cfg)
|
97 |
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
98 |
-
if len(fuser.fuse2cond[
|
99 |
-
kwargs[
|
100 |
if codebooks_pattern_cfg.modeling is None:
|
101 |
-
assert
|
102 |
-
|
|
|
103 |
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
104 |
-
{
|
105 |
)
|
|
|
106 |
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
107 |
-
|
|
|
108 |
pattern_provider=pattern_provider,
|
109 |
condition_provider=condition_provider,
|
110 |
fuser=fuser,
|
@@ -113,67 +169,84 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
|
113 |
attribute_dropout=attribute_dropout,
|
114 |
dtype=getattr(torch, cfg.dtype),
|
115 |
device=cfg.device,
|
116 |
-
**kwargs
|
117 |
).to(cfg.device)
|
118 |
else:
|
119 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
120 |
|
121 |
|
122 |
-
def get_conditioner_provider(
|
|
|
|
|
123 |
"""Instantiate a conditioning model."""
|
124 |
device = cfg.device
|
125 |
duration = cfg.dataset.segment_duration
|
126 |
-
cfg = getattr(cfg,
|
127 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
128 |
conditioners: tp.Dict[str, BaseConditioner] = {}
|
129 |
-
condition_provider_args = dict_cfg.pop(
|
130 |
-
condition_provider_args.pop(
|
131 |
-
condition_provider_args.pop(
|
132 |
|
133 |
for cond, cond_cfg in dict_cfg.items():
|
134 |
-
model_type = cond_cfg[
|
135 |
model_args = cond_cfg[model_type]
|
136 |
-
if model_type ==
|
137 |
-
conditioners[str(cond)] = T5Conditioner(
|
138 |
-
|
139 |
-
|
140 |
-
elif model_type ==
|
|
|
|
|
|
|
|
|
141 |
conditioners[str(cond)] = ChromaStemConditioner(
|
142 |
-
output_dim=output_dim,
|
143 |
-
duration=duration,
|
144 |
-
device=device,
|
145 |
-
**model_args
|
146 |
)
|
147 |
-
elif model_type
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
|
|
|
|
|
|
|
|
149 |
output_dim=output_dim,
|
150 |
device=device,
|
151 |
**model_args
|
152 |
)
|
153 |
else:
|
154 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
155 |
-
conditioner = ConditioningProvider(
|
|
|
|
|
156 |
return conditioner
|
157 |
|
158 |
|
159 |
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
160 |
"""Instantiate a condition fuser object."""
|
161 |
-
fuser_cfg = getattr(cfg,
|
162 |
-
fuser_methods = [
|
163 |
-
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
164 |
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
165 |
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
166 |
return fuser
|
167 |
|
168 |
|
169 |
-
def get_codebooks_pattern_provider(
|
|
|
|
|
170 |
"""Instantiate a codebooks pattern provider object."""
|
171 |
pattern_providers = {
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
}
|
178 |
name = cfg.modeling
|
179 |
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
@@ -181,20 +254,23 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
|
|
181 |
return klass(n_q, **kwargs)
|
182 |
|
183 |
|
184 |
-
def get_debug_compression_model(device=
|
185 |
"""Instantiate a debug compression model to be used for unit tests."""
|
186 |
-
assert sample_rate in [
|
|
|
|
|
|
|
187 |
model_ratios = {
|
188 |
16000: [10, 8, 8], # 25 Hz at 16kHz
|
189 |
-
32000: [10, 8, 16] # 25 Hz at 32kHz
|
190 |
}
|
191 |
ratios: tp.List[int] = model_ratios[sample_rate]
|
192 |
frame_rate = 25
|
193 |
seanet_kwargs: dict = {
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
}
|
199 |
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
200 |
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
@@ -202,8 +278,13 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
|
|
202 |
init_x = torch.randn(8, 32, 128)
|
203 |
quantizer(init_x, 1) # initialize kmeans etc.
|
204 |
compression_model = EncodecModel(
|
205 |
-
encoder,
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
207 |
return compression_model.eval()
|
208 |
|
209 |
|
@@ -211,48 +292,60 @@ def get_diffusion_model(cfg: omegaconf.DictConfig):
|
|
211 |
# TODO Find a way to infer the channels from dset
|
212 |
channels = cfg.channels
|
213 |
num_steps = cfg.schedule.num_steps
|
214 |
-
return DiffusionUnet(
|
215 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
216 |
|
217 |
|
218 |
def get_processor(cfg, sample_rate: int = 24000):
|
219 |
sample_processor = SampleProcessor()
|
220 |
if cfg.use:
|
221 |
kw = dict(cfg)
|
222 |
-
kw.pop(
|
223 |
-
kw.pop(
|
224 |
if cfg.name == "multi_band_processor":
|
225 |
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
226 |
return sample_processor
|
227 |
|
228 |
|
229 |
-
def get_debug_lm_model(device=
|
230 |
"""Instantiate a debug LM to be used for unit tests."""
|
231 |
pattern = DelayedPatternProvider(n_q=4)
|
232 |
dim = 16
|
233 |
providers = {
|
234 |
-
|
|
|
|
|
235 |
}
|
236 |
condition_provider = ConditioningProvider(providers)
|
237 |
fuser = ConditionFuser(
|
238 |
-
{
|
239 |
-
|
240 |
lm = LMModel(
|
241 |
-
pattern,
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
return lm.to(device).eval()
|
245 |
|
246 |
|
247 |
def get_wrapped_compression_model(
|
248 |
-
|
249 |
-
|
250 |
-
if hasattr(cfg,
|
251 |
if cfg.interleave_stereo_codebooks.use:
|
252 |
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
253 |
-
kwargs.pop(
|
254 |
-
compression_model = InterleaveStereoCompressionModel(
|
255 |
-
|
|
|
|
|
256 |
if cfg.compression_model_n_q is not None:
|
257 |
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
258 |
return compression_model
|
|
|
11 |
|
12 |
import typing as tp
|
13 |
|
|
|
14 |
import omegaconf
|
15 |
import torch
|
16 |
|
17 |
+
import audiocraft
|
18 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from .. import quantization as qt
|
20 |
+
from ..modules.codebooks_patterns import (CoarseFirstPattern,
|
21 |
+
CodebooksPatternProvider,
|
22 |
+
DelayedPatternProvider,
|
23 |
+
MusicLMPattern,
|
24 |
+
ParallelPatternProvider,
|
25 |
+
UnrolledPatternProvider)
|
26 |
+
from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner,
|
27 |
+
CLAPEmbeddingConditioner,
|
28 |
+
ConditionFuser, JascoCondConst,
|
29 |
+
ConditioningProvider, LUTConditioner,
|
30 |
+
T5Conditioner, StyleConditioner)
|
31 |
+
from ..modules.jasco_conditioners import (JascoConditioningProvider, ChordsEmbConditioner,
|
32 |
+
DrumsConditioner, MelodyConditioner)
|
33 |
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
34 |
+
from ..utils.utils import dict_from_config
|
35 |
+
from .encodec import (CompressionModel, EncodecModel,
|
36 |
+
InterleaveStereoCompressionModel)
|
37 |
+
from .lm import LMModel
|
38 |
+
from .lm_magnet import MagnetLMModel
|
39 |
+
from .flow_matching import FlowMatchingModel
|
40 |
+
from .unet import DiffusionUnet
|
41 |
+
|
42 |
|
43 |
|
44 |
+
def get_quantizer(
|
45 |
+
quantizer: str, cfg: omegaconf.DictConfig, dimension: int
|
46 |
+
) -> qt.BaseQuantizer:
|
47 |
+
klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[
|
48 |
+
quantizer
|
49 |
+
]
|
50 |
kwargs = dict_from_config(getattr(cfg, quantizer))
|
51 |
+
if quantizer != "no_quant":
|
52 |
+
kwargs["dimension"] = dimension
|
53 |
return klass(**kwargs)
|
54 |
|
55 |
|
56 |
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
57 |
+
if encoder_name == "seanet":
|
58 |
+
kwargs = dict_from_config(getattr(cfg, "seanet"))
|
59 |
+
encoder_override_kwargs = kwargs.pop("encoder")
|
60 |
+
decoder_override_kwargs = kwargs.pop("decoder")
|
61 |
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
62 |
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
63 |
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
|
|
69 |
|
70 |
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
71 |
"""Instantiate a compression model."""
|
72 |
+
if cfg.compression_model == "encodec":
|
73 |
+
kwargs = dict_from_config(getattr(cfg, "encodec"))
|
74 |
+
encoder_name = kwargs.pop("autoencoder")
|
75 |
+
quantizer_name = kwargs.pop("quantizer")
|
76 |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
77 |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
78 |
+
frame_rate = kwargs["sample_rate"] // encoder.hop_length
|
79 |
+
renormalize = kwargs.pop("renormalize", False)
|
80 |
# deprecated params
|
81 |
+
kwargs.pop("renorm", None)
|
82 |
+
return EncodecModel(
|
83 |
+
encoder,
|
84 |
+
decoder,
|
85 |
+
quantizer,
|
86 |
+
frame_rate=frame_rate,
|
87 |
+
renormalize=renormalize,
|
88 |
+
**kwargs,
|
89 |
+
).to(cfg.device)
|
90 |
else:
|
91 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
92 |
|
93 |
|
94 |
+
def get_jasco_model(cfg: omegaconf.DictConfig,
|
95 |
+
compression_model: tp.Optional[CompressionModel] = None) -> FlowMatchingModel:
|
96 |
+
kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
|
97 |
+
attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
|
98 |
+
cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
|
99 |
+
cfg_prob = cls_free_guidance["training_dropout"]
|
100 |
+
cfg_coef = cls_free_guidance["inference_coef"]
|
101 |
+
fuser = get_condition_fuser(cfg)
|
102 |
+
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
103 |
+
if JascoCondConst.DRM.value in condition_provider.conditioners: # use self_wav for drums
|
104 |
+
assert compression_model is not None
|
105 |
+
|
106 |
+
# use compression model for drums conditioning
|
107 |
+
condition_provider.conditioners.self_wav.compression_model = compression_model
|
108 |
+
condition_provider.conditioners.self_wav.compression_model.requires_grad_(False)
|
109 |
+
|
110 |
+
# downcast to jasco conditioning provider
|
111 |
+
seq_len = cfg.compression_model_framerate * cfg.dataset.segment_duration
|
112 |
+
chords_card = cfg.conditioners.chords.chords_emb.card if JascoCondConst.CRD.value in cfg.conditioners else -1
|
113 |
+
condition_provider = JascoConditioningProvider(device=condition_provider.device,
|
114 |
+
conditioners=condition_provider.conditioners,
|
115 |
+
chords_card=chords_card,
|
116 |
+
sequence_length=seq_len)
|
117 |
+
|
118 |
+
if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
|
119 |
+
kwargs["cross_attention"] = True
|
120 |
+
|
121 |
+
kwargs.pop("n_q", None)
|
122 |
+
kwargs.pop("card", None)
|
123 |
+
|
124 |
+
return FlowMatchingModel(
|
125 |
+
condition_provider=condition_provider,
|
126 |
+
fuser=fuser,
|
127 |
+
cfg_dropout=cfg_prob,
|
128 |
+
cfg_coef=cfg_coef,
|
129 |
+
attribute_dropout=attribute_dropout,
|
130 |
+
dtype=getattr(torch, cfg.dtype),
|
131 |
+
device=cfg.device,
|
132 |
+
**kwargs,
|
133 |
+
).to(cfg.device)
|
134 |
+
|
135 |
+
|
136 |
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
137 |
"""Instantiate a transformer LM."""
|
138 |
+
if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]:
|
139 |
+
kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
|
140 |
+
n_q = kwargs["n_q"]
|
141 |
+
q_modeling = kwargs.pop("q_modeling", None)
|
142 |
+
codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern")
|
143 |
+
attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
|
144 |
+
cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
|
145 |
+
cfg_prob, cfg_coef = (
|
146 |
+
cls_free_guidance["training_dropout"],
|
147 |
+
cls_free_guidance["inference_coef"],
|
148 |
+
)
|
149 |
fuser = get_condition_fuser(cfg)
|
150 |
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
151 |
+
if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
|
152 |
+
kwargs["cross_attention"] = True
|
153 |
if codebooks_pattern_cfg.modeling is None:
|
154 |
+
assert (
|
155 |
+
q_modeling is not None
|
156 |
+
), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
157 |
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
158 |
+
{"modeling": q_modeling, "delay": {"delays": list(range(n_q))}}
|
159 |
)
|
160 |
+
|
161 |
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
162 |
+
lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel
|
163 |
+
return lm_class(
|
164 |
pattern_provider=pattern_provider,
|
165 |
condition_provider=condition_provider,
|
166 |
fuser=fuser,
|
|
|
169 |
attribute_dropout=attribute_dropout,
|
170 |
dtype=getattr(torch, cfg.dtype),
|
171 |
device=cfg.device,
|
172 |
+
**kwargs,
|
173 |
).to(cfg.device)
|
174 |
else:
|
175 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
176 |
|
177 |
|
178 |
+
def get_conditioner_provider(
|
179 |
+
output_dim: int, cfg: omegaconf.DictConfig
|
180 |
+
) -> ConditioningProvider:
|
181 |
"""Instantiate a conditioning model."""
|
182 |
device = cfg.device
|
183 |
duration = cfg.dataset.segment_duration
|
184 |
+
cfg = getattr(cfg, "conditioners")
|
185 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
186 |
conditioners: tp.Dict[str, BaseConditioner] = {}
|
187 |
+
condition_provider_args = dict_cfg.pop("args", {})
|
188 |
+
condition_provider_args.pop("merge_text_conditions_p", None)
|
189 |
+
condition_provider_args.pop("drop_desc_p", None)
|
190 |
|
191 |
for cond, cond_cfg in dict_cfg.items():
|
192 |
+
model_type = cond_cfg["model"]
|
193 |
model_args = cond_cfg[model_type]
|
194 |
+
if model_type == "t5":
|
195 |
+
conditioners[str(cond)] = T5Conditioner(
|
196 |
+
output_dim=output_dim, device=device, **model_args
|
197 |
+
)
|
198 |
+
elif model_type == "lut":
|
199 |
+
conditioners[str(cond)] = LUTConditioner(
|
200 |
+
output_dim=output_dim, **model_args
|
201 |
+
)
|
202 |
+
elif model_type == "chroma_stem":
|
203 |
conditioners[str(cond)] = ChromaStemConditioner(
|
204 |
+
output_dim=output_dim, duration=duration, device=device, **model_args
|
|
|
|
|
|
|
205 |
)
|
206 |
+
elif model_type in {"chords_emb", "drum_latents", "melody"}:
|
207 |
+
conditioners_classes = {"chords_emb": ChordsEmbConditioner,
|
208 |
+
"drum_latents": DrumsConditioner,
|
209 |
+
"melody": MelodyConditioner}
|
210 |
+
conditioner_class = conditioners_classes[model_type]
|
211 |
+
conditioners[str(cond)] = conditioner_class(device=device, **model_args)
|
212 |
+
elif model_type == "clap":
|
213 |
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
214 |
+
output_dim=output_dim, device=device, **model_args
|
215 |
+
)
|
216 |
+
elif model_type == 'style':
|
217 |
+
conditioners[str(cond)] = StyleConditioner(
|
218 |
output_dim=output_dim,
|
219 |
device=device,
|
220 |
**model_args
|
221 |
)
|
222 |
else:
|
223 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
224 |
+
conditioner = ConditioningProvider(
|
225 |
+
conditioners, device=device, **condition_provider_args
|
226 |
+
)
|
227 |
return conditioner
|
228 |
|
229 |
|
230 |
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
231 |
"""Instantiate a condition fuser object."""
|
232 |
+
fuser_cfg = getattr(cfg, "fuser")
|
233 |
+
fuser_methods = ["sum", "cross", "prepend", "ignore", "input_interpolate"]
|
234 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods if k in fuser_cfg}
|
235 |
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
236 |
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
237 |
return fuser
|
238 |
|
239 |
|
240 |
+
def get_codebooks_pattern_provider(
|
241 |
+
n_q: int, cfg: omegaconf.DictConfig
|
242 |
+
) -> CodebooksPatternProvider:
|
243 |
"""Instantiate a codebooks pattern provider object."""
|
244 |
pattern_providers = {
|
245 |
+
"parallel": ParallelPatternProvider,
|
246 |
+
"delay": DelayedPatternProvider,
|
247 |
+
"unroll": UnrolledPatternProvider,
|
248 |
+
"coarse_first": CoarseFirstPattern,
|
249 |
+
"musiclm": MusicLMPattern,
|
250 |
}
|
251 |
name = cfg.modeling
|
252 |
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
|
|
254 |
return klass(n_q, **kwargs)
|
255 |
|
256 |
|
257 |
+
def get_debug_compression_model(device="cpu", sample_rate: int = 32000):
|
258 |
"""Instantiate a debug compression model to be used for unit tests."""
|
259 |
+
assert sample_rate in [
|
260 |
+
16000,
|
261 |
+
32000,
|
262 |
+
], "unsupported sample rate for debug compression model"
|
263 |
model_ratios = {
|
264 |
16000: [10, 8, 8], # 25 Hz at 16kHz
|
265 |
+
32000: [10, 8, 16], # 25 Hz at 32kHz
|
266 |
}
|
267 |
ratios: tp.List[int] = model_ratios[sample_rate]
|
268 |
frame_rate = 25
|
269 |
seanet_kwargs: dict = {
|
270 |
+
"n_filters": 4,
|
271 |
+
"n_residual_layers": 1,
|
272 |
+
"dimension": 32,
|
273 |
+
"ratios": ratios,
|
274 |
}
|
275 |
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
276 |
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
|
|
278 |
init_x = torch.randn(8, 32, 128)
|
279 |
quantizer(init_x, 1) # initialize kmeans etc.
|
280 |
compression_model = EncodecModel(
|
281 |
+
encoder,
|
282 |
+
decoder,
|
283 |
+
quantizer,
|
284 |
+
frame_rate=frame_rate,
|
285 |
+
sample_rate=sample_rate,
|
286 |
+
channels=1,
|
287 |
+
).to(device)
|
288 |
return compression_model.eval()
|
289 |
|
290 |
|
|
|
292 |
# TODO Find a way to infer the channels from dset
|
293 |
channels = cfg.channels
|
294 |
num_steps = cfg.schedule.num_steps
|
295 |
+
return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
|
|
296 |
|
297 |
|
298 |
def get_processor(cfg, sample_rate: int = 24000):
|
299 |
sample_processor = SampleProcessor()
|
300 |
if cfg.use:
|
301 |
kw = dict(cfg)
|
302 |
+
kw.pop("use")
|
303 |
+
kw.pop("name")
|
304 |
if cfg.name == "multi_band_processor":
|
305 |
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
306 |
return sample_processor
|
307 |
|
308 |
|
309 |
+
def get_debug_lm_model(device="cpu"):
|
310 |
"""Instantiate a debug LM to be used for unit tests."""
|
311 |
pattern = DelayedPatternProvider(n_q=4)
|
312 |
dim = 16
|
313 |
providers = {
|
314 |
+
"description": LUTConditioner(
|
315 |
+
n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"
|
316 |
+
),
|
317 |
}
|
318 |
condition_provider = ConditioningProvider(providers)
|
319 |
fuser = ConditionFuser(
|
320 |
+
{"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []}
|
321 |
+
)
|
322 |
lm = LMModel(
|
323 |
+
pattern,
|
324 |
+
condition_provider,
|
325 |
+
fuser,
|
326 |
+
n_q=4,
|
327 |
+
card=400,
|
328 |
+
dim=dim,
|
329 |
+
num_heads=4,
|
330 |
+
custom=True,
|
331 |
+
num_layers=2,
|
332 |
+
cross_attention=True,
|
333 |
+
causal=True,
|
334 |
+
)
|
335 |
return lm.to(device).eval()
|
336 |
|
337 |
|
338 |
def get_wrapped_compression_model(
|
339 |
+
compression_model: CompressionModel, cfg: omegaconf.DictConfig
|
340 |
+
) -> CompressionModel:
|
341 |
+
if hasattr(cfg, "interleave_stereo_codebooks"):
|
342 |
if cfg.interleave_stereo_codebooks.use:
|
343 |
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
344 |
+
kwargs.pop("use")
|
345 |
+
compression_model = InterleaveStereoCompressionModel(
|
346 |
+
compression_model, **kwargs
|
347 |
+
)
|
348 |
+
if hasattr(cfg, "compression_model_n_q"):
|
349 |
if cfg.compression_model_n_q is not None:
|
350 |
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
351 |
return compression_model
|
audiocraft/models/flow_matching.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torchdiffeq import odeint # type: ignore
|
15 |
+
from ..modules.streaming import StreamingModule
|
16 |
+
from ..modules.transformer import create_norm_fn, StreamingTransformerLayer
|
17 |
+
from ..modules.unet_transformer import UnetTransformer
|
18 |
+
from ..modules.conditioners import (
|
19 |
+
ConditionFuser,
|
20 |
+
ClassifierFreeGuidanceDropout,
|
21 |
+
AttributeDropout,
|
22 |
+
ConditioningAttributes,
|
23 |
+
JascoCondConst
|
24 |
+
)
|
25 |
+
from ..modules.jasco_conditioners import JascoConditioningProvider
|
26 |
+
from ..modules.activations import get_activation_fn
|
27 |
+
|
28 |
+
from .lm import ConditionTensors, init_layer
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class FMOutput:
|
36 |
+
latents: torch.Tensor # [B, T, D]
|
37 |
+
mask: torch.Tensor # [B, T]
|
38 |
+
|
39 |
+
|
40 |
+
class CFGTerm:
|
41 |
+
"""
|
42 |
+
Base class for Multi Source Classifier-Free Guidance (CFG) terms. This class represents a term in the CFG process,
|
43 |
+
which is used to guide the generation process by adjusting the influence of different conditions.
|
44 |
+
Attributes:
|
45 |
+
conditions (dict): A dictionary of conditions that influence the generation process.
|
46 |
+
weight (float): The weight of the CFG term, determining its influence on the generation.
|
47 |
+
"""
|
48 |
+
def __init__(self, conditions, weight):
|
49 |
+
self.conditions = conditions
|
50 |
+
self.weight = weight
|
51 |
+
|
52 |
+
def drop_irrelevant_conds(self, conditions):
|
53 |
+
"""
|
54 |
+
Drops irrelevant conditions from the CFG term. This method should be implemented by subclasses.
|
55 |
+
Args:
|
56 |
+
conditions (dict): The conditions to be filtered.
|
57 |
+
Raises:
|
58 |
+
NotImplementedError: If the method is not implemented in a subclass.
|
59 |
+
"""
|
60 |
+
raise NotImplementedError("No base implementation for setting generation params.")
|
61 |
+
|
62 |
+
|
63 |
+
class AllCFGTerm(CFGTerm):
|
64 |
+
"""
|
65 |
+
A CFG term that retains all conditions. This class does not drop any condition.
|
66 |
+
"""
|
67 |
+
def __init__(self, conditions, weight):
|
68 |
+
super().__init__(conditions, weight)
|
69 |
+
self.drop_irrelevant_conds()
|
70 |
+
|
71 |
+
def drop_irrelevant_conds(self):
|
72 |
+
pass
|
73 |
+
|
74 |
+
|
75 |
+
class NullCFGTerm(CFGTerm):
|
76 |
+
"""
|
77 |
+
A CFG term that drops all conditions, effectively nullifying their influence.
|
78 |
+
"""
|
79 |
+
def __init__(self, conditions, weight):
|
80 |
+
super().__init__(conditions, weight)
|
81 |
+
self.drop_irrelevant_conds()
|
82 |
+
|
83 |
+
def drop_irrelevant_conds(self):
|
84 |
+
"""
|
85 |
+
Drops all conditions by applying a dropout with probability 1.0, effectively nullifying their influence.
|
86 |
+
"""
|
87 |
+
self.conditions = ClassifierFreeGuidanceDropout(p=1.0)(
|
88 |
+
samples=self.conditions,
|
89 |
+
cond_types=["wav", "text", "symbolic"])
|
90 |
+
|
91 |
+
|
92 |
+
class TextCFGTerm(CFGTerm):
|
93 |
+
"""
|
94 |
+
A CFG term that selectively drops conditions based on specified dropout probabilities for different types
|
95 |
+
of conditions, such as 'symbolic' and 'wav'.
|
96 |
+
"""
|
97 |
+
def __init__(self, conditions, weight, model_att_dropout):
|
98 |
+
"""
|
99 |
+
Initializes a TextCFGTerm with specified conditions, weight, and model attention dropout configuration.
|
100 |
+
Args:
|
101 |
+
conditions (dict): The conditions to be used in the CFG process.
|
102 |
+
weight (float): The weight of the CFG term.
|
103 |
+
model_att_dropout (object): The attribute dropouts used by the model.
|
104 |
+
"""
|
105 |
+
super().__init__(conditions, weight)
|
106 |
+
if 'symbolic' in model_att_dropout.p:
|
107 |
+
self.drop_symbolics = {k: 1.0 for k in model_att_dropout.p['symbolic'].keys()}
|
108 |
+
else:
|
109 |
+
self.drop_symbolics = {}
|
110 |
+
if 'wav' in model_att_dropout.p:
|
111 |
+
self.drop_wav = {k: 1.0 for k in model_att_dropout.p['wav'].keys()}
|
112 |
+
else:
|
113 |
+
self.drop_wav = {}
|
114 |
+
self.drop_irrelevant_conds()
|
115 |
+
|
116 |
+
def drop_irrelevant_conds(self):
|
117 |
+
self.conditions = AttributeDropout({'symbolic': self.drop_symbolics,
|
118 |
+
'wav': self.drop_wav})(self.conditions) # drop temporal conds
|
119 |
+
|
120 |
+
|
121 |
+
class FlowMatchingModel(StreamingModule):
|
122 |
+
"""
|
123 |
+
A flow matching model inherits from StreamingModule.
|
124 |
+
This model uses a transformer architecture to process and fuse conditions, applying learned embeddings and
|
125 |
+
transformations and predicts multi-source guided vector fields.
|
126 |
+
Attributes:
|
127 |
+
condition_provider (JascoConditioningProvider): Provider for conditioning attributes.
|
128 |
+
fuser (ConditionFuser): Fuser for combining multiple conditions.
|
129 |
+
dim (int): Dimensionality of the model's main features.
|
130 |
+
num_heads (int): Number of attention heads in the transformer.
|
131 |
+
flow_dim (int): Dimensionality of the flow features.
|
132 |
+
chords_dim (int): Dimensionality for chord embeddings, if used.
|
133 |
+
drums_dim (int): Dimensionality for drums embeddings, if used.
|
134 |
+
melody_dim (int): Dimensionality for melody embeddings, if used.
|
135 |
+
hidden_scale (int): Scaling factor for the dimensionality of the feedforward network in the transformer.
|
136 |
+
norm (str): Type of normalization to use ('layer_norm' or other supported types).
|
137 |
+
norm_first (bool): Whether to apply normalization before other operations in the transformer layers.
|
138 |
+
bias_proj (bool): Whether to include bias in the projection layers.
|
139 |
+
weight_init (Optional[str]): Method for initializing weights.
|
140 |
+
depthwise_init (Optional[str]): Method for initializing depthwise convolutional layers.
|
141 |
+
zero_bias_init (bool): Whether to initialize biases to zero.
|
142 |
+
cfg_dropout (float): Dropout rate for configuration settings.
|
143 |
+
cfg_coef (float): Coefficient for configuration influence.
|
144 |
+
attribute_dropout (Dict[str, Dict[str, float]]): Dropout rates for specific attributes.
|
145 |
+
time_embedding_dim (int): Dimensionality of time embeddings.
|
146 |
+
**kwargs: Additional keyword arguments for the transformer.
|
147 |
+
Methods:
|
148 |
+
__init__: Initializes the model with the specified attributes and configuration.
|
149 |
+
"""
|
150 |
+
def __init__(self, condition_provider: JascoConditioningProvider,
|
151 |
+
fuser: ConditionFuser,
|
152 |
+
dim: int = 128,
|
153 |
+
num_heads: int = 8,
|
154 |
+
flow_dim: int = 128,
|
155 |
+
chords_dim: int = 0,
|
156 |
+
drums_dim: int = 0,
|
157 |
+
melody_dim: int = 0,
|
158 |
+
hidden_scale: int = 4,
|
159 |
+
norm: str = 'layer_norm',
|
160 |
+
norm_first: bool = False,
|
161 |
+
bias_proj: bool = True,
|
162 |
+
weight_init: tp.Optional[str] = None,
|
163 |
+
depthwise_init: tp.Optional[str] = None,
|
164 |
+
zero_bias_init: bool = False,
|
165 |
+
cfg_dropout: float = 0,
|
166 |
+
cfg_coef: float = 1.0,
|
167 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
|
168 |
+
time_embedding_dim: int = 128,
|
169 |
+
**kwargs):
|
170 |
+
super().__init__()
|
171 |
+
self.cfg_coef = cfg_coef
|
172 |
+
|
173 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
174 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
175 |
+
self.condition_provider = condition_provider
|
176 |
+
self.fuser = fuser
|
177 |
+
self.dim = dim # transformer dim
|
178 |
+
self.flow_dim = flow_dim
|
179 |
+
self.chords_dim = chords_dim
|
180 |
+
self.emb = nn.Linear(flow_dim + chords_dim + drums_dim + melody_dim, dim, bias=False)
|
181 |
+
if 'activation' in kwargs:
|
182 |
+
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
183 |
+
|
184 |
+
self.transformer = UnetTransformer(
|
185 |
+
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
186 |
+
norm=norm, norm_first=norm_first,
|
187 |
+
layer_class=StreamingTransformerLayer,
|
188 |
+
**kwargs)
|
189 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
190 |
+
if norm_first:
|
191 |
+
self.out_norm = create_norm_fn(norm, dim)
|
192 |
+
self.linear = nn.Linear(dim, flow_dim, bias=bias_proj)
|
193 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
194 |
+
self._fsdp: tp.Optional[nn.Module]
|
195 |
+
self.__dict__['_fsdp'] = None
|
196 |
+
|
197 |
+
# init time parameter embedding
|
198 |
+
self.d_temb1 = time_embedding_dim
|
199 |
+
self.d_temb2 = 4 * time_embedding_dim
|
200 |
+
self.temb = nn.Module()
|
201 |
+
self.temb.dense = nn.ModuleList([
|
202 |
+
torch.nn.Linear(self.d_temb1,
|
203 |
+
self.d_temb2),
|
204 |
+
torch.nn.Linear(self.d_temb2,
|
205 |
+
self.d_temb2),
|
206 |
+
])
|
207 |
+
self.temb_proj = nn.Linear(self.d_temb2, dim)
|
208 |
+
|
209 |
+
def _get_timestep_embedding(self, timesteps, embedding_dim):
|
210 |
+
"""
|
211 |
+
#######################################################################################################
|
212 |
+
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
|
213 |
+
#######################################################################################################
|
214 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
215 |
+
From Fairseq.
|
216 |
+
Build sinusoidal embeddings.
|
217 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
218 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
219 |
+
"""
|
220 |
+
assert len(timesteps.shape) == 1
|
221 |
+
|
222 |
+
half_dim = embedding_dim // 2
|
223 |
+
emb = math.log(10000) / (half_dim - 1)
|
224 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
225 |
+
emb = emb.to(device=timesteps.device)
|
226 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
227 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
228 |
+
if embedding_dim % 2 == 1: # zero pad
|
229 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
230 |
+
return emb
|
231 |
+
|
232 |
+
def _embed_time_parameter(self, t: torch.Tensor):
|
233 |
+
"""
|
234 |
+
#######################################################################################################
|
235 |
+
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
|
236 |
+
#######################################################################################################
|
237 |
+
"""
|
238 |
+
temb = self._get_timestep_embedding(t.flatten(), self.d_temb1)
|
239 |
+
temb = self.temb.dense[0](temb)
|
240 |
+
temb = temb * torch.sigmoid(temb) # swish activation
|
241 |
+
temb = self.temb.dense[1](temb)
|
242 |
+
return temb
|
243 |
+
|
244 |
+
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
245 |
+
"""Initialization of the transformer module weights.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
249 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
250 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
251 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
252 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
253 |
+
"""
|
254 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
255 |
+
assert depthwise_init is None or weight_init is not None, \
|
256 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
257 |
+
assert not zero_bias_init or weight_init is not None, \
|
258 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
259 |
+
|
260 |
+
if weight_init is None:
|
261 |
+
return
|
262 |
+
|
263 |
+
init_layer(self.emb, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
264 |
+
|
265 |
+
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
266 |
+
depth = None
|
267 |
+
if depthwise_init == 'current':
|
268 |
+
depth = layer_idx + 1
|
269 |
+
elif depthwise_init == 'global':
|
270 |
+
depth = len(self.transformer.layers)
|
271 |
+
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
272 |
+
tr_layer.apply(init_fn)
|
273 |
+
|
274 |
+
init_layer(self.linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
275 |
+
|
276 |
+
def _align_seq_length(self,
|
277 |
+
cond: torch.Tensor,
|
278 |
+
seq_len: int = 500):
|
279 |
+
# trim if needed
|
280 |
+
cond = cond[:, :seq_len, :]
|
281 |
+
|
282 |
+
# pad if needed
|
283 |
+
B, T, C = cond.shape
|
284 |
+
if T < seq_len:
|
285 |
+
cond = torch.cat((cond, torch.zeros((B, seq_len - T, C), dtype=cond.dtype, device=cond.device)), dim=1)
|
286 |
+
|
287 |
+
return cond
|
288 |
+
|
289 |
+
def forward(self,
|
290 |
+
latents: torch.Tensor,
|
291 |
+
t: torch.Tensor,
|
292 |
+
conditions: tp.List[ConditioningAttributes],
|
293 |
+
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
|
294 |
+
"""Apply flow matching forward pass on latents and conditions.
|
295 |
+
Given a tensor of noisy latents of shape [B, T, D] with D the flow dim and T the sequence steps,
|
296 |
+
and a time parameter tensor t, return the vector field with shape [B, T, D].
|
297 |
+
|
298 |
+
Args:
|
299 |
+
latents (torch.Tensor): noisy latents.
|
300 |
+
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
301 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
302 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
303 |
+
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
304 |
+
tensors, see `conditions`.
|
305 |
+
Returns:
|
306 |
+
torch.Tensor: estimated vector field v_theta.
|
307 |
+
"""
|
308 |
+
assert condition_tensors is not None, "FlowMatchingModel require pre-calculation of condition tensors"
|
309 |
+
assert not conditions, "Shouldn't pass unprocessed conditions to FlowMatchingModel."
|
310 |
+
|
311 |
+
B, T, D = latents.shape
|
312 |
+
x = latents
|
313 |
+
|
314 |
+
# concat temporal conditions on the feature dimension
|
315 |
+
temporal_conds = JascoCondConst.ALL.value
|
316 |
+
for cond in temporal_conds:
|
317 |
+
if cond not in condition_tensors:
|
318 |
+
continue
|
319 |
+
c = self._align_seq_length(condition_tensors[cond][0], seq_len=T)
|
320 |
+
x = torch.concat((x, c), dim=-1)
|
321 |
+
|
322 |
+
# project to transformer dimension
|
323 |
+
input_ = self.emb(x)
|
324 |
+
|
325 |
+
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
326 |
+
|
327 |
+
# embed time parameter
|
328 |
+
t_embs = self._embed_time_parameter(t)
|
329 |
+
|
330 |
+
# add it to cross_attention_input
|
331 |
+
cross_attention_input = cross_attention_input + self.temb_proj(t_embs[:, None, :])
|
332 |
+
|
333 |
+
out = self.transformer(input_, cross_attention_src=cross_attention_input)
|
334 |
+
|
335 |
+
if self.out_norm:
|
336 |
+
out = self.out_norm(out)
|
337 |
+
v_theta = self.linear(out) # [B, T, D]
|
338 |
+
|
339 |
+
# remove the prefix from the model outputs
|
340 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
341 |
+
v_theta = v_theta[:, :, -T:]
|
342 |
+
|
343 |
+
return v_theta # [B, T, D]
|
344 |
+
|
345 |
+
def _multi_source_cfg_preprocess(self,
|
346 |
+
conditions: tp.List[ConditioningAttributes],
|
347 |
+
cfg_coef_all: float,
|
348 |
+
cfg_coef_txt: float,
|
349 |
+
min_weight: float = 1e-6):
|
350 |
+
"""
|
351 |
+
Preprocesses the CFG terms for multi-source conditional generation.
|
352 |
+
Args:
|
353 |
+
conditions (list): A list of conditions to be applied.
|
354 |
+
cfg_coef_all (float): The coefficient for all conditions.
|
355 |
+
cfg_coef_txt (float): The coefficient for text conditions.
|
356 |
+
min_weight (float): The minimal absolute weight for calculating a CFG term.
|
357 |
+
Returns:
|
358 |
+
tuple: A tuple containing condition_tensors and cfg_terms.
|
359 |
+
condition_tensors is a dictionary or ConditionTensors object with tokenized conditions.
|
360 |
+
cfg_terms is a list of CFGTerm objects with weights adjusted based on the coefficients.
|
361 |
+
"""
|
362 |
+
condition_tensors: tp.Optional[ConditionTensors]
|
363 |
+
cfg_terms = []
|
364 |
+
if conditions:
|
365 |
+
# conditional terms
|
366 |
+
cfg_terms = [AllCFGTerm(conditions=conditions, weight=cfg_coef_all),
|
367 |
+
TextCFGTerm(conditions=conditions, weight=cfg_coef_txt,
|
368 |
+
model_att_dropout=self.att_dropout)]
|
369 |
+
|
370 |
+
# add null term
|
371 |
+
cfg_terms.append(NullCFGTerm(conditions=conditions, weight=1 - sum([ct.weight for ct in cfg_terms])))
|
372 |
+
|
373 |
+
# remove terms with negligible weight
|
374 |
+
for ct in cfg_terms:
|
375 |
+
if abs(ct.weight) < min_weight:
|
376 |
+
cfg_terms.remove(ct)
|
377 |
+
|
378 |
+
conds: tp.List[ConditioningAttributes] = sum([ct.conditions for ct in cfg_terms], [])
|
379 |
+
tokenized = self.condition_provider.tokenize(conds)
|
380 |
+
condition_tensors = self.condition_provider(tokenized)
|
381 |
+
else:
|
382 |
+
condition_tensors = {}
|
383 |
+
|
384 |
+
return condition_tensors, cfg_terms
|
385 |
+
|
386 |
+
def estimated_vector_field(self, z, t, condition_tensors=None, cfg_terms=[]):
|
387 |
+
"""
|
388 |
+
Estimates the vector field for the given latent variables and time parameter,
|
389 |
+
conditioned on the provided conditions.
|
390 |
+
Args:
|
391 |
+
z (Tensor): The latent variables.
|
392 |
+
t (float): The time variable.
|
393 |
+
condition_tensors (ConditionTensors, optional): The condition tensors. Defaults to None.
|
394 |
+
cfg_terms (list, optional): The list of CFG terms. Defaults to an empty list.
|
395 |
+
Returns:
|
396 |
+
Tensor: The estimated vector field.
|
397 |
+
"""
|
398 |
+
if len(cfg_terms) > 1:
|
399 |
+
z = z.repeat(len(cfg_terms), 1, 1) # duplicate noisy latents for multi-source CFG
|
400 |
+
v_thetas = self(latents=z, t=t, conditions=[], condition_tensors=condition_tensors)
|
401 |
+
return self._multi_source_cfg_postprocess(v_thetas, cfg_terms)
|
402 |
+
|
403 |
+
def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms):
|
404 |
+
"""
|
405 |
+
Postprocesses the vector fields generated for each CFG term to combine them into a single vector field.
|
406 |
+
Multi source guidance occurs here.
|
407 |
+
Args:
|
408 |
+
v_thetas (Tensor): The vector fields for each CFG term.
|
409 |
+
cfg_terms (list): The CFG terms used.
|
410 |
+
Returns:
|
411 |
+
Tensor: The combined vector field.
|
412 |
+
"""
|
413 |
+
if len(cfg_terms) <= 1:
|
414 |
+
return v_thetas
|
415 |
+
v_theta_per_term = v_thetas.chunk(len(cfg_terms))
|
416 |
+
return sum([ct.weight * term_vf for ct, term_vf in zip(cfg_terms, v_theta_per_term)])
|
417 |
+
|
418 |
+
@torch.no_grad()
|
419 |
+
def generate(self,
|
420 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
421 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
422 |
+
num_samples: tp.Optional[int] = None,
|
423 |
+
max_gen_len: int = 256,
|
424 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
425 |
+
cfg_coef_all: float = 3.0,
|
426 |
+
cfg_coef_txt: float = 1.0,
|
427 |
+
euler: bool = False,
|
428 |
+
euler_steps: int = 100,
|
429 |
+
ode_rtol: float = 1e-5,
|
430 |
+
ode_atol: float = 1e-5,
|
431 |
+
) -> torch.Tensor:
|
432 |
+
"""
|
433 |
+
Generate audio latents given a prompt or unconditionally. This method supports both Euler integration
|
434 |
+
and adaptive ODE solving to generate sequences based on the specified conditions and configuration coefficients.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
prompt (torch.Tensor, optional): Initial prompt to condition the generation. defaults to None
|
438 |
+
conditions (List[ConditioningAttributes]): List of conditioning attributes - text, symbolic or audio.
|
439 |
+
num_samples (int, optional): Number of samples to generate.
|
440 |
+
If None, it is inferred from the number of conditions.
|
441 |
+
max_gen_len (int): Maximum length of the generated sequence.
|
442 |
+
callback (Callable[[int, int], None], optional): Callback function to monitor the generation process.
|
443 |
+
cfg_coef_all (float): Coefficient for the fully conditional CFG term.
|
444 |
+
cfg_coef_txt (float): Coefficient for text CFG term.
|
445 |
+
euler (bool): If True, use Euler integration, otherwise use adaptive ODE solver.
|
446 |
+
euler_steps (int): Number of Euler steps to perform if Euler integration is used.
|
447 |
+
ode_rtol (float): ODE solver rtol threshold.
|
448 |
+
ode_atol (float): ODE solver atol threshold.
|
449 |
+
|
450 |
+
Returns:
|
451 |
+
torch.Tensor: Generated latents, shaped as (num_samples, max_gen_len, feature_dim).
|
452 |
+
"""
|
453 |
+
|
454 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
455 |
+
first_param = next(iter(self.parameters()))
|
456 |
+
device = first_param.device
|
457 |
+
|
458 |
+
# Checking all input shapes are consistent.
|
459 |
+
possible_num_samples = []
|
460 |
+
if num_samples is not None:
|
461 |
+
possible_num_samples.append(num_samples)
|
462 |
+
elif prompt is not None:
|
463 |
+
possible_num_samples.append(prompt.shape[0])
|
464 |
+
elif conditions:
|
465 |
+
possible_num_samples.append(len(conditions))
|
466 |
+
else:
|
467 |
+
possible_num_samples.append(1)
|
468 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
469 |
+
num_samples = possible_num_samples[0]
|
470 |
+
|
471 |
+
condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(conditions, cfg_coef_all, cfg_coef_txt)
|
472 |
+
|
473 |
+
# flow matching inference
|
474 |
+
B, T, D = num_samples, max_gen_len, self.flow_dim
|
475 |
+
|
476 |
+
z_0 = torch.randn((B, T, D), device=device)
|
477 |
+
|
478 |
+
if euler:
|
479 |
+
# vanilla Euler intergration
|
480 |
+
dt = (1 / euler_steps)
|
481 |
+
z = z_0
|
482 |
+
t = torch.zeros((1, ), device=device)
|
483 |
+
for _ in range(euler_steps):
|
484 |
+
v_theta = self.estimated_vector_field(z, t,
|
485 |
+
condition_tensors=condition_tensors,
|
486 |
+
cfg_terms=cfg_terms)
|
487 |
+
z = z + dt * v_theta
|
488 |
+
t = t + dt
|
489 |
+
z_1 = z
|
490 |
+
else:
|
491 |
+
# solve with dynamic ode integrator (dopri5)
|
492 |
+
t = torch.tensor([0, 1.0 - 1e-5], device=device)
|
493 |
+
num_evals = 0
|
494 |
+
|
495 |
+
# define ode vector field function
|
496 |
+
def inner_ode_func(t, z):
|
497 |
+
nonlocal num_evals
|
498 |
+
num_evals += 1
|
499 |
+
if callback is not None:
|
500 |
+
ESTIMATED_ODE_SOLVER_STEPS = 300
|
501 |
+
callback(num_evals, ESTIMATED_ODE_SOLVER_STEPS)
|
502 |
+
return self.estimated_vector_field(z, t,
|
503 |
+
condition_tensors=condition_tensors,
|
504 |
+
cfg_terms=cfg_terms)
|
505 |
+
|
506 |
+
ode_opts: dict = {"options": {}}
|
507 |
+
z = odeint(
|
508 |
+
inner_ode_func,
|
509 |
+
z_0,
|
510 |
+
t,
|
511 |
+
**{"atol": ode_atol, "rtol": ode_rtol, **ode_opts},
|
512 |
+
)
|
513 |
+
logger.info("Generated in %d steps", num_evals)
|
514 |
+
z_1 = z[-1]
|
515 |
+
|
516 |
+
return z_1
|
audiocraft/models/genmodel.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Base implementation for audio generative models. This base implementation
|
9 |
+
combines all the required components to run inference with pretrained audio
|
10 |
+
generative models. It can be easily inherited by downstream model classes to
|
11 |
+
provide easy access to the generation API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
from abc import ABC, abstractmethod
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from .encodec import CompressionModel
|
21 |
+
from .lm import LMModel
|
22 |
+
from .builders import get_wrapped_compression_model
|
23 |
+
from ..data.audio_utils import convert_audio
|
24 |
+
from ..modules.conditioners import ConditioningAttributes
|
25 |
+
from ..utils.autocast import TorchAutocast
|
26 |
+
|
27 |
+
|
28 |
+
class BaseGenModel(ABC):
|
29 |
+
"""Base generative model with convenient generation API.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
name (str): name of the model.
|
33 |
+
compression_model (CompressionModel): Compression model
|
34 |
+
used to map audio to invertible discrete representations.
|
35 |
+
lm (LMModel): Language model over discrete representations.
|
36 |
+
max_duration (float, optional): maximum duration the model can produce,
|
37 |
+
otherwise, inferred from the training params.
|
38 |
+
"""
|
39 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
40 |
+
max_duration: tp.Optional[float] = None):
|
41 |
+
self.name = name
|
42 |
+
self.compression_model = compression_model
|
43 |
+
self.lm = lm
|
44 |
+
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
45 |
+
# Just to be safe, let's put everything in eval mode.
|
46 |
+
self.compression_model.eval()
|
47 |
+
self.lm.eval()
|
48 |
+
|
49 |
+
if hasattr(lm, 'cfg'):
|
50 |
+
cfg = lm.cfg
|
51 |
+
assert isinstance(cfg, omegaconf.DictConfig)
|
52 |
+
self.cfg = cfg
|
53 |
+
|
54 |
+
if self.cfg is not None:
|
55 |
+
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
56 |
+
|
57 |
+
if max_duration is None:
|
58 |
+
if self.cfg is not None:
|
59 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
60 |
+
else:
|
61 |
+
raise ValueError("You must provide max_duration when building directly your GenModel")
|
62 |
+
assert max_duration is not None
|
63 |
+
|
64 |
+
self.max_duration: float = max_duration
|
65 |
+
self.duration = self.max_duration
|
66 |
+
|
67 |
+
# self.extend_stride is the length of audio extension when generating samples longer
|
68 |
+
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
|
69 |
+
# positive float value when generating with self.duration > self.max_duration.
|
70 |
+
self.extend_stride: tp.Optional[float] = None
|
71 |
+
self.device = next(iter(lm.parameters())).device
|
72 |
+
self.generation_params: dict = {}
|
73 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
74 |
+
if self.device.type == 'cpu':
|
75 |
+
self.autocast = TorchAutocast(enabled=False)
|
76 |
+
else:
|
77 |
+
self.autocast = TorchAutocast(
|
78 |
+
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def frame_rate(self) -> float:
|
82 |
+
"""Roughly the number of AR steps per seconds."""
|
83 |
+
return self.compression_model.frame_rate
|
84 |
+
|
85 |
+
@property
|
86 |
+
def sample_rate(self) -> int:
|
87 |
+
"""Sample rate of the generated audio."""
|
88 |
+
return self.compression_model.sample_rate
|
89 |
+
|
90 |
+
@property
|
91 |
+
def audio_channels(self) -> int:
|
92 |
+
"""Audio channels of the generated audio."""
|
93 |
+
return self.compression_model.channels
|
94 |
+
|
95 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
96 |
+
"""Override the default progress callback."""
|
97 |
+
self._progress_callback = progress_callback
|
98 |
+
|
99 |
+
@abstractmethod
|
100 |
+
def set_generation_params(self, *args, **kwargs):
|
101 |
+
"""Set the generation parameters."""
|
102 |
+
raise NotImplementedError("No base implementation for setting generation params.")
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
@abstractmethod
|
106 |
+
def get_pretrained(name: str, device=None):
|
107 |
+
raise NotImplementedError("No base implementation for getting pretrained model")
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def _prepare_tokens_and_attributes(
|
111 |
+
self,
|
112 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
113 |
+
prompt: tp.Optional[torch.Tensor],
|
114 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
115 |
+
"""Prepare model inputs.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
119 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
120 |
+
"""
|
121 |
+
attributes = [
|
122 |
+
ConditioningAttributes(text={'description': description})
|
123 |
+
for description in descriptions]
|
124 |
+
|
125 |
+
if prompt is not None:
|
126 |
+
if descriptions is not None:
|
127 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
128 |
+
prompt = prompt.to(self.device)
|
129 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
130 |
+
assert scale is None
|
131 |
+
else:
|
132 |
+
prompt_tokens = None
|
133 |
+
return attributes, prompt_tokens
|
134 |
+
|
135 |
+
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
136 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
137 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
138 |
+
"""Generate samples in an unconditional manner.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
num_samples (int): Number of samples to be generated.
|
142 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
143 |
+
"""
|
144 |
+
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
145 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
146 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
147 |
+
if return_tokens:
|
148 |
+
return self.generate_audio(tokens), tokens
|
149 |
+
return self.generate_audio(tokens)
|
150 |
+
|
151 |
+
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
|
152 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
153 |
+
"""Generate samples conditioned on text.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
157 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
158 |
+
"""
|
159 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
160 |
+
assert prompt_tokens is None
|
161 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
162 |
+
if return_tokens:
|
163 |
+
return self.generate_audio(tokens), tokens
|
164 |
+
return self.generate_audio(tokens)
|
165 |
+
|
166 |
+
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
167 |
+
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
168 |
+
progress: bool = False, return_tokens: bool = False) \
|
169 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
170 |
+
"""Generate samples conditioned on audio prompts and an optional text description.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
174 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
175 |
+
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
176 |
+
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
177 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
178 |
+
"""
|
179 |
+
if prompt.dim() == 2:
|
180 |
+
prompt = prompt[None]
|
181 |
+
if prompt.dim() != 3:
|
182 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
183 |
+
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
184 |
+
if descriptions is None:
|
185 |
+
descriptions = [None] * len(prompt)
|
186 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
187 |
+
assert prompt_tokens is not None
|
188 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
189 |
+
if return_tokens:
|
190 |
+
return self.generate_audio(tokens), tokens
|
191 |
+
return self.generate_audio(tokens)
|
192 |
+
|
193 |
+
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
194 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
195 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
|
199 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
200 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
201 |
+
Returns:
|
202 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
203 |
+
"""
|
204 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
205 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
206 |
+
current_gen_offset: int = 0
|
207 |
+
|
208 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
209 |
+
generated_tokens += current_gen_offset
|
210 |
+
if self._progress_callback is not None:
|
211 |
+
# Note that total_gen_len might be quite wrong depending on the
|
212 |
+
# codebook pattern used, but with delay it is almost accurate.
|
213 |
+
self._progress_callback(generated_tokens, tokens_to_generate)
|
214 |
+
else:
|
215 |
+
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
216 |
+
|
217 |
+
if prompt_tokens is not None:
|
218 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
219 |
+
"Prompt is longer than audio to generate"
|
220 |
+
|
221 |
+
callback = None
|
222 |
+
if progress:
|
223 |
+
callback = _progress_callback
|
224 |
+
|
225 |
+
if self.duration <= self.max_duration:
|
226 |
+
# generate by sampling from LM, simple case.
|
227 |
+
with self.autocast:
|
228 |
+
gen_tokens = self.lm.generate(
|
229 |
+
prompt_tokens, attributes,
|
230 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
231 |
+
|
232 |
+
else:
|
233 |
+
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
|
234 |
+
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
235 |
+
all_tokens = []
|
236 |
+
if prompt_tokens is None:
|
237 |
+
prompt_length = 0
|
238 |
+
else:
|
239 |
+
all_tokens.append(prompt_tokens)
|
240 |
+
prompt_length = prompt_tokens.shape[-1]
|
241 |
+
|
242 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
243 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
244 |
+
time_offset = current_gen_offset / self.frame_rate
|
245 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
246 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
247 |
+
with self.autocast:
|
248 |
+
gen_tokens = self.lm.generate(
|
249 |
+
prompt_tokens, attributes,
|
250 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
251 |
+
if prompt_tokens is None:
|
252 |
+
all_tokens.append(gen_tokens)
|
253 |
+
else:
|
254 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
255 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
256 |
+
prompt_length = prompt_tokens.shape[-1]
|
257 |
+
current_gen_offset += stride_tokens
|
258 |
+
|
259 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
260 |
+
return gen_tokens
|
261 |
+
|
262 |
+
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
263 |
+
"""Generate Audio from tokens."""
|
264 |
+
assert gen_tokens.dim() == 3
|
265 |
+
with torch.no_grad():
|
266 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
267 |
+
return gen_audio
|
audiocraft/models/lm.py
CHANGED
@@ -23,6 +23,7 @@ from ..modules.conditioners import (
|
|
23 |
ConditioningProvider,
|
24 |
ConditioningAttributes,
|
25 |
ConditionType,
|
|
|
26 |
)
|
27 |
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
28 |
from ..modules.activations import get_activation_fn
|
@@ -219,7 +220,8 @@ class LMModel(StreamingModule):
|
|
219 |
|
220 |
def forward(self, sequence: torch.Tensor,
|
221 |
conditions: tp.List[ConditioningAttributes],
|
222 |
-
condition_tensors: tp.Optional[ConditionTensors] = None
|
|
|
223 |
"""Apply language model on sequence and conditions.
|
224 |
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
225 |
S the sequence steps, return the logits with shape [B, card, K, S].
|
@@ -231,6 +233,9 @@ class LMModel(StreamingModule):
|
|
231 |
you should pre-compute those and pass them as `condition_tensors`.
|
232 |
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
233 |
tensors, see `conditions`.
|
|
|
|
|
|
|
234 |
Returns:
|
235 |
torch.Tensor: Logits.
|
236 |
"""
|
@@ -250,7 +255,8 @@ class LMModel(StreamingModule):
|
|
250 |
|
251 |
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
252 |
|
253 |
-
out = self.transformer(input_, cross_attention_src=cross_attention_input
|
|
|
254 |
if self.out_norm:
|
255 |
out = self.out_norm(out)
|
256 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
@@ -264,7 +270,9 @@ class LMModel(StreamingModule):
|
|
264 |
def compute_predictions(
|
265 |
self, codes: torch.Tensor,
|
266 |
conditions: tp.List[ConditioningAttributes],
|
267 |
-
condition_tensors: tp.Optional[ConditionTensors] = None
|
|
|
|
|
268 |
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
269 |
forward using the specified codes interleaving pattern.
|
270 |
|
@@ -276,6 +284,11 @@ class LMModel(StreamingModule):
|
|
276 |
you should pre-compute those and pass them as `condition_tensors`.
|
277 |
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
278 |
tensors, see `conditions`.
|
|
|
|
|
|
|
|
|
|
|
279 |
Returns:
|
280 |
LMOutput: Language model outputs
|
281 |
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
@@ -290,17 +303,18 @@ class LMModel(StreamingModule):
|
|
290 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
291 |
pattern = self.pattern_provider.get_pattern(T)
|
292 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
293 |
-
codes, self.special_token_id, keep_only_valid_steps=
|
294 |
)
|
|
|
295 |
# apply model on pattern sequence
|
296 |
model = self if self._fsdp is None else self._fsdp
|
297 |
-
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
|
298 |
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
299 |
# and provide the corresponding mask over invalid positions of tokens
|
300 |
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
301 |
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
302 |
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
303 |
-
logits, float('nan'), keep_only_valid_steps=
|
304 |
)
|
305 |
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
306 |
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
@@ -315,6 +329,7 @@ class LMModel(StreamingModule):
|
|
315 |
top_k: int = 0,
|
316 |
top_p: float = 0.0,
|
317 |
cfg_coef: tp.Optional[float] = None,
|
|
|
318 |
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
319 |
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
320 |
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
@@ -330,6 +345,13 @@ class LMModel(StreamingModule):
|
|
330 |
top_k (int): K for "top-k" sampling.
|
331 |
top_p (float): P for "top-p" sampling.
|
332 |
cfg_coef (float, optional): classifier free guidance coefficient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
Returns:
|
334 |
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
335 |
"""
|
@@ -337,7 +359,23 @@ class LMModel(StreamingModule):
|
|
337 |
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
338 |
model = self if self._fsdp is None else self._fsdp
|
339 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
340 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
|
342 |
condition_tensors, null_condition_tensors = cfg_conditions
|
343 |
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
@@ -390,23 +428,30 @@ class LMModel(StreamingModule):
|
|
390 |
top_k: int = 250,
|
391 |
top_p: float = 0.0,
|
392 |
cfg_coef: tp.Optional[float] = None,
|
|
|
393 |
two_step_cfg: tp.Optional[bool] = None,
|
394 |
remove_prompts: bool = False,
|
395 |
check: bool = False,
|
396 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
|
|
397 |
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
398 |
-
be
|
399 |
|
400 |
Args:
|
401 |
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
402 |
-
|
403 |
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
404 |
max_gen_len (int): Maximum generation length.
|
405 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
406 |
temp (float): Sampling temperature.
|
407 |
top_k (int): K for "top-k" sampling.
|
408 |
top_p (float): P for "top-p" sampling.
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
410 |
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
411 |
remove_prompts (bool): Whether to remove prompts from generation or not.
|
412 |
check (bool): Whether to apply further checks on generated sequence.
|
@@ -441,18 +486,27 @@ class LMModel(StreamingModule):
|
|
441 |
# the padding structure is exactly the same between train and test.
|
442 |
# With a batch size of 1, this can be slower though.
|
443 |
cfg_conditions: CFGConditions
|
444 |
-
|
445 |
-
if
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
451 |
-
)
|
452 |
-
else:
|
453 |
-
conditions = conditions + null_conditions
|
454 |
tokenized = self.condition_provider.tokenize(conditions)
|
455 |
cfg_conditions = self.condition_provider(tokenized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
else:
|
457 |
cfg_conditions = {}
|
458 |
|
@@ -463,8 +517,8 @@ class LMModel(StreamingModule):
|
|
463 |
B, K, T = prompt.shape
|
464 |
start_offset = T
|
465 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
466 |
-
assert start_offset
|
467 |
-
|
468 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
469 |
# this token is used as default value for codes that are not generated yet
|
470 |
unknown_token = -1
|
@@ -496,7 +550,7 @@ class LMModel(StreamingModule):
|
|
496 |
# sample next token from the model, next token shape is [B, K, 1]
|
497 |
next_token = self._sample_next_token(
|
498 |
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
499 |
-
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
|
500 |
# ensure the tokens that should be masked are properly set to special_token_id
|
501 |
# as the model never output special_token_id
|
502 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
|
|
23 |
ConditioningProvider,
|
24 |
ConditioningAttributes,
|
25 |
ConditionType,
|
26 |
+
_drop_description_condition
|
27 |
)
|
28 |
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
29 |
from ..modules.activations import get_activation_fn
|
|
|
220 |
|
221 |
def forward(self, sequence: torch.Tensor,
|
222 |
conditions: tp.List[ConditioningAttributes],
|
223 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
224 |
+
stage: int = -1) -> torch.Tensor:
|
225 |
"""Apply language model on sequence and conditions.
|
226 |
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
227 |
S the sequence steps, return the logits with shape [B, card, K, S].
|
|
|
233 |
you should pre-compute those and pass them as `condition_tensors`.
|
234 |
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
235 |
tensors, see `conditions`.
|
236 |
+
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
237 |
+
in which prediction is done in a codebook-by-codebook manner.
|
238 |
+
Takes values in range(n_q), and ignored by default.
|
239 |
Returns:
|
240 |
torch.Tensor: Logits.
|
241 |
"""
|
|
|
255 |
|
256 |
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
257 |
|
258 |
+
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
259 |
+
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) # type: ignore
|
260 |
if self.out_norm:
|
261 |
out = self.out_norm(out)
|
262 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
|
|
270 |
def compute_predictions(
|
271 |
self, codes: torch.Tensor,
|
272 |
conditions: tp.List[ConditioningAttributes],
|
273 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
274 |
+
stage: int = -1,
|
275 |
+
keep_only_valid_steps: bool = True) -> LMOutput:
|
276 |
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
277 |
forward using the specified codes interleaving pattern.
|
278 |
|
|
|
284 |
you should pre-compute those and pass them as `condition_tensors`.
|
285 |
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
286 |
tensors, see `conditions`.
|
287 |
+
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
288 |
+
in which prediction is done in a codebook-by-codebook manner.
|
289 |
+
Takes values in range(n_q), and ignored by default.
|
290 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
291 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
292 |
Returns:
|
293 |
LMOutput: Language model outputs
|
294 |
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
|
|
303 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
304 |
pattern = self.pattern_provider.get_pattern(T)
|
305 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
306 |
+
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
307 |
)
|
308 |
+
|
309 |
# apply model on pattern sequence
|
310 |
model = self if self._fsdp is None else self._fsdp
|
311 |
+
logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
|
312 |
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
313 |
# and provide the corresponding mask over invalid positions of tokens
|
314 |
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
315 |
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
316 |
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
317 |
+
logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
|
318 |
)
|
319 |
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
320 |
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
|
|
329 |
top_k: int = 0,
|
330 |
top_p: float = 0.0,
|
331 |
cfg_coef: tp.Optional[float] = None,
|
332 |
+
cfg_coef_beta: tp.Optional[float] = None,
|
333 |
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
334 |
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
335 |
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
|
|
345 |
top_k (int): K for "top-k" sampling.
|
346 |
top_p (float): P for "top-p" sampling.
|
347 |
cfg_coef (float, optional): classifier free guidance coefficient
|
348 |
+
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
|
349 |
+
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
|
350 |
+
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
|
351 |
+
push the text condition more than the style condition in the case where both text and style
|
352 |
+
conditions are being used.
|
353 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
354 |
+
|
355 |
Returns:
|
356 |
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
357 |
"""
|
|
|
359 |
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
360 |
model = self if self._fsdp is None else self._fsdp
|
361 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
362 |
+
if cfg_coef_beta is not None:
|
363 |
+
assert isinstance(cfg_conditions, dict)
|
364 |
+
condition_tensors = cfg_conditions
|
365 |
+
if condition_tensors:
|
366 |
+
# Preparing for CFG, predicting conditional text and style, conditional style
|
367 |
+
# and unconditional
|
368 |
+
sequence = torch.cat([sequence, sequence, sequence], dim=0)
|
369 |
+
all_logits = model(
|
370 |
+
sequence,
|
371 |
+
conditions=[], condition_tensors=condition_tensors)
|
372 |
+
if condition_tensors:
|
373 |
+
cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
374 |
+
logits = uncond_logits + cfg_coef * (
|
375 |
+
wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits
|
376 |
+
)
|
377 |
+
|
378 |
+
elif two_step_cfg and cfg_conditions != {}:
|
379 |
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
|
380 |
condition_tensors, null_condition_tensors = cfg_conditions
|
381 |
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
|
|
428 |
top_k: int = 250,
|
429 |
top_p: float = 0.0,
|
430 |
cfg_coef: tp.Optional[float] = None,
|
431 |
+
cfg_coef_beta: tp.Optional[float] = None,
|
432 |
two_step_cfg: tp.Optional[bool] = None,
|
433 |
remove_prompts: bool = False,
|
434 |
check: bool = False,
|
435 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
436 |
+
) -> torch.Tensor:
|
437 |
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
438 |
+
be performed in a greedy fashion or using sampling with top K and top P strategies.
|
439 |
|
440 |
Args:
|
441 |
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
442 |
+
conditions (list of ConditioningAttributes, optional): List of conditions.
|
443 |
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
444 |
max_gen_len (int): Maximum generation length.
|
445 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
446 |
temp (float): Sampling temperature.
|
447 |
top_k (int): K for "top-k" sampling.
|
448 |
top_p (float): P for "top-p" sampling.
|
449 |
+
cfg_coef (float, optional): Classifier-free guidance coefficient.
|
450 |
+
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
|
451 |
+
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
|
452 |
+
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
|
453 |
+
push the text condition more than the style condition in the case where both text and style
|
454 |
+
conditions are being used.
|
455 |
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
456 |
remove_prompts (bool): Whether to remove prompts from generation or not.
|
457 |
check (bool): Whether to apply further checks on generated sequence.
|
|
|
486 |
# the padding structure is exactly the same between train and test.
|
487 |
# With a batch size of 1, this can be slower though.
|
488 |
cfg_conditions: CFGConditions
|
489 |
+
cfg_conditions = {}
|
490 |
+
if cfg_coef_beta is not None:
|
491 |
+
if conditions:
|
492 |
+
wav_conditions = _drop_description_condition(conditions)
|
493 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
494 |
+
conditions = conditions + wav_conditions + null_conditions
|
|
|
|
|
|
|
|
|
495 |
tokenized = self.condition_provider.tokenize(conditions)
|
496 |
cfg_conditions = self.condition_provider(tokenized)
|
497 |
+
elif conditions:
|
498 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
499 |
+
if conditions:
|
500 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
501 |
+
if two_step_cfg:
|
502 |
+
cfg_conditions = (
|
503 |
+
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
504 |
+
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
conditions = conditions + null_conditions
|
508 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
509 |
+
cfg_conditions = self.condition_provider(tokenized)
|
510 |
else:
|
511 |
cfg_conditions = {}
|
512 |
|
|
|
517 |
B, K, T = prompt.shape
|
518 |
start_offset = T
|
519 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
520 |
+
assert start_offset < max_gen_len
|
521 |
+
|
522 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
523 |
# this token is used as default value for codes that are not generated yet
|
524 |
unknown_token = -1
|
|
|
550 |
# sample next token from the model, next token shape is [B, K, 1]
|
551 |
next_token = self._sample_next_token(
|
552 |
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
553 |
+
cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg)
|
554 |
# ensure the tokens that should be masked are properly set to special_token_id
|
555 |
# as the model never output special_token_id
|
556 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
audiocraft/models/lm_magnet.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import typing as tp
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from ..utils import utils
|
14 |
+
from ..modules.conditioners import (
|
15 |
+
ClassifierFreeGuidanceDropout,
|
16 |
+
ConditioningAttributes,
|
17 |
+
ConditionType,
|
18 |
+
)
|
19 |
+
from .lm import LMModel
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
23 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
24 |
+
|
25 |
+
|
26 |
+
class MagnetLMModel(LMModel):
|
27 |
+
"""Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
|
28 |
+
Args:
|
29 |
+
subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
|
30 |
+
When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
|
31 |
+
compression_model_framerate (int): frame rate of the audio tokenizer.
|
32 |
+
segment_duration (int): Sample length in seconds.
|
33 |
+
span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
|
34 |
+
for both training and inference. Defaults to 3.
|
35 |
+
**kwargs: Additional parameters for the LMModel.
|
36 |
+
"""
|
37 |
+
def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
|
38 |
+
segment_duration: int = 10, span_len: int = 3, **kwargs):
|
39 |
+
super().__init__(**kwargs)
|
40 |
+
self.causal = kwargs['causal']
|
41 |
+
self.subcodes_context = subcodes_context
|
42 |
+
self.span_len = span_len
|
43 |
+
self._build_attn_masks(compression_model_framerate=compression_model_framerate,
|
44 |
+
segment_duration=segment_duration,
|
45 |
+
num_heads=kwargs['num_heads'],
|
46 |
+
device=kwargs['device'], dtype=kwargs['dtype'])
|
47 |
+
|
48 |
+
def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
49 |
+
"""Creates a restricted attention mask (local attention map) where the context
|
50 |
+
is determined by self.subcodes_context.
|
51 |
+
Args:
|
52 |
+
seq_len (int): token sequence length.
|
53 |
+
device (torch.device): device of the output tensor.
|
54 |
+
dtype (torch.dtype): data type of the output tensor.
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: The restricted attention mask.
|
57 |
+
"""
|
58 |
+
# Return a context restricted non-causal att mask
|
59 |
+
queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
|
60 |
+
keys_pos = torch.arange(seq_len, device=device).view(1, -1)
|
61 |
+
|
62 |
+
delta = queries_pos - keys_pos
|
63 |
+
valid = torch.abs(delta) <= self.subcodes_context
|
64 |
+
return torch.where(
|
65 |
+
valid,
|
66 |
+
torch.zeros([], device=device, dtype=dtype),
|
67 |
+
torch.full([], float('-inf'), device=device, dtype=dtype))
|
68 |
+
|
69 |
+
def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
|
70 |
+
device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
|
71 |
+
"""Creates a restricted attention mask given the stage (codebook index).
|
72 |
+
Args:
|
73 |
+
stage (int): The codebook index. Takes values in [0, n_q].
|
74 |
+
seq_len (int): Token sequence length.
|
75 |
+
num_heads (int): Num transformer attention heads.
|
76 |
+
device (torch.device): device of the output tensor.
|
77 |
+
dtype (torch.dtype): data type of the output tensor.
|
78 |
+
Returns:
|
79 |
+
torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
|
80 |
+
"""
|
81 |
+
sa_mask = None
|
82 |
+
|
83 |
+
if stage > 0 and self.subcodes_context > -1:
|
84 |
+
# parallel - non-causal - with restricted subcodes context
|
85 |
+
sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)
|
86 |
+
|
87 |
+
if sa_mask is not None:
|
88 |
+
# Repeat for each attention head
|
89 |
+
sa_mask = sa_mask.repeat((1, num_heads, 1, 1))
|
90 |
+
|
91 |
+
# align8 to enable memory efficient attention
|
92 |
+
MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
|
93 |
+
seq_len_aligned = \
|
94 |
+
int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR
|
95 |
+
|
96 |
+
sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
|
97 |
+
sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
|
98 |
+
sa_mask = sa_mask_aligned
|
99 |
+
|
100 |
+
return sa_mask
|
101 |
+
|
102 |
+
def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
|
103 |
+
device: torch.device, dtype: torch.dtype):
|
104 |
+
"""Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
|
105 |
+
either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
|
106 |
+
Args:
|
107 |
+
compression_model_framerate (int): The frame rate of the tokenizer.
|
108 |
+
segment_duration (int): Sample length in seconds.
|
109 |
+
num_heads (int): Num transformer attention heads.
|
110 |
+
device (torch.device): device of the output tensor.
|
111 |
+
dtype (torch.dtype): data type of the output tensor.
|
112 |
+
"""
|
113 |
+
seq_len = compression_model_framerate * segment_duration
|
114 |
+
self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
|
115 |
+
device, dtype) for stage in range(self.n_q)]
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def generate(self,
|
119 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
120 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
121 |
+
num_samples: tp.Optional[int] = None,
|
122 |
+
max_gen_len: int = 256,
|
123 |
+
use_sampling: bool = True,
|
124 |
+
temp: float = 1.0,
|
125 |
+
top_k: int = 250,
|
126 |
+
top_p: float = 0.0,
|
127 |
+
cfg_coef: tp.Optional[float] = None,
|
128 |
+
cfg_coef_beta: tp.Optional[float] = None,
|
129 |
+
two_step_cfg: tp.Optional[bool] = None,
|
130 |
+
remove_prompts: bool = False,
|
131 |
+
check: bool = False,
|
132 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
133 |
+
**kwargs) -> torch.Tensor:
|
134 |
+
|
135 |
+
assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
|
136 |
+
assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
|
137 |
+
assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
|
138 |
+
assert check is False, "MAGNeT currently doesn't support the check arg."
|
139 |
+
assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg."
|
140 |
+
# Call the MAGNeT-specific generation method
|
141 |
+
return self._generate_magnet(prompt=prompt,
|
142 |
+
conditions=conditions,
|
143 |
+
num_samples=num_samples,
|
144 |
+
max_gen_len=max_gen_len,
|
145 |
+
use_sampling=use_sampling,
|
146 |
+
temp=temp,
|
147 |
+
top_k=top_k,
|
148 |
+
top_p=top_p,
|
149 |
+
callback=callback, **kwargs)
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def _generate_magnet(self,
|
153 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
154 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
155 |
+
num_samples: tp.Optional[int] = None,
|
156 |
+
max_gen_len: int = 256,
|
157 |
+
use_sampling: bool = True,
|
158 |
+
temp: float = 3.0,
|
159 |
+
top_k: int = 0,
|
160 |
+
top_p: float = 0.9,
|
161 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
162 |
+
max_cfg_coef: float = 10.0,
|
163 |
+
min_cfg_coef: float = 1.0,
|
164 |
+
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
165 |
+
anneal_temp: bool = True,
|
166 |
+
span_scoring='max',
|
167 |
+
span_arrangement='nonoverlap') -> torch.Tensor:
|
168 |
+
"""Generate audio tokens given textual conditions, and optionally given audio prompts,
|
169 |
+
by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
|
170 |
+
Args:
|
171 |
+
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
172 |
+
conditions (list of ConditioningAttributes): List of conditions.
|
173 |
+
num_samples (int): Number of samples to generate when no prompt and no conditions are given.
|
174 |
+
max_gen_len (int): Maximum generation length.
|
175 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
176 |
+
temp (float): Initial sampling temperature.
|
177 |
+
top_k (int): k for "top-k" sampling.
|
178 |
+
top_p (float): p for "top-p" sampling.
|
179 |
+
callback (Callback): Callback function to report generation progress.
|
180 |
+
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
181 |
+
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
182 |
+
decoding_steps (list of n_q ints): The number of iterative decoding steps,
|
183 |
+
for each of the n_q RVQ codebooks.
|
184 |
+
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
185 |
+
span_scoring (str): Use the maximum probability of each span ('max')
|
186 |
+
or the product of probabilities ('prod').
|
187 |
+
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
188 |
+
in the masking scheme.
|
189 |
+
Returns:
|
190 |
+
torch.Tensor: Generated tokens.
|
191 |
+
"""
|
192 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
193 |
+
first_param = next(iter(self.parameters()))
|
194 |
+
device = first_param.device
|
195 |
+
|
196 |
+
# Checking all input shapes are consistent.
|
197 |
+
possible_num_samples = []
|
198 |
+
if num_samples is not None:
|
199 |
+
possible_num_samples.append(num_samples)
|
200 |
+
elif prompt is not None:
|
201 |
+
possible_num_samples.append(prompt.shape[0])
|
202 |
+
elif conditions:
|
203 |
+
possible_num_samples.append(len(conditions))
|
204 |
+
else:
|
205 |
+
possible_num_samples.append(1)
|
206 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
207 |
+
num_samples = possible_num_samples[0]
|
208 |
+
|
209 |
+
# below we create set of conditions: one conditional and one unconditional
|
210 |
+
# to do that we merge the regular condition together with the null condition
|
211 |
+
# we then do 1 forward pass instead of 2.
|
212 |
+
cfg_conditions: tp.Optional[ConditionTensors]
|
213 |
+
if conditions:
|
214 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
215 |
+
conditions = conditions + null_conditions
|
216 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
217 |
+
cfg_conditions = self.condition_provider(tokenized)
|
218 |
+
else:
|
219 |
+
cfg_conditions = {}
|
220 |
+
|
221 |
+
if prompt is None:
|
222 |
+
assert num_samples > 0
|
223 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
224 |
+
|
225 |
+
B, K, prompt_length = prompt.shape
|
226 |
+
start_offset = prompt_length
|
227 |
+
assert start_offset < max_gen_len
|
228 |
+
|
229 |
+
mask_id = self.special_token_id
|
230 |
+
|
231 |
+
# we generate codes with a fixed sequence length
|
232 |
+
shape = (B, K, max_gen_len)
|
233 |
+
|
234 |
+
gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
235 |
+
# filling the gen_codes with the prompt if needed
|
236 |
+
gen_codes[..., :start_offset] = prompt
|
237 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
238 |
+
gen_sequence = gen_codes
|
239 |
+
|
240 |
+
curr_step = 0
|
241 |
+
for stage, n_steps in zip(range(self.n_q), decoding_steps):
|
242 |
+
gen_sequence, curr_step = self._generate_stage(gen_sequence,
|
243 |
+
cfg_conditions,
|
244 |
+
stage=stage,
|
245 |
+
device=device,
|
246 |
+
prompt_length=prompt_length,
|
247 |
+
prompt=prompt,
|
248 |
+
temp=temp,
|
249 |
+
max_cfg_coef=max_cfg_coef,
|
250 |
+
min_cfg_coef=min_cfg_coef,
|
251 |
+
top_k=top_k,
|
252 |
+
top_p=top_p,
|
253 |
+
timesteps=n_steps,
|
254 |
+
anneal_temp=anneal_temp,
|
255 |
+
span_scoring=span_scoring,
|
256 |
+
use_sampling=use_sampling,
|
257 |
+
span_arrangement=span_arrangement,
|
258 |
+
curr_step=curr_step,
|
259 |
+
total_steps=sum(decoding_steps),
|
260 |
+
callback=callback)
|
261 |
+
|
262 |
+
return gen_sequence
|
263 |
+
|
264 |
+
@torch.no_grad()
|
265 |
+
def _generate_stage(self,
|
266 |
+
gen_sequence: torch.Tensor,
|
267 |
+
condition_tensors: tp.Optional[ConditionTensors],
|
268 |
+
stage: int,
|
269 |
+
device: torch.device,
|
270 |
+
prompt_length: int = 0,
|
271 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
272 |
+
use_sampling: bool = True,
|
273 |
+
temp: float = 3.0,
|
274 |
+
max_cfg_coef: float = 10.0,
|
275 |
+
min_cfg_coef: float = 1.0,
|
276 |
+
top_k: int = 0,
|
277 |
+
top_p: float = 0.0,
|
278 |
+
timesteps: int = 10,
|
279 |
+
anneal_temp: bool = True,
|
280 |
+
span_scoring: str = 'max',
|
281 |
+
span_arrangement: str = 'nonoverlap',
|
282 |
+
curr_step: int = 0,
|
283 |
+
total_steps: int = 0,
|
284 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
|
285 |
+
"""Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
|
286 |
+
and the textual conditions.
|
287 |
+
Args:
|
288 |
+
gen_sequence (torch.Tensor): Previously generated tokens.
|
289 |
+
condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
|
290 |
+
stage (int): RVQ level to generate.
|
291 |
+
device (torch.device): device of the output tensor.
|
292 |
+
prompt_length (int): Temporal length of the audio prompt.
|
293 |
+
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
294 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
295 |
+
temp (float): Initial sampling temperature.
|
296 |
+
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
297 |
+
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
298 |
+
top_k (int): k for "top-k" sampling.
|
299 |
+
top_p (float): p for "top-p" sampling.
|
300 |
+
timesteps (int): Number of iterative decoding steps.
|
301 |
+
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
302 |
+
span_scoring (str): Use the maximum probability of each span ('max')
|
303 |
+
or the product of probabilities ('prod').
|
304 |
+
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
305 |
+
in the masking scheme.
|
306 |
+
curr_step (int): Global iterative decoding step counter.
|
307 |
+
total_steps (int): Total decoding steps.
|
308 |
+
callback (Callback): Callback function to report generation progress.
|
309 |
+
Returns:
|
310 |
+
tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
|
311 |
+
"""
|
312 |
+
B, K, T = gen_sequence.shape
|
313 |
+
shape = (B, 1, T) # generating a single codebook per stage
|
314 |
+
|
315 |
+
mask_id = self.special_token_id
|
316 |
+
stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
317 |
+
|
318 |
+
assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
|
319 |
+
chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'
|
320 |
+
|
321 |
+
DONT_REMASK_ME_SCORE = -1e4
|
322 |
+
|
323 |
+
model = self if self._fsdp is None else self._fsdp
|
324 |
+
|
325 |
+
if chunk_masking:
|
326 |
+
# span-wise scores
|
327 |
+
n_chunks = T // self.span_len
|
328 |
+
if T % self.span_len != 0:
|
329 |
+
# trim sequence ending to achieve a multiple of span_len
|
330 |
+
T = self.span_len * n_chunks
|
331 |
+
gen_sequence = gen_sequence[..., :T]
|
332 |
+
stage_gen_seq = stage_gen_seq[..., :T]
|
333 |
+
|
334 |
+
chunked_shape = (B, 1, n_chunks)
|
335 |
+
n_prompt_chunks = prompt_length // self.span_len
|
336 |
+
scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
|
337 |
+
scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
|
338 |
+
num_chunks_to_gen = n_chunks - n_prompt_chunks
|
339 |
+
else:
|
340 |
+
# token-wise scores
|
341 |
+
scores = torch.zeros(shape, dtype=torch.float32, device=device)
|
342 |
+
scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
|
343 |
+
gen_T = T - prompt_length
|
344 |
+
|
345 |
+
# run MAGNeT iterative decoding for "timesteps" iterations
|
346 |
+
for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
|
347 |
+
|
348 |
+
mask_p = torch.cos(timestep * math.pi * 0.5)
|
349 |
+
|
350 |
+
if chunk_masking:
|
351 |
+
num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
|
352 |
+
else:
|
353 |
+
num_masked = max(int((mask_p * gen_T).item()), 1)
|
354 |
+
|
355 |
+
# masking
|
356 |
+
run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
|
357 |
+
if run_lps_masking:
|
358 |
+
# masking of the k least probable overlapping (stride 1) spans
|
359 |
+
mask = torch.concat((
|
360 |
+
[self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
|
361 |
+
for i in range(B)]), dim=0)
|
362 |
+
stage_gen_seq[mask] = mask_id
|
363 |
+
else:
|
364 |
+
# masking of the k least probable non-overlapping spans
|
365 |
+
masked = scores.topk(num_masked, dim=-1).indices
|
366 |
+
if chunk_masking:
|
367 |
+
chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
|
368 |
+
chunks_mask = chunks_mask.scatter(2, masked, True)
|
369 |
+
mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
|
370 |
+
stage_gen_seq[mask] = mask_id
|
371 |
+
else:
|
372 |
+
stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)
|
373 |
+
|
374 |
+
if prompt is not None:
|
375 |
+
stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)
|
376 |
+
|
377 |
+
gen_sequence[:, [stage], :] = stage_gen_seq
|
378 |
+
if condition_tensors:
|
379 |
+
# duplicate input for classifier free guidance
|
380 |
+
sequence = torch.cat([gen_sequence, gen_sequence], dim=0)
|
381 |
+
|
382 |
+
all_logits = model(sequence, [], condition_tensors, stage=stage)
|
383 |
+
|
384 |
+
if condition_tensors:
|
385 |
+
# classifier free guidance with annealing
|
386 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
387 |
+
clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
|
388 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
|
389 |
+
else:
|
390 |
+
logits = all_logits
|
391 |
+
|
392 |
+
# temperature annealing - linear
|
393 |
+
t = temp * (steps_left / timesteps) if anneal_temp else temp
|
394 |
+
|
395 |
+
# sampling
|
396 |
+
logits = logits[:, stage, :, :].unsqueeze(1)
|
397 |
+
probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
|
398 |
+
if use_sampling:
|
399 |
+
if top_p > 0.0:
|
400 |
+
sampled_tokens = utils.sample_top_p(probs, p=top_p)
|
401 |
+
elif top_k > 0:
|
402 |
+
sampled_tokens = utils.sample_top_k(probs, k=top_k)
|
403 |
+
else:
|
404 |
+
sampled_tokens = utils.multinomial(probs, num_samples=1)
|
405 |
+
else:
|
406 |
+
sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
407 |
+
|
408 |
+
# place mask_id token in each of the masked positions
|
409 |
+
mask = stage_gen_seq == mask_id
|
410 |
+
stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
|
411 |
+
gen_sequence[:, [stage], :] = stage_gen_seq
|
412 |
+
|
413 |
+
# get probs of sampled tokens
|
414 |
+
sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]
|
415 |
+
|
416 |
+
# span scoring
|
417 |
+
if chunk_masking:
|
418 |
+
if span_scoring == 'max':
|
419 |
+
# max in linear space
|
420 |
+
scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
|
421 |
+
elif span_scoring == 'prod':
|
422 |
+
# prod in log space
|
423 |
+
scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
|
424 |
+
else:
|
425 |
+
raise NotImplementedError
|
426 |
+
else:
|
427 |
+
# prod in log space for lps masking (stride1)
|
428 |
+
scores = -torch.log(sampled_probs)
|
429 |
+
|
430 |
+
# Fix unmasked tokens by placing inf probs (-inf scores)
|
431 |
+
if chunk_masking:
|
432 |
+
scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
|
433 |
+
else:
|
434 |
+
scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)
|
435 |
+
|
436 |
+
if callback is not None:
|
437 |
+
curr_step += 1
|
438 |
+
callback(curr_step, total_steps)
|
439 |
+
|
440 |
+
return gen_sequence, curr_step
|
441 |
+
|
442 |
+
def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
|
443 |
+
"""Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
|
444 |
+
span_starts defines the initial index of each span, and the span length is
|
445 |
+
defined by self.span_len.
|
446 |
+
Args:
|
447 |
+
span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
|
448 |
+
T (int): Sequence length.
|
449 |
+
device (torch.device): device of the output tensor.
|
450 |
+
Returns:
|
451 |
+
torch.Tensor: Spans mask of shape [1x1xT]
|
452 |
+
"""
|
453 |
+
mask = torch.full((1, 1, T), False, device=device)
|
454 |
+
mask[:, :, span_starts] = True
|
455 |
+
shifted_mask = mask.clone()
|
456 |
+
for _ in range(self.span_len - 1):
|
457 |
+
shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
|
458 |
+
mask = torch.logical_or(mask, shifted_mask)
|
459 |
+
return mask
|
460 |
+
|
461 |
+
def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
|
462 |
+
"""Construct a [1x1xT] boolean mask, consists of the u least probable spans,
|
463 |
+
where the token probability is determined by -scores, and the total
|
464 |
+
number of masked tokens is as closest as possible to num_masked_trg.
|
465 |
+
Find u using binary search.
|
466 |
+
Args:
|
467 |
+
scores (torch.Tensor): Per token score [-log(prob)]
|
468 |
+
num_masked_trg: int: The desired amount of tokens to be masked.
|
469 |
+
Returns:
|
470 |
+
torch.Tensor: Spans mask of shape [1x1xT]
|
471 |
+
"""
|
472 |
+
T = scores.shape[-1]
|
473 |
+
device = scores.device
|
474 |
+
scores_unfolded = scores.unfold(2, self.span_len, 1)
|
475 |
+
# Span score is the product of probs (sum in log space)
|
476 |
+
span_scores = scores_unfolded.sum(dim=-1)
|
477 |
+
spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)
|
478 |
+
|
479 |
+
num_masked_trg = max(num_masked_trg, self.span_len)
|
480 |
+
|
481 |
+
# Binary search for u - the number least probable overlapping masked spans s.t.
|
482 |
+
# the total masking rate is the closest to num_masked_trg / T.
|
483 |
+
min_u = num_masked_trg // self.span_len
|
484 |
+
max_u = num_masked_trg - self.span_len + 1
|
485 |
+
mid = round(0.5 * (min_u + max_u))
|
486 |
+
|
487 |
+
if mid == min_u or mid == max_u:
|
488 |
+
return self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
489 |
+
|
490 |
+
while mid > min_u and mid < max_u:
|
491 |
+
mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
492 |
+
n_masked = mask.sum()
|
493 |
+
if n_masked > num_masked_trg:
|
494 |
+
max_u = mid
|
495 |
+
mid = round(0.5 * (min_u + max_u))
|
496 |
+
else:
|
497 |
+
min_u = mid
|
498 |
+
mid = round(0.5 * (min_u + max_u))
|
499 |
+
|
500 |
+
return mask
|
audiocraft/models/loaders.py
CHANGED
@@ -28,6 +28,7 @@ from omegaconf import OmegaConf, DictConfig
|
|
28 |
import torch
|
29 |
|
30 |
import audiocraft
|
|
|
31 |
from . import builders
|
32 |
from .encodec import CompressionModel
|
33 |
|
@@ -47,6 +48,7 @@ HF_MODEL_CHECKPOINTS_MAP = {
|
|
47 |
"stereo-large": "facebook/musicgen-stereo-large",
|
48 |
"stereo-melody": "facebook/musicgen-stereo-melody",
|
49 |
"stereo-melody-large": "facebook/musicgen-stereo-melody-large",
|
|
|
50 |
}
|
51 |
|
52 |
|
@@ -156,7 +158,7 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
|
156 |
# Handle newer model formats that might not have xp.cfg
|
157 |
if 'xp.cfg' not in pkg:
|
158 |
if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
|
159 |
-
'stereo-small', 'stereo-large', 'stereo-melody-large']:
|
160 |
print(f"Using fallback configuration for {file_or_url_or_id}")
|
161 |
# Create a default configuration based on the model type
|
162 |
# This is where you'd need to add model-specific configurations
|
@@ -212,6 +214,52 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
|
|
212 |
return model
|
213 |
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
216 |
filename: tp.Optional[str] = None,
|
217 |
cache_dir: tp.Optional[str] = None):
|
|
|
28 |
import torch
|
29 |
|
30 |
import audiocraft
|
31 |
+
|
32 |
from . import builders
|
33 |
from .encodec import CompressionModel
|
34 |
|
|
|
48 |
"stereo-large": "facebook/musicgen-stereo-large",
|
49 |
"stereo-melody": "facebook/musicgen-stereo-melody",
|
50 |
"stereo-melody-large": "facebook/musicgen-stereo-melody-large",
|
51 |
+
"style": "facebook/musicgen-style",
|
52 |
}
|
53 |
|
54 |
|
|
|
158 |
# Handle newer model formats that might not have xp.cfg
|
159 |
if 'xp.cfg' not in pkg:
|
160 |
if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
|
161 |
+
'stereo-small', 'stereo-large', 'stereo-melody-large','style']:
|
162 |
print(f"Using fallback configuration for {file_or_url_or_id}")
|
163 |
# Create a default configuration based on the model type
|
164 |
# This is where you'd need to add model-specific configurations
|
|
|
214 |
return model
|
215 |
|
216 |
|
217 |
+
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
|
218 |
+
device='cpu', cache_dir: tp.Optional[str] = None):
|
219 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
220 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
221 |
+
cfg.device = str(device)
|
222 |
+
if cfg.device == 'cpu':
|
223 |
+
cfg.dtype = 'float32'
|
224 |
+
else:
|
225 |
+
cfg.dtype = 'float16'
|
226 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
227 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
228 |
+
|
229 |
+
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
|
230 |
+
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
|
231 |
+
cfg.transformer_lm.span_len = cfg.masking.span_len
|
232 |
+
|
233 |
+
# MAGNeT models v1 support only xformers backend.
|
234 |
+
from audiocraft.modules.transformer import set_efficient_attention_backend
|
235 |
+
|
236 |
+
if cfg.transformer_lm.memory_efficient:
|
237 |
+
set_efficient_attention_backend("xformers")
|
238 |
+
|
239 |
+
model = builders.get_lm_model(cfg)
|
240 |
+
model.load_state_dict(pkg['best_state'])
|
241 |
+
model.eval()
|
242 |
+
model.cfg = cfg
|
243 |
+
return model
|
244 |
+
|
245 |
+
|
246 |
+
def load_jasco_model(file_or_url_or_id: tp.Union[Path, str],
|
247 |
+
compression_model: CompressionModel,
|
248 |
+
device='cpu', cache_dir: tp.Optional[str] = None):
|
249 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
250 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
251 |
+
cfg.device = str(device)
|
252 |
+
if cfg.device == 'cpu':
|
253 |
+
cfg.dtype = 'float32'
|
254 |
+
else:
|
255 |
+
cfg.dtype = 'float16'
|
256 |
+
model = builders.get_jasco_model(cfg, compression_model)
|
257 |
+
model.load_state_dict(pkg['best_state'])
|
258 |
+
model.eval()
|
259 |
+
model.cfg = cfg
|
260 |
+
return model
|
261 |
+
|
262 |
+
|
263 |
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
264 |
filename: tp.Optional[str] = None,
|
265 |
cache_dir: tp.Optional[str] = None):
|
audiocraft/models/magnet.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using MAGNeT. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
import typing as tp
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from .genmodel import BaseGenModel
|
15 |
+
from .loaders import load_compression_model, load_lm_model_magnet
|
16 |
+
|
17 |
+
|
18 |
+
class MAGNeT(BaseGenModel):
|
19 |
+
"""MAGNeT main model with convenient generation API.
|
20 |
+
Args:
|
21 |
+
See MusicGen class.
|
22 |
+
"""
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
# MAGNeT operates over a fixed sequence length defined in it's config.
|
26 |
+
self.duration = self.lm.cfg.dataset.segment_duration
|
27 |
+
self.set_generation_params()
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
|
31 |
+
"""Return pretrained model, we provide six models:
|
32 |
+
- facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
|
33 |
+
# see: https://huggingface.co/facebook/magnet-small-10secs
|
34 |
+
- facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
|
35 |
+
# see: https://huggingface.co/facebook/magnet-medium-10secs
|
36 |
+
- facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
|
37 |
+
# see: https://huggingface.co/facebook/magnet-small-30secs
|
38 |
+
- facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
|
39 |
+
# see: https://huggingface.co/facebook/magnet-medium-30secs
|
40 |
+
- facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
|
41 |
+
# see: https://huggingface.co/facebook/audio-magnet-small
|
42 |
+
- facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
|
43 |
+
# see: https://huggingface.co/facebook/audio-magnet-medium
|
44 |
+
"""
|
45 |
+
if device is None:
|
46 |
+
if torch.cuda.device_count():
|
47 |
+
device = 'cuda'
|
48 |
+
else:
|
49 |
+
device = 'cpu'
|
50 |
+
|
51 |
+
compression_model = load_compression_model(name, device=device)
|
52 |
+
lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)
|
53 |
+
|
54 |
+
if 'self_wav' in lm.condition_provider.conditioners:
|
55 |
+
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
56 |
+
|
57 |
+
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
|
58 |
+
return MAGNeT(**kwargs)
|
59 |
+
|
60 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
|
61 |
+
top_p: float = 0.9, temperature: float = 3.0,
|
62 |
+
max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
|
63 |
+
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
64 |
+
span_arrangement: str = 'nonoverlap'):
|
65 |
+
"""Set the generation parameters for MAGNeT.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
69 |
+
top_k (int, optional): top_k used for sampling. Defaults to 0.
|
70 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
|
71 |
+
temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
|
72 |
+
max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
|
73 |
+
min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
|
74 |
+
decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
|
75 |
+
for each of the n_q RVQ codebooks.
|
76 |
+
span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
|
77 |
+
or overlapping spans ('stride1') in the masking scheme.
|
78 |
+
"""
|
79 |
+
self.generation_params = {
|
80 |
+
'use_sampling': use_sampling,
|
81 |
+
'temp': temperature,
|
82 |
+
'top_k': top_k,
|
83 |
+
'top_p': top_p,
|
84 |
+
'max_cfg_coef': max_cfg_coef,
|
85 |
+
'min_cfg_coef': min_cfg_coef,
|
86 |
+
'decoding_steps': [int(s) for s in decoding_steps],
|
87 |
+
'span_arrangement': span_arrangement
|
88 |
+
}
|
audiocraft/models/musicgen.py
CHANGED
@@ -18,11 +18,12 @@ import torch
|
|
18 |
import gradio as gr
|
19 |
|
20 |
from .encodec import CompressionModel
|
|
|
21 |
from .lm import LMModel
|
22 |
from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
|
23 |
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
24 |
from ..data.audio_utils import convert_audio
|
25 |
-
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
26 |
from ..utils.autocast import TorchAutocast
|
27 |
|
28 |
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
@@ -108,6 +109,7 @@ class MusicGen:
|
|
108 |
- stereo-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-stereo-melody
|
109 |
- stereo-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-large
|
110 |
- stereo-melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-stereo-melody-large
|
|
|
111 |
"""
|
112 |
|
113 |
if device is None:
|
@@ -120,7 +122,7 @@ class MusicGen:
|
|
120 |
# used only for unit tests
|
121 |
compression_model = get_debug_compression_model(device)
|
122 |
lm = get_debug_lm_model(device)
|
123 |
-
return MusicGen(name, compression_model, lm)
|
124 |
|
125 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
126 |
if not os.path.isfile(name) and not os.path.isdir(name):
|
@@ -143,6 +145,7 @@ class MusicGen:
|
|
143 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
144 |
top_p: float = 0.0, temperature: float = 1.0,
|
145 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
|
|
146 |
two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
|
147 |
"""Set the generation parameters for MusicGen.
|
148 |
|
@@ -153,6 +156,10 @@ class MusicGen:
|
|
153 |
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
154 |
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
155 |
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
|
|
|
|
|
|
|
|
156 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
157 |
instead of batching together the two. This has some impact on how things
|
158 |
are padded but seems to have little impact in practice.
|
@@ -172,8 +179,30 @@ class MusicGen:
|
|
172 |
'top_p': top_p,
|
173 |
'cfg_coef': cfg_coef,
|
174 |
'two_step_cfg': two_step_cfg,
|
|
|
175 |
}
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
|
178 |
"""Override the default progress callback."""
|
179 |
self._progress_callback = progress_callback
|
@@ -399,8 +428,8 @@ class MusicGen:
|
|
399 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
400 |
|
401 |
Args:
|
402 |
-
attributes (
|
403 |
-
prompt_tokens (
|
404 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
405 |
Returns:
|
406 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
|
|
18 |
import gradio as gr
|
19 |
|
20 |
from .encodec import CompressionModel
|
21 |
+
from .genmodel import BaseGenModel
|
22 |
from .lm import LMModel
|
23 |
from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
|
24 |
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
25 |
from ..data.audio_utils import convert_audio
|
26 |
+
from ..modules.conditioners import ConditioningAttributes, WavCondition, StyleConditioner
|
27 |
from ..utils.autocast import TorchAutocast
|
28 |
|
29 |
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
|
|
109 |
- stereo-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-stereo-melody
|
110 |
- stereo-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-large
|
111 |
- stereo-melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-stereo-melody-large
|
112 |
+
- musicgen-style (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-style
|
113 |
"""
|
114 |
|
115 |
if device is None:
|
|
|
122 |
# used only for unit tests
|
123 |
compression_model = get_debug_compression_model(device)
|
124 |
lm = get_debug_lm_model(device)
|
125 |
+
return MusicGen(name, compression_model, lm, max_duration=30)
|
126 |
|
127 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
128 |
if not os.path.isfile(name) and not os.path.isdir(name):
|
|
|
145 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
146 |
top_p: float = 0.0, temperature: float = 1.0,
|
147 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
148 |
+
cfg_coef_beta: tp.Optional[float] = None,
|
149 |
two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
|
150 |
"""Set the generation parameters for MusicGen.
|
151 |
|
|
|
156 |
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
157 |
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
158 |
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
159 |
+
cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
|
160 |
+
Should be only used for MusicGen melody if we want to push the text condition more than
|
161 |
+
the audio conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
|
162 |
+
double CFG.
|
163 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
164 |
instead of batching together the two. This has some impact on how things
|
165 |
are padded but seems to have little impact in practice.
|
|
|
179 |
'top_p': top_p,
|
180 |
'cfg_coef': cfg_coef,
|
181 |
'two_step_cfg': two_step_cfg,
|
182 |
+
'cfg_coef_beta': cfg_coef_beta,
|
183 |
}
|
184 |
|
185 |
+
def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0,
|
186 |
+
ds_factor: tp.Optional[int] = None,
|
187 |
+
encodec_n_q: tp.Optional[int] = None) -> None:
|
188 |
+
"""Set the parameters of the style conditioner
|
189 |
+
Args:
|
190 |
+
eval_q (int): the number of residual quantization streams used to quantize the style condition
|
191 |
+
the smaller it is, the narrower is the information bottleneck
|
192 |
+
excerpt_length (float): the excerpt length in seconds that is extracted from the audio
|
193 |
+
conditioning
|
194 |
+
ds_factor: (int): the downsampling factor used to downsample the style tokens before
|
195 |
+
using them as a prefix
|
196 |
+
encodec_n_q: (int, optional): if encodec is used as a feature extractor, sets the number
|
197 |
+
of streams that is used to extract features
|
198 |
+
"""
|
199 |
+
assert isinstance(self.lm.condition_provider.conditioners.self_wav, StyleConditioner), \
|
200 |
+
"Only use this function if you model is MusicGen-Style"
|
201 |
+
self.lm.condition_provider.conditioners.self_wav.set_params(eval_q=eval_q,
|
202 |
+
excerpt_length=excerpt_length,
|
203 |
+
ds_factor=ds_factor,
|
204 |
+
encodec_n_q=encodec_n_q)
|
205 |
+
|
206 |
def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
|
207 |
"""Override the default progress callback."""
|
208 |
self._progress_callback = progress_callback
|
|
|
428 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
429 |
|
430 |
Args:
|
431 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
432 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
433 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
434 |
Returns:
|
435 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
audiocraft/modules/codebooks_patterns.py
CHANGED
@@ -30,7 +30,7 @@ class Pattern:
|
|
30 |
|
31 |
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
32 |
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
33 |
-
to the interleaved sequence of shape [B, K, S] applying the pattern, with
|
34 |
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
35 |
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
36 |
is returned along with a mask indicating valid tokens.
|
@@ -49,7 +49,6 @@ class Pattern:
|
|
49 |
|
50 |
def __post_init__(self):
|
51 |
assert len(self.layout) > 0
|
52 |
-
assert self.layout[0] == []
|
53 |
self._validate_layout()
|
54 |
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
55 |
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
@@ -93,6 +92,9 @@ class Pattern:
|
|
93 |
valid_step = len(self.layout) - self.max_delay
|
94 |
return self.layout[:valid_step]
|
95 |
|
|
|
|
|
|
|
96 |
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
97 |
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
98 |
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
@@ -202,7 +204,7 @@ class Pattern:
|
|
202 |
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
203 |
|
204 |
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
205 |
-
if is_model_output:
|
206 |
ref_layout = ref_layout[1:]
|
207 |
|
208 |
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
@@ -335,7 +337,8 @@ class DelayedPatternProvider(CodebooksPatternProvider):
|
|
335 |
assert sorted(self.delays) == self.delays
|
336 |
|
337 |
def get_pattern(self, timesteps: int) -> Pattern:
|
338 |
-
|
|
|
339 |
max_delay = max(self.delays)
|
340 |
if self.empty_initial:
|
341 |
out += [[] for _ in range(self.empty_initial)]
|
@@ -360,9 +363,10 @@ class ParallelPatternProvider(DelayedPatternProvider):
|
|
360 |
|
361 |
Args:
|
362 |
n_q (int): Number of codebooks.
|
|
|
363 |
"""
|
364 |
-
def __init__(self, n_q: int):
|
365 |
-
super().__init__(n_q, [0] * n_q)
|
366 |
|
367 |
|
368 |
class UnrolledPatternProvider(CodebooksPatternProvider):
|
|
|
30 |
|
31 |
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
32 |
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
33 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
34 |
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
35 |
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
36 |
is returned along with a mask indicating valid tokens.
|
|
|
49 |
|
50 |
def __post_init__(self):
|
51 |
assert len(self.layout) > 0
|
|
|
52 |
self._validate_layout()
|
53 |
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
54 |
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
|
|
92 |
valid_step = len(self.layout) - self.max_delay
|
93 |
return self.layout[:valid_step]
|
94 |
|
95 |
+
def starts_with_special_token(self):
|
96 |
+
return self.layout[0] == []
|
97 |
+
|
98 |
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
99 |
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
100 |
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
|
|
204 |
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
205 |
|
206 |
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
207 |
+
if is_model_output and self.starts_with_special_token():
|
208 |
ref_layout = ref_layout[1:]
|
209 |
|
210 |
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
|
|
337 |
assert sorted(self.delays) == self.delays
|
338 |
|
339 |
def get_pattern(self, timesteps: int) -> Pattern:
|
340 |
+
omit_special_token = self.empty_initial < 0
|
341 |
+
out: PatternLayout = [] if omit_special_token else [[]]
|
342 |
max_delay = max(self.delays)
|
343 |
if self.empty_initial:
|
344 |
out += [[] for _ in range(self.empty_initial)]
|
|
|
363 |
|
364 |
Args:
|
365 |
n_q (int): Number of codebooks.
|
366 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
367 |
"""
|
368 |
+
def __init__(self, n_q: int, empty_initial: int = 0):
|
369 |
+
super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
|
370 |
|
371 |
|
372 |
class UnrolledPatternProvider(CodebooksPatternProvider):
|
audiocraft/modules/conditioners.py
CHANGED
@@ -15,8 +15,8 @@ import random
|
|
15 |
import re
|
16 |
import typing as tp
|
17 |
import warnings
|
18 |
-
|
19 |
import einops
|
|
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
|
@@ -24,10 +24,10 @@ import torch
|
|
24 |
from torch import nn
|
25 |
import torch.nn.functional as F
|
26 |
from torch.nn.utils.rnn import pad_sequence
|
27 |
-
|
28 |
from .chroma import ChromaExtractor
|
29 |
from .streaming import StreamingModule
|
30 |
-
from .transformer import create_sin_embedding
|
31 |
from ..data.audio import audio_read
|
32 |
from ..data.audio_dataset import SegmentInfo
|
33 |
from ..data.audio_utils import convert_audio
|
@@ -43,6 +43,15 @@ TextCondition = tp.Optional[str] # a text condition can be a string or None (if
|
|
43 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
class WavCondition(tp.NamedTuple):
|
47 |
wav: torch.Tensor
|
48 |
length: torch.Tensor
|
@@ -60,11 +69,17 @@ class JointEmbedCondition(tp.NamedTuple):
|
|
60 |
seek_time: tp.List[tp.Optional[float]] = []
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
63 |
@dataclass
|
64 |
class ConditioningAttributes:
|
65 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
66 |
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
67 |
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
|
|
68 |
|
69 |
def __getitem__(self, item):
|
70 |
return getattr(self, item)
|
@@ -81,19 +96,25 @@ class ConditioningAttributes:
|
|
81 |
def joint_embed_attributes(self):
|
82 |
return self.joint_embed.keys()
|
83 |
|
|
|
|
|
|
|
|
|
84 |
@property
|
85 |
def attributes(self):
|
86 |
return {
|
87 |
"text": self.text_attributes,
|
88 |
"wav": self.wav_attributes,
|
89 |
"joint_embed": self.joint_embed_attributes,
|
|
|
90 |
}
|
91 |
|
92 |
def to_flat_dict(self):
|
93 |
return {
|
94 |
**{f"text.{k}": v for k, v in self.text.items()},
|
95 |
**{f"wav.{k}": v for k, v in self.wav.items()},
|
96 |
-
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
|
|
97 |
}
|
98 |
|
99 |
@classmethod
|
@@ -177,6 +198,44 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
|
177 |
)
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
class Tokenizer:
|
181 |
"""Base tokenizer implementation
|
182 |
(in case we want to introduce more advances tokenizers in the future).
|
@@ -297,7 +356,8 @@ class BaseConditioner(nn.Module):
|
|
297 |
super().__init__()
|
298 |
self.dim = dim
|
299 |
self.output_dim = output_dim
|
300 |
-
self.
|
|
|
301 |
|
302 |
def tokenize(self, *args, **kwargs) -> tp.Any:
|
303 |
"""Should be any part of the processing that will lead to a synchronization
|
@@ -495,8 +555,9 @@ class WaveformConditioner(BaseConditioner):
|
|
495 |
wav, lengths, *_ = x
|
496 |
with torch.no_grad():
|
497 |
embeds = self._get_wav_embedding(x)
|
498 |
-
|
499 |
-
|
|
|
500 |
|
501 |
if lengths is not None and self._use_masking:
|
502 |
lengths = lengths / self._downsampling_factor()
|
@@ -607,7 +668,7 @@ class ChromaStemConditioner(WaveformConditioner):
|
|
607 |
with self.autocast:
|
608 |
wav = convert_audio(
|
609 |
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
610 |
-
stems = apply_model(self.demucs, wav, device=self.device)
|
611 |
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
612 |
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
613 |
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
@@ -698,6 +759,250 @@ class ChromaStemConditioner(WaveformConditioner):
|
|
698 |
return x
|
699 |
|
700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
class JointEmbeddingConditioner(BaseConditioner):
|
702 |
"""Joint embedding conditioning supporting both audio or text conditioning.
|
703 |
|
@@ -996,13 +1301,48 @@ class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
|
996 |
return embed, empty_idx
|
997 |
|
998 |
|
999 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
1001 |
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
1002 |
If the condition is of any other type, set its value to None.
|
1003 |
Works in-place.
|
1004 |
"""
|
1005 |
-
if condition_type not in ['text', 'wav', 'joint_embed']:
|
1006 |
raise ValueError(
|
1007 |
"dropout_condition got an unexpected condition type!"
|
1008 |
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
@@ -1021,6 +1361,8 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi
|
|
1021 |
elif condition_type == 'joint_embed':
|
1022 |
embed = sample.joint_embed[condition]
|
1023 |
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
|
|
|
|
1024 |
else:
|
1025 |
sample.text[condition] = None
|
1026 |
|
@@ -1071,7 +1413,7 @@ class AttributeDropout(DropoutModule):
|
|
1071 |
return samples
|
1072 |
|
1073 |
samples = deepcopy(samples)
|
1074 |
-
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
1075 |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
1076 |
if torch.rand(1, generator=self.rng).item() < p:
|
1077 |
for sample in samples:
|
@@ -1094,7 +1436,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
1094 |
super().__init__(seed=seed)
|
1095 |
self.p = p
|
1096 |
|
1097 |
-
def forward(self, samples: tp.List[ConditioningAttributes]
|
|
|
|
|
1098 |
"""
|
1099 |
Args:
|
1100 |
samples (list[ConditioningAttributes]): List of conditions.
|
@@ -1111,10 +1455,11 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
1111 |
|
1112 |
# nullify conditions of all attributes
|
1113 |
samples = deepcopy(samples)
|
1114 |
-
for condition_type in
|
1115 |
for sample in samples:
|
1116 |
for condition in sample.attributes[condition_type]:
|
1117 |
-
dropout_condition(sample, condition_type, condition
|
|
|
1118 |
return samples
|
1119 |
|
1120 |
def __repr__(self):
|
@@ -1339,7 +1684,7 @@ class ConditionFuser(StreamingModule):
|
|
1339 |
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
1340 |
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
1341 |
"""
|
1342 |
-
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
1343 |
|
1344 |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
1345 |
cross_attention_pos_emb_scale: float = 1.0):
|
@@ -1399,6 +1744,8 @@ class ConditionFuser(StreamingModule):
|
|
1399 |
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
1400 |
else:
|
1401 |
cross_attention_output = cond
|
|
|
|
|
1402 |
else:
|
1403 |
raise ValueError(f"unknown op ({op})")
|
1404 |
|
|
|
15 |
import re
|
16 |
import typing as tp
|
17 |
import warnings
|
|
|
18 |
import einops
|
19 |
+
import flashy
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
|
|
|
24 |
from torch import nn
|
25 |
import torch.nn.functional as F
|
26 |
from torch.nn.utils.rnn import pad_sequence
|
27 |
+
from enum import Enum
|
28 |
from .chroma import ChromaExtractor
|
29 |
from .streaming import StreamingModule
|
30 |
+
from .transformer import create_sin_embedding, StreamingTransformer
|
31 |
from ..data.audio import audio_read
|
32 |
from ..data.audio_dataset import SegmentInfo
|
33 |
from ..data.audio_utils import convert_audio
|
|
|
43 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
44 |
|
45 |
|
46 |
+
class JascoCondConst(Enum):
|
47 |
+
DRM = 'self_wav'
|
48 |
+
CRD = 'chords'
|
49 |
+
MLD = 'melody'
|
50 |
+
SYM = {'chords', 'melody'}
|
51 |
+
LAT = {'self_wav'}
|
52 |
+
ALL = ['chords', 'self_wav', 'melody'] # order matters
|
53 |
+
|
54 |
+
|
55 |
class WavCondition(tp.NamedTuple):
|
56 |
wav: torch.Tensor
|
57 |
length: torch.Tensor
|
|
|
69 |
seek_time: tp.List[tp.Optional[float]] = []
|
70 |
|
71 |
|
72 |
+
class SymbolicCondition(tp.NamedTuple):
|
73 |
+
frame_chords: tp.Optional[torch.Tensor] = None
|
74 |
+
melody: tp.Optional[torch.Tensor] = None
|
75 |
+
|
76 |
+
|
77 |
@dataclass
|
78 |
class ConditioningAttributes:
|
79 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
80 |
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
81 |
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
82 |
+
symbolic: tp.Dict[str, SymbolicCondition] = field(default_factory=dict)
|
83 |
|
84 |
def __getitem__(self, item):
|
85 |
return getattr(self, item)
|
|
|
96 |
def joint_embed_attributes(self):
|
97 |
return self.joint_embed.keys()
|
98 |
|
99 |
+
@property
|
100 |
+
def symbolic_attributes(self):
|
101 |
+
return self.symbolic.keys()
|
102 |
+
|
103 |
@property
|
104 |
def attributes(self):
|
105 |
return {
|
106 |
"text": self.text_attributes,
|
107 |
"wav": self.wav_attributes,
|
108 |
"joint_embed": self.joint_embed_attributes,
|
109 |
+
"symbolic": self.symbolic_attributes,
|
110 |
}
|
111 |
|
112 |
def to_flat_dict(self):
|
113 |
return {
|
114 |
**{f"text.{k}": v for k, v in self.text.items()},
|
115 |
**{f"wav.{k}": v for k, v in self.wav.items()},
|
116 |
+
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()},
|
117 |
+
**{f"symbolic.{k}": v for k, v in self.symbolic.items()}
|
118 |
}
|
119 |
|
120 |
@classmethod
|
|
|
198 |
)
|
199 |
|
200 |
|
201 |
+
def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 194) -> SymbolicCondition:
|
202 |
+
"""Nullify the symbolic condition by setting all frame chords to a specified null chord index.
|
203 |
+
Args:
|
204 |
+
sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
|
205 |
+
null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
|
206 |
+
Returns:
|
207 |
+
SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
|
208 |
+
"""
|
209 |
+
return SymbolicCondition(frame_chords=torch.ones_like(sym_cond.frame_chords) * null_chord_idx) # type: ignore
|
210 |
+
|
211 |
+
|
212 |
+
def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition:
|
213 |
+
"""Nullify the symbolic condition by replacing the melody matrix with zeros matrix.
|
214 |
+
Args:
|
215 |
+
sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
|
216 |
+
null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
|
217 |
+
Returns:
|
218 |
+
SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
|
219 |
+
"""
|
220 |
+
return SymbolicCondition(melody=torch.zeros_like(sym_cond.melody)) # type: ignore
|
221 |
+
|
222 |
+
|
223 |
+
def _drop_description_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
224 |
+
"""Drop the text condition but keep the wav conditon on a list of ConditioningAttributes.
|
225 |
+
This is useful to calculate l_style in the double classifier free guidance formula.
|
226 |
+
See paragraph 4.3 in https://arxiv.org/pdf/2407.12563
|
227 |
+
|
228 |
+
Args:
|
229 |
+
conditions (tp.List[ConditioningAttributes]): List of conditions.
|
230 |
+
"""
|
231 |
+
# We assert that description and self_wav are in the conditions
|
232 |
+
for condition in conditions:
|
233 |
+
assert 'description' in condition.text.keys()
|
234 |
+
assert 'self_wav' in condition.wav.keys()
|
235 |
+
return AttributeDropout(p={'text': {'description': 1.0},
|
236 |
+
'wav': {'self_wav': 0.0}})(conditions)
|
237 |
+
|
238 |
+
|
239 |
class Tokenizer:
|
240 |
"""Base tokenizer implementation
|
241 |
(in case we want to introduce more advances tokenizers in the future).
|
|
|
356 |
super().__init__()
|
357 |
self.dim = dim
|
358 |
self.output_dim = output_dim
|
359 |
+
if self.output_dim > -1: # omit projection when output_dim <= 0
|
360 |
+
self.output_proj = nn.Linear(dim, output_dim)
|
361 |
|
362 |
def tokenize(self, *args, **kwargs) -> tp.Any:
|
363 |
"""Should be any part of the processing that will lead to a synchronization
|
|
|
555 |
wav, lengths, *_ = x
|
556 |
with torch.no_grad():
|
557 |
embeds = self._get_wav_embedding(x)
|
558 |
+
if hasattr(self, 'output_proj'):
|
559 |
+
embeds = embeds.to(self.output_proj.weight)
|
560 |
+
embeds = self.output_proj(embeds)
|
561 |
|
562 |
if lengths is not None and self._use_masking:
|
563 |
lengths = lengths / self._downsampling_factor()
|
|
|
668 |
with self.autocast:
|
669 |
wav = convert_audio(
|
670 |
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
671 |
+
stems = apply_model(self.demucs, wav, device=self.device) # type: ignore
|
672 |
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
673 |
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
674 |
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
|
|
759 |
return x
|
760 |
|
761 |
|
762 |
+
class FeatureExtractor(WaveformConditioner):
|
763 |
+
"""
|
764 |
+
Feature Extractor used for the style conditioner of the paper AUDIO CONDITIONING
|
765 |
+
FOR MUSIC GENERATION VIA DISCRETE BOTTLENECK FEATURES.
|
766 |
+
|
767 |
+
Given a waveform, we extract an excerpt of defined length randomly subsampled.
|
768 |
+
Then, we feed this excerpt to a feature extractor.
|
769 |
+
|
770 |
+
Args:
|
771 |
+
model_name (str): 'encodec' or 'mert'.
|
772 |
+
sample_rate (str): sample rate of the input audio. (32000)
|
773 |
+
encodec_checkpoint (str): if encodec is used as a feature extractor, checkpoint
|
774 |
+
of the model. ('//pretrained/facebook/encodec_32khz' is the default)
|
775 |
+
encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
|
776 |
+
quantization streams used in it.
|
777 |
+
length (float): length in seconds of the random subsampled excerpt that is used
|
778 |
+
for conditioning.
|
779 |
+
dim (int): The internal representation dimension.
|
780 |
+
output_dim (int): Output dimension for the conditioner.
|
781 |
+
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
782 |
+
compute_mask (bool): whether to mask the tokens corresponding to the subsampled
|
783 |
+
excerpt in the computation of the music language model cross-entropy loss.
|
784 |
+
use_middle_of_segment (bool): if True, always take the middle of the input
|
785 |
+
instead of a random subsampled excerpt.
|
786 |
+
ds_rate_compression (int): downsampling parameter of the compression model used
|
787 |
+
for the music language model. (640 for encodec_32khz)
|
788 |
+
num_codebooks_lm (int): the number of codebooks used by the music language model.
|
789 |
+
"""
|
790 |
+
def __init__(
|
791 |
+
self, model_name: str,
|
792 |
+
sample_rate: int, encodec_checkpoint: str, encodec_n_q: int, length: float,
|
793 |
+
dim: int, output_dim: int, device: tp.Union[torch.device, str],
|
794 |
+
compute_mask: bool = True,
|
795 |
+
use_middle_of_segment: bool = False, ds_rate_compression: int = 640,
|
796 |
+
num_codebooks_lm: int = 4
|
797 |
+
):
|
798 |
+
assert model_name in ['encodec', 'mert']
|
799 |
+
if model_name == 'encodec':
|
800 |
+
from ..solvers.compression import CompressionSolver
|
801 |
+
feat_extractor = CompressionSolver.model_from_checkpoint(encodec_checkpoint, device)
|
802 |
+
elif model_name == 'mert':
|
803 |
+
from transformers import AutoModel
|
804 |
+
feat_extractor = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
|
805 |
+
super().__init__(
|
806 |
+
dim=dim,
|
807 |
+
output_dim=output_dim,
|
808 |
+
device=device
|
809 |
+
)
|
810 |
+
self.sample_rate = sample_rate
|
811 |
+
self.compute_mask = compute_mask
|
812 |
+
self.feat_extractor: nn.Module
|
813 |
+
self.embed: tp.Union[nn.ModuleList, nn.Linear]
|
814 |
+
if model_name == 'encodec':
|
815 |
+
self.__dict__["feat_extractor"] = feat_extractor.to(device)
|
816 |
+
self.encodec_n_q = encodec_n_q
|
817 |
+
self.embed = nn.ModuleList([nn.Embedding(feat_extractor.cardinality, dim) for _ in range(encodec_n_q)])
|
818 |
+
if model_name == 'mert':
|
819 |
+
self.__dict__["feat_extractor"] = feat_extractor.eval().to(device)
|
820 |
+
self.embed = nn.Linear(768, dim) # hardcoded
|
821 |
+
self.length_subwav = int(length * sample_rate)
|
822 |
+
self.ds_rate_compression = ds_rate_compression
|
823 |
+
self.model_name = model_name
|
824 |
+
self.use_middle_of_segment = use_middle_of_segment
|
825 |
+
self.num_codebooks_lm = num_codebooks_lm
|
826 |
+
|
827 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
828 |
+
if x.wav.shape[-1] == 1:
|
829 |
+
self.temp_mask = None
|
830 |
+
return torch.zeros(x.wav.shape[0], 1, self.dim, device=self.device)
|
831 |
+
else:
|
832 |
+
with torch.no_grad():
|
833 |
+
if self.use_middle_of_segment:
|
834 |
+
start = int((x.wav.shape[-1] - self.length_subwav) / 2)
|
835 |
+
wav = x.wav[:, :, start:start+self.length_subwav]
|
836 |
+
else:
|
837 |
+
start = random.randint(0, x.wav.shape[-1] - self.length_subwav)
|
838 |
+
wav = x.wav[:, :, start:start+self.length_subwav]
|
839 |
+
if self.compute_mask:
|
840 |
+
self.temp_mask = self._get_mask_wav(x, start)
|
841 |
+
if self.model_name == 'encodec':
|
842 |
+
tokens = self.feat_extractor.encode(wav)[0] # type: ignore
|
843 |
+
elif self.model_name == 'mert':
|
844 |
+
wav = convert_audio(wav, from_rate=x.sample_rate[0], to_rate=24000, to_channels=1)
|
845 |
+
embeds = self.feat_extractor(wav.squeeze(-2)).last_hidden_state
|
846 |
+
if self.model_name == 'encodec':
|
847 |
+
tokens = tokens[:, :self.encodec_n_q]
|
848 |
+
embeds = sum([self.embed[k](tokens[:, k]) for k in range(self.encodec_n_q)]) # type: ignore
|
849 |
+
else:
|
850 |
+
embeds = self.embed(embeds)
|
851 |
+
|
852 |
+
return embeds # [B, T, dim]
|
853 |
+
|
854 |
+
def _downsampling_factor(self):
|
855 |
+
if self.model_name == 'encodec':
|
856 |
+
return self.sample_rate / self.feat_extractor.frame_rate
|
857 |
+
elif self.model_name == 'mert':
|
858 |
+
return self.sample_rate / 75
|
859 |
+
|
860 |
+
def _get_mask_wav(self, x: WavCondition, start: int) -> tp.Union[torch.Tensor, None]:
|
861 |
+
if x.wav.shape[-1] == 1:
|
862 |
+
return None
|
863 |
+
total_length = int(x.wav.shape[-1] / self.ds_rate_compression)
|
864 |
+
mask_length = int(self.length_subwav / self.ds_rate_compression)
|
865 |
+
start = int(start / self.ds_rate_compression)
|
866 |
+
mask = torch.ones(x.wav.shape[0], self.num_codebooks_lm,
|
867 |
+
total_length, device=self.device, dtype=torch.bool)
|
868 |
+
mask[:, :, start:start+mask_length] = 0
|
869 |
+
return mask
|
870 |
+
|
871 |
+
|
872 |
+
class StyleConditioner(FeatureExtractor):
|
873 |
+
"""Conditioner from the paper AUDIO CONDITIONING FOR MUSIC GENERATION VIA
|
874 |
+
DISCRETE BOTTLENECK FEATURES.
|
875 |
+
Given an audio input, it is passed through a Feature Extractor and a
|
876 |
+
transformer encoder. Then it is quantized through RVQ.
|
877 |
+
|
878 |
+
Args:
|
879 |
+
transformer_scale (str): size of the transformer. See in the __init__ to have more infos.
|
880 |
+
ds_factor (int): the downsampling factor applied to the representation after quantization.
|
881 |
+
encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
|
882 |
+
quantization streams used in it.
|
883 |
+
n_q_out (int): the number of quantization streams used for the RVQ. If increased, there
|
884 |
+
is more information passing as a conditioning.
|
885 |
+
eval_q (int): the number of quantization streams used for the RVQ at evaluation time.
|
886 |
+
q_dropout (bool): if True, at training time, a random number of stream is sampled
|
887 |
+
at each step in the interval [1, n_q_out].
|
888 |
+
bins (int): the codebook size used for each quantization stream.
|
889 |
+
varying_lengths (List[float]): list of the min and max duration in seconds for the
|
890 |
+
randomly subsampled excerpt at training time. For each step a length is sampled
|
891 |
+
in this interval.
|
892 |
+
batch_norm (bool): use of batch normalization after the transformer. Stabilizes the
|
893 |
+
training.
|
894 |
+
rvq_threshold_ema_dead_code (float): threshold for dropping dead codes in the
|
895 |
+
RVQ.
|
896 |
+
"""
|
897 |
+
def __init__(self, transformer_scale: str = 'default', ds_factor: int = 15, encodec_n_q: int = 4,
|
898 |
+
n_q_out: int = 6, eval_q: int = 3, q_dropout: bool = True, bins: int = 1024,
|
899 |
+
varying_lengths: tp.List[float] = [1.5, 4.5],
|
900 |
+
batch_norm: bool = True, rvq_threshold_ema_dead_code: float = 0.1,
|
901 |
+
**kwargs):
|
902 |
+
tr_args: tp.Dict[str, tp.Any]
|
903 |
+
if transformer_scale == 'xsmall':
|
904 |
+
tr_args = {'d_model': 256, 'num_heads': 8, 'num_layers': 4}
|
905 |
+
elif transformer_scale == 'large':
|
906 |
+
tr_args = {'d_model': 1024, 'num_heads': 16, 'num_layers': 24}
|
907 |
+
elif transformer_scale == 'default':
|
908 |
+
tr_args = {'d_model': 512, 'num_heads': 8, 'num_layers': 8}
|
909 |
+
elif transformer_scale == 'none':
|
910 |
+
tr_args = {'d_model': 512}
|
911 |
+
tr_args.update({
|
912 |
+
'memory_efficient': True, 'activation': 'gelu',
|
913 |
+
'norm_first': True, 'causal': False, 'layer_scale': None,
|
914 |
+
'bias_ff': False, 'bias_attn': False,
|
915 |
+
})
|
916 |
+
dim = tr_args['d_model']
|
917 |
+
super().__init__(dim=dim, encodec_n_q=encodec_n_q, **kwargs)
|
918 |
+
|
919 |
+
self.ds_factor = ds_factor
|
920 |
+
if transformer_scale == 'none':
|
921 |
+
self.transformer = None
|
922 |
+
else:
|
923 |
+
self.transformer = StreamingTransformer(dim_feedforward=int(4 * dim), **tr_args)
|
924 |
+
self.n_q_out = n_q_out
|
925 |
+
self.eval_q = eval_q
|
926 |
+
self.rvq = None
|
927 |
+
if n_q_out > 0:
|
928 |
+
self.rvq = ResidualVectorQuantizer(dim, n_q=n_q_out, q_dropout=q_dropout, bins=bins,
|
929 |
+
threshold_ema_dead_code=rvq_threshold_ema_dead_code)
|
930 |
+
self.autocast = TorchAutocast(enabled=self.device != 'cpu', device_type=self.device, dtype=torch.float32)
|
931 |
+
self.varying_lengths = varying_lengths
|
932 |
+
self.batch_norm = None
|
933 |
+
if batch_norm:
|
934 |
+
self.batch_norm = nn.BatchNorm1d(dim, affine=False)
|
935 |
+
self.mask = None
|
936 |
+
|
937 |
+
def _get_wav_embedding(self, wav: WavCondition) -> torch.Tensor:
|
938 |
+
with self.autocast:
|
939 |
+
# Sample the length of the excerpts
|
940 |
+
if self.varying_lengths and self.training:
|
941 |
+
assert len(self.varying_lengths) == 2
|
942 |
+
length = random.uniform(self.varying_lengths[0], self.varying_lengths[1])
|
943 |
+
self.length_subwav = int(length * self.sample_rate)
|
944 |
+
z1 = super()._get_wav_embedding(wav)
|
945 |
+
if self.compute_mask:
|
946 |
+
self.mask = self.temp_mask # type: ignore
|
947 |
+
self.temp_mask = None
|
948 |
+
|
949 |
+
if self.transformer is not None:
|
950 |
+
out1 = self.transformer(z1)
|
951 |
+
else:
|
952 |
+
out1 = z1
|
953 |
+
if self.batch_norm:
|
954 |
+
out1 = self.batch_norm(out1.transpose(1, 2)).transpose(1, 2)
|
955 |
+
# Apply quantization
|
956 |
+
if self.rvq:
|
957 |
+
if self.training:
|
958 |
+
self.rvq.set_num_codebooks(self.n_q_out)
|
959 |
+
else:
|
960 |
+
self.rvq.set_num_codebooks(self.eval_q)
|
961 |
+
out1 = self.rvq(out1.transpose(1, 2), frame_rate=1.)
|
962 |
+
if self.training:
|
963 |
+
flashy.distrib.average_tensors(self.rvq.buffers())
|
964 |
+
out1 = out1.x.transpose(1, 2)
|
965 |
+
# Apply fix downsample
|
966 |
+
out1 = out1[:, ::self.ds_factor]
|
967 |
+
|
968 |
+
return out1
|
969 |
+
|
970 |
+
def set_params(self, eval_q: int = 3,
|
971 |
+
excerpt_length: float = 3.0,
|
972 |
+
ds_factor: tp.Optional[int] = None, encodec_n_q: tp.Optional[int] = None):
|
973 |
+
"""Modify the parameters of the SSL or introduce new parameters to add noise to
|
974 |
+
the conditioning or to downsample it
|
975 |
+
|
976 |
+
Args:
|
977 |
+
eval_q (int): number of codebooks used when evaluating the model
|
978 |
+
excerpt_length (float): the length of the excerpts used to condition the model
|
979 |
+
"""
|
980 |
+
self.eval_q = eval_q
|
981 |
+
self.length_subwav = int(excerpt_length * self.sample_rate)
|
982 |
+
if ds_factor is not None:
|
983 |
+
self.ds_factor = ds_factor
|
984 |
+
if encodec_n_q is not None:
|
985 |
+
self.encodec_n_q = encodec_n_q
|
986 |
+
|
987 |
+
def _downsampling_factor(self):
|
988 |
+
df = super()._downsampling_factor()
|
989 |
+
return df * self.ds_factor
|
990 |
+
|
991 |
+
def forward(self, x: WavCondition) -> ConditionType:
|
992 |
+
wav, lengths, *_ = x
|
993 |
+
|
994 |
+
embeds = self._get_wav_embedding(x)
|
995 |
+
embeds = embeds.to(self.output_proj.weight)
|
996 |
+
embeds = self.output_proj(embeds)
|
997 |
+
|
998 |
+
lengths = lengths / self._downsampling_factor()
|
999 |
+
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
1000 |
+
|
1001 |
+
embeds = (embeds * mask.unsqueeze(2).to(self.device))
|
1002 |
+
|
1003 |
+
return embeds, mask
|
1004 |
+
|
1005 |
+
|
1006 |
class JointEmbeddingConditioner(BaseConditioner):
|
1007 |
"""Joint embedding conditioning supporting both audio or text conditioning.
|
1008 |
|
|
|
1301 |
return embed, empty_idx
|
1302 |
|
1303 |
|
1304 |
+
def dropout_symbolic_conditions(sample: ConditioningAttributes,
|
1305 |
+
condition: str, null_chord_idx: int = 194) -> ConditioningAttributes:
|
1306 |
+
"""
|
1307 |
+
Applies dropout to symbolic conditions within the sample based on the specified condition by setting the condition
|
1308 |
+
value to a null index.
|
1309 |
+
Args:
|
1310 |
+
sample (ConditioningAttributes): The sample containing symbolic attributes to potentially dropout.
|
1311 |
+
condition (str): The specific condition within the symbolic attributes to apply dropout.
|
1312 |
+
null_chord_idx (int, optional): The index used to represent a null chord. Defaults to 194.
|
1313 |
+
Returns:
|
1314 |
+
ConditioningAttributes: The modified sample with dropout applied to the specified condition.
|
1315 |
+
Raises:
|
1316 |
+
ValueError: If the specified condition is not present in the sample's symbolic attributes.
|
1317 |
+
"""
|
1318 |
+
if sample.symbolic == {} or sample.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1: # type: ignore
|
1319 |
+
# nothing to drop
|
1320 |
+
return sample
|
1321 |
+
|
1322 |
+
if condition not in getattr(sample, 'symbolic'):
|
1323 |
+
raise ValueError(
|
1324 |
+
"dropout_symbolic_condition received an unexpected condition!"
|
1325 |
+
f" expected {sample.symbolic.keys()}"
|
1326 |
+
f" but got '{condition}'!"
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
if condition == JascoCondConst.CRD.value:
|
1330 |
+
sample.symbolic[condition] = nullify_chords(sample.symbolic[condition], null_chord_idx=null_chord_idx)
|
1331 |
+
elif condition == JascoCondConst.MLD.value:
|
1332 |
+
sample.symbolic[condition] = nullify_melody(sample.symbolic[condition])
|
1333 |
+
|
1334 |
+
return sample
|
1335 |
+
|
1336 |
+
|
1337 |
+
def dropout_condition(sample: ConditioningAttributes,
|
1338 |
+
condition_type: str, condition: str,
|
1339 |
+
**kwargs) -> ConditioningAttributes:
|
1340 |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
1341 |
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
1342 |
If the condition is of any other type, set its value to None.
|
1343 |
Works in-place.
|
1344 |
"""
|
1345 |
+
if condition_type not in ['text', 'wav', 'joint_embed', 'symbolic']:
|
1346 |
raise ValueError(
|
1347 |
"dropout_condition got an unexpected condition type!"
|
1348 |
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
|
|
1361 |
elif condition_type == 'joint_embed':
|
1362 |
embed = sample.joint_embed[condition]
|
1363 |
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
1364 |
+
elif condition_type == 'symbolic':
|
1365 |
+
sample = dropout_symbolic_conditions(sample=sample, condition=condition, **kwargs)
|
1366 |
else:
|
1367 |
sample.text[condition] = None
|
1368 |
|
|
|
1413 |
return samples
|
1414 |
|
1415 |
samples = deepcopy(samples)
|
1416 |
+
for condition_type, ps in self.p.items(): # for condition types [text, wav, symbolic]
|
1417 |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
1418 |
if torch.rand(1, generator=self.rng).item() < p:
|
1419 |
for sample in samples:
|
|
|
1436 |
super().__init__(seed=seed)
|
1437 |
self.p = p
|
1438 |
|
1439 |
+
def forward(self, samples: tp.List[ConditioningAttributes],
|
1440 |
+
cond_types: tp.List[str] = ["wav", "text"],
|
1441 |
+
**kwargs) -> tp.List[ConditioningAttributes]:
|
1442 |
"""
|
1443 |
Args:
|
1444 |
samples (list[ConditioningAttributes]): List of conditions.
|
|
|
1455 |
|
1456 |
# nullify conditions of all attributes
|
1457 |
samples = deepcopy(samples)
|
1458 |
+
for condition_type in cond_types:
|
1459 |
for sample in samples:
|
1460 |
for condition in sample.attributes[condition_type]:
|
1461 |
+
dropout_condition(sample, condition_type, condition,
|
1462 |
+
**kwargs)
|
1463 |
return samples
|
1464 |
|
1465 |
def __repr__(self):
|
|
|
1684 |
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
1685 |
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
1686 |
"""
|
1687 |
+
FUSING_METHODS = ["sum", "prepend", "cross", "ignore", "input_interpolate"]
|
1688 |
|
1689 |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
1690 |
cross_attention_pos_emb_scale: float = 1.0):
|
|
|
1744 |
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
1745 |
else:
|
1746 |
cross_attention_output = cond
|
1747 |
+
elif op == 'ignore':
|
1748 |
+
continue
|
1749 |
else:
|
1750 |
raise ValueError(f"unknown op ({op})")
|
1751 |
|
audiocraft/modules/jasco_conditioners.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import typing as tp
|
3 |
+
from itertools import chain
|
4 |
+
from pathlib import Path
|
5 |
+
from torch import nn
|
6 |
+
from .conditioners import (ConditioningAttributes, BaseConditioner, ConditionType,
|
7 |
+
ConditioningProvider, JascoCondConst,
|
8 |
+
WaveformConditioner, WavCondition, SymbolicCondition)
|
9 |
+
from ..data.audio import audio_read
|
10 |
+
from ..data.audio_utils import convert_audio
|
11 |
+
from ..utils.autocast import TorchAutocast
|
12 |
+
from ..utils.cache import EmbeddingCache
|
13 |
+
|
14 |
+
|
15 |
+
class MelodyConditioner(BaseConditioner):
|
16 |
+
"""
|
17 |
+
A conditioner that handles melody conditioning from pre-computed salience matrix.
|
18 |
+
Attributes:
|
19 |
+
card (int): The cardinality of the melody matrix.
|
20 |
+
out_dim (int): The dimensionality of the output projection.
|
21 |
+
device (Union[torch.device, str]): The device on which the embeddings are stored.
|
22 |
+
"""
|
23 |
+
def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
24 |
+
super().__init__(dim=card, output_dim=out_dim)
|
25 |
+
self.device = device
|
26 |
+
|
27 |
+
def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
|
28 |
+
return SymbolicCondition(melody=x.melody.to(self.device)) # type: ignore
|
29 |
+
|
30 |
+
def forward(self, x: SymbolicCondition) -> ConditionType:
|
31 |
+
embeds = self.output_proj(x.melody.permute(0, 2, 1)) # type: ignore
|
32 |
+
mask = torch.ones_like(embeds[..., 0])
|
33 |
+
return embeds, mask
|
34 |
+
|
35 |
+
|
36 |
+
class ChordsEmbConditioner(BaseConditioner):
|
37 |
+
"""
|
38 |
+
A conditioner that embeds chord symbols into a continuous vector space.
|
39 |
+
Attributes:
|
40 |
+
card (int): The cardinality of the chord vocabulary.
|
41 |
+
out_dim (int): The dimensionality of the output embeddings.
|
42 |
+
device (Union[torch.device, str]): The device on which the embeddings are stored.
|
43 |
+
"""
|
44 |
+
def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
45 |
+
vocab_size = card + 1 # card + 1 - for null chord used during dropout
|
46 |
+
super().__init__(dim=vocab_size, output_dim=-1) # out_dim=-1 to avoid another projection
|
47 |
+
self.emb = nn.Embedding(vocab_size, out_dim, device=device)
|
48 |
+
self.device = device
|
49 |
+
|
50 |
+
def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
|
51 |
+
return SymbolicCondition(frame_chords=x.frame_chords.to(self.device)) # type: ignore
|
52 |
+
|
53 |
+
def forward(self, x: SymbolicCondition) -> ConditionType:
|
54 |
+
embeds = self.emb(x.frame_chords)
|
55 |
+
mask = torch.ones_like(embeds[..., 0])
|
56 |
+
return embeds, mask
|
57 |
+
|
58 |
+
|
59 |
+
class DrumsConditioner(WaveformConditioner):
|
60 |
+
def __init__(self, out_dim: int, sample_rate: int, blurring_factor: int = 3,
|
61 |
+
cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
62 |
+
compression_model_latent_dim: int = 128,
|
63 |
+
compression_model_framerate: float = 50,
|
64 |
+
segment_duration: float = 10.0,
|
65 |
+
device: tp.Union[torch.device, str] = 'cpu',
|
66 |
+
**kwargs):
|
67 |
+
"""Drum condition conditioner
|
68 |
+
|
69 |
+
Args:
|
70 |
+
out_dim (int): _description_
|
71 |
+
sample_rate (int): _description_
|
72 |
+
blurring_factor (int, optional): _description_. Defaults to 3.
|
73 |
+
cache_path (tp.Optional[tp.Union[str, Path]], optional): path to precomputed cache. Defaults to None.
|
74 |
+
compression_model_latent_dim (int, optional): latent dimensino. Defaults to 128.
|
75 |
+
compression_model_framerate (float, optional): frame rate of the representation model. Defaults to 50.
|
76 |
+
segment_duration (float, optional): duration in sec for each audio segment. Defaults to 10.0.
|
77 |
+
device (tp.Union[torch.device, str], optional): device. Defaults to 'cpu'.
|
78 |
+
"""
|
79 |
+
from demucs import pretrained
|
80 |
+
self.sample_rate = sample_rate
|
81 |
+
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
82 |
+
stem_sources: list = self.demucs.sources # type: ignore
|
83 |
+
self.stem_idx = stem_sources.index('drums')
|
84 |
+
self.compression_model = None
|
85 |
+
self.latent_dim = compression_model_latent_dim
|
86 |
+
super().__init__(dim=self.latent_dim, output_dim=out_dim, device=device)
|
87 |
+
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
88 |
+
self._use_masking = False
|
89 |
+
self.blurring_factor = blurring_factor
|
90 |
+
self.seq_len = int(segment_duration * compression_model_framerate)
|
91 |
+
self.cache = None # If you wish to train with EmbeddingCache, call self.create_embedding_cache(cache_path)
|
92 |
+
|
93 |
+
def create_embedding_cache(self, cache_path):
|
94 |
+
if cache_path is not None:
|
95 |
+
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
96 |
+
compute_embed_fn=self._calc_coarse_drum_codes_for_cache,
|
97 |
+
extract_embed_fn=self._load_drum_codes_chunk)
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
101 |
+
"""Get parts of the wav that holds the drums, extracting the main stems from the wav."""
|
102 |
+
from demucs.apply import apply_model
|
103 |
+
from demucs.audio import convert_audio
|
104 |
+
with self.autocast:
|
105 |
+
wav = convert_audio(
|
106 |
+
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
107 |
+
stems = apply_model(self.demucs, wav, device=self.device)
|
108 |
+
drum_stem = stems[:, self.stem_idx] # extract relevant stems for drums conditioning
|
109 |
+
return convert_audio(drum_stem, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
110 |
+
|
111 |
+
def _temporal_blur(self, z: torch.Tensor):
|
112 |
+
# z: (B, T, C)
|
113 |
+
B, T, C = z.shape
|
114 |
+
if T % self.blurring_factor != 0:
|
115 |
+
# pad with reflect for T % self.temporal_blurring on the right in dim=1
|
116 |
+
pad_val = self.blurring_factor - T % self.blurring_factor
|
117 |
+
z = torch.nn.functional.pad(z, (0, 0, 0, pad_val), mode='reflect')
|
118 |
+
z = z.reshape(B, -1, self.blurring_factor, C).sum(dim=2) / self.blurring_factor
|
119 |
+
z = z.unsqueeze(2).repeat(1, 1, self.blurring_factor, 1).reshape(B, -1, C)
|
120 |
+
z = z[:, :T]
|
121 |
+
assert z.shape == (B, T, C)
|
122 |
+
return z
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
126 |
+
assert self.compression_model is not None
|
127 |
+
|
128 |
+
# stem separation of drums
|
129 |
+
drums = self._get_drums_stem(wav, sample_rate)
|
130 |
+
|
131 |
+
# continuous encoding with compression model
|
132 |
+
latents = self.compression_model.model.encoder(drums)
|
133 |
+
|
134 |
+
# quantization to coarsest codebook
|
135 |
+
coarsest_quantizer = self.compression_model.model.quantizer.layers[0]
|
136 |
+
drums = coarsest_quantizer.encode(latents).to(torch.int16)
|
137 |
+
return drums
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path],
|
141 |
+
x: WavCondition, idx: int,
|
142 |
+
max_duration_to_process: float = 600) -> torch.Tensor:
|
143 |
+
"""Extract blurred drum latents from the whole audio waveform at the given path."""
|
144 |
+
wav, sr = audio_read(path)
|
145 |
+
wav = wav[None].to(self.device)
|
146 |
+
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
147 |
+
|
148 |
+
max_frames_to_process = int(max_duration_to_process * self.sample_rate)
|
149 |
+
if wav.shape[-1] > max_frames_to_process:
|
150 |
+
# process very long tracks in chunks
|
151 |
+
start = 0
|
152 |
+
codes = []
|
153 |
+
while start < wav.shape[-1] - 1:
|
154 |
+
wav_chunk = wav[..., start: start + max_frames_to_process]
|
155 |
+
codes.append(self._extract_coarse_drum_codes(wav_chunk, self.sample_rate)[0])
|
156 |
+
start += max_frames_to_process
|
157 |
+
return torch.cat(codes)
|
158 |
+
|
159 |
+
return self._extract_coarse_drum_codes(wav, self.sample_rate)[0]
|
160 |
+
|
161 |
+
def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
162 |
+
"""Extract a chunk of coarse drum codes from the full coarse drum codes derived from the full waveform."""
|
163 |
+
wav_length = x.wav.shape[-1]
|
164 |
+
seek_time = x.seek_time[idx]
|
165 |
+
assert seek_time is not None, (
|
166 |
+
"WavCondition seek_time is required "
|
167 |
+
"when extracting chunks from pre-computed drum codes.")
|
168 |
+
assert self.compression_model is not None
|
169 |
+
frame_rate = self.compression_model.frame_rate
|
170 |
+
target_length = int(frame_rate * wav_length / self.sample_rate)
|
171 |
+
target_length = max(target_length, self.seq_len)
|
172 |
+
index = int(frame_rate * seek_time)
|
173 |
+
out = full_coarse_drum_codes[index: index + target_length]
|
174 |
+
# pad
|
175 |
+
out = torch.cat((out, torch.zeros(target_length - out.shape[0], dtype=out.dtype, device=out.device)))
|
176 |
+
return out.to(self.device)
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
180 |
+
bs = x.wav.shape[0]
|
181 |
+
if x.wav.shape[-1] <= 1:
|
182 |
+
# null condition
|
183 |
+
return torch.zeros((bs, self.seq_len, self.latent_dim), device=x.wav.device, dtype=x.wav.dtype)
|
184 |
+
|
185 |
+
# extract coarse drum codes
|
186 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
187 |
+
no_nullified_cond = x.wav.shape[-1] > 1
|
188 |
+
if self.cache is not None and no_undefined_paths and no_nullified_cond:
|
189 |
+
paths = [Path(p) for p in x.path if p is not None]
|
190 |
+
codes = self.cache.get_embed_from_cache(paths, x)
|
191 |
+
else:
|
192 |
+
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
193 |
+
codes = self._extract_coarse_drum_codes(x.wav, x.sample_rate[0])
|
194 |
+
|
195 |
+
assert self.compression_model is not None
|
196 |
+
# decode back to the continuous representation of compression model
|
197 |
+
codes = codes.unsqueeze(1).permute(1, 0, 2) # (B, T) -> (1, B, T)
|
198 |
+
codes = codes.to(torch.int64)
|
199 |
+
latents = self.compression_model.model.quantizer.decode(codes)
|
200 |
+
|
201 |
+
latents = latents.permute(0, 2, 1) # [B, C, T] -> [B, T, C]
|
202 |
+
|
203 |
+
# temporal blurring
|
204 |
+
return self._temporal_blur(latents)
|
205 |
+
|
206 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
207 |
+
"""Apply WavConditioner tokenization and populate cache if needed."""
|
208 |
+
x = super().tokenize(x)
|
209 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
210 |
+
if self.cache is not None and no_undefined_paths:
|
211 |
+
paths = [Path(p) for p in x.path if p is not None]
|
212 |
+
self.cache.populate_embed_cache(paths, x)
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class JascoConditioningProvider(ConditioningProvider):
|
217 |
+
"""
|
218 |
+
A cond-provider that manages and tokenizes various types of conditioning attributes for Jasco models.
|
219 |
+
Attributes:
|
220 |
+
chords_card (int): The cardinality of the chord vocabulary.
|
221 |
+
sequence_length (int): The length of the sequence for padding purposes.
|
222 |
+
melody_dim (int): The dimensionality of the melody matrix.
|
223 |
+
"""
|
224 |
+
def __init__(self, *args,
|
225 |
+
chords_card: int = 194,
|
226 |
+
sequence_length: int = 500,
|
227 |
+
melody_dim: int = 53, **kwargs):
|
228 |
+
self.null_chord = chords_card
|
229 |
+
self.sequence_len = sequence_length
|
230 |
+
self.melody_dim = melody_dim
|
231 |
+
super().__init__(*args, **kwargs)
|
232 |
+
|
233 |
+
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
234 |
+
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
235 |
+
This should be called before starting any real GPU work to avoid synchronization points.
|
236 |
+
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
240 |
+
text and wav conditions.
|
241 |
+
"""
|
242 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
243 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
244 |
+
f" but types were {set([type(x) for x in inputs])}"
|
245 |
+
)
|
246 |
+
|
247 |
+
output = {}
|
248 |
+
text = self._collate_text(inputs)
|
249 |
+
wavs = self._collate_wavs(inputs)
|
250 |
+
|
251 |
+
symbolic = self._collate_symbolic(inputs, self.conditioners.keys())
|
252 |
+
|
253 |
+
assert set(text.keys() | wavs.keys() | symbolic.keys()).issubset(set(self.conditioners.keys())), (
|
254 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
255 |
+
f"got {text.keys(), wavs.keys(), symbolic.keys()}"
|
256 |
+
)
|
257 |
+
|
258 |
+
for attribute, batch in chain(text.items(), wavs.items(), symbolic.items()):
|
259 |
+
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
260 |
+
return output
|
261 |
+
|
262 |
+
def _collate_symbolic(self, samples: tp.List[ConditioningAttributes],
|
263 |
+
conditioner_keys: tp.Set) -> tp.Dict[str, SymbolicCondition]:
|
264 |
+
output = {}
|
265 |
+
|
266 |
+
# collate if symbolic cond exists
|
267 |
+
if any(x in conditioner_keys for x in JascoCondConst.SYM.value):
|
268 |
+
|
269 |
+
for s in samples:
|
270 |
+
# hydrate with null chord if chords not exist - for inference support
|
271 |
+
if (s.symbolic == {} or
|
272 |
+
s.symbolic[JascoCondConst.CRD.value].frame_chords is None or
|
273 |
+
s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1): # type: ignore
|
274 |
+
# no chords conditioning - fill with null chord token
|
275 |
+
s.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
|
276 |
+
frame_chords=torch.ones(self.sequence_len, dtype=torch.int32) * self.null_chord)
|
277 |
+
|
278 |
+
if (s.symbolic == {} or
|
279 |
+
s.symbolic[JascoCondConst.MLD.value].melody is None or
|
280 |
+
s.symbolic[JascoCondConst.MLD.value].melody.shape[-1] <= 1): # type: ignore
|
281 |
+
# no chords conditioning - fill with null chord token
|
282 |
+
s.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(
|
283 |
+
melody=torch.zeros((self.melody_dim, self.sequence_len)))
|
284 |
+
|
285 |
+
if JascoCondConst.CRD.value in conditioner_keys:
|
286 |
+
# pad to max
|
287 |
+
max_seq_len = max(
|
288 |
+
[s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] for s in samples]) # type: ignore
|
289 |
+
padded_chords = [
|
290 |
+
torch.cat((x.symbolic[JascoCondConst.CRD.value].frame_chords, # type: ignore
|
291 |
+
torch.ones(max_seq_len -
|
292 |
+
x.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1], # type: ignore
|
293 |
+
dtype=torch.int32) * self.null_chord))
|
294 |
+
for x in samples
|
295 |
+
]
|
296 |
+
output[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.stack(padded_chords))
|
297 |
+
if JascoCondConst.MLD.value in conditioner_keys:
|
298 |
+
melodies = torch.stack([x.symbolic[JascoCondConst.MLD.value].melody for x in samples]) # type: ignore
|
299 |
+
output[JascoCondConst.MLD.value] = SymbolicCondition(melody=melodies)
|
300 |
+
return output
|
audiocraft/modules/transformer.py
CHANGED
@@ -315,7 +315,6 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
315 |
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
316 |
key_padding_mask=None, need_weights=False, attn_mask=None,
|
317 |
average_attn_weights=True, is_causal=False):
|
318 |
-
assert attn_mask is None
|
319 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
320 |
"use the causal args in the constructor.")
|
321 |
|
@@ -329,7 +328,10 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
329 |
assert self.causal or self.cross_attention, \
|
330 |
"Streaming only available for causal or cross attention"
|
331 |
|
|
|
|
|
332 |
if self.causal:
|
|
|
333 |
# At the moment we specialize only for the self-attention case.
|
334 |
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
335 |
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
@@ -398,6 +400,14 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
398 |
if self.attention_as_float32:
|
399 |
q, k, v = [x.float() for x in [q, k, v]]
|
400 |
if self.memory_efficient:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
p = self.dropout if self.training else 0
|
402 |
if _efficient_attention_backend == 'torch':
|
403 |
x = torch.nn.functional.scaled_dot_product_attention(
|
|
|
315 |
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
316 |
key_padding_mask=None, need_weights=False, attn_mask=None,
|
317 |
average_attn_weights=True, is_causal=False):
|
|
|
318 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
319 |
"use the causal args in the constructor.")
|
320 |
|
|
|
328 |
assert self.causal or self.cross_attention, \
|
329 |
"Streaming only available for causal or cross attention"
|
330 |
|
331 |
+
custom_attn_mask = attn_mask is not None
|
332 |
+
|
333 |
if self.causal:
|
334 |
+
assert attn_mask is None
|
335 |
# At the moment we specialize only for the self-attention case.
|
336 |
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
337 |
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
|
|
400 |
if self.attention_as_float32:
|
401 |
q, k, v = [x.float() for x in [q, k, v]]
|
402 |
if self.memory_efficient:
|
403 |
+
if custom_attn_mask:
|
404 |
+
# When using a custom attn mask:
|
405 |
+
# Move to query's device, repeat for each sample, remove align8 padding
|
406 |
+
seq_len = query.shape[1]
|
407 |
+
attn_mask = attn_mask.to(q.dtype)
|
408 |
+
attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
|
409 |
+
attn_mask = attn_mask[..., :seq_len, :seq_len]
|
410 |
+
|
411 |
p = self.dropout if self.training else 0
|
412 |
if _efficient_attention_backend == 'torch':
|
413 |
x = torch.nn.functional.scaled_dot_product_attention(
|
audiocraft/modules/unet_transformer.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import typing as tp
|
3 |
+
from .transformer import StreamingTransformer, create_sin_embedding
|
4 |
+
|
5 |
+
|
6 |
+
class UnetTransformer(StreamingTransformer):
|
7 |
+
"""U-net Transformer for processing sequences with optional skip connections.
|
8 |
+
This transformer architecture incorporates U-net style skip connections
|
9 |
+
between layers, which can be optionally enabled. It inherits from a
|
10 |
+
StreamingTransformer.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
d_model (int): Dimension of the model, typically the number of expected features in the input.
|
14 |
+
num_layers (int): Total number of layers in the transformer.
|
15 |
+
skip_connections (bool, optional): Flag to determine whether skip connections should be used.
|
16 |
+
Defaults to False.
|
17 |
+
layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training).
|
18 |
+
**kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`.
|
19 |
+
"""
|
20 |
+
def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False,
|
21 |
+
layer_dropout_p: tp.Optional[float] = None, **kwargs):
|
22 |
+
super().__init__(d_model=d_model,
|
23 |
+
num_layers=num_layers,
|
24 |
+
**kwargs)
|
25 |
+
self.skip_connect = skip_connections
|
26 |
+
if self.skip_connect:
|
27 |
+
self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model)
|
28 |
+
for _ in range(num_layers // 2)])
|
29 |
+
self.num_layers = num_layers
|
30 |
+
self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0
|
31 |
+
|
32 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
33 |
+
B, T, C = x.shape
|
34 |
+
|
35 |
+
if 'offsets' in self._streaming_state:
|
36 |
+
offsets = self._streaming_state['offsets']
|
37 |
+
else:
|
38 |
+
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
39 |
+
|
40 |
+
if self.positional_embedding in ['sin', 'sin_rope']:
|
41 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
42 |
+
positions = positions + offsets.view(-1, 1, 1)
|
43 |
+
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
44 |
+
x = x + self.positional_scale * pos_emb
|
45 |
+
|
46 |
+
skip_connections: tp.List[torch.Tensor] = []
|
47 |
+
|
48 |
+
for i, layer in enumerate(self.layers):
|
49 |
+
if self.skip_connect and i >= self.num_layers // 2:
|
50 |
+
|
51 |
+
# in the second half of the layers, add residual connection
|
52 |
+
# and linearly project the concatenated features back to d_model
|
53 |
+
x = torch.cat([x, skip_connections.pop()], dim=-1)
|
54 |
+
x = self.skip_projections[i % len(self.skip_projections)](x)
|
55 |
+
|
56 |
+
x = self._apply_layer(layer, x, *args, **kwargs)
|
57 |
+
|
58 |
+
if self.skip_connect and i < self.num_layers // 2:
|
59 |
+
if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip
|
60 |
+
skip_connections.append(torch.zeros_like(x))
|
61 |
+
else:
|
62 |
+
skip_connections.append(x)
|
63 |
+
|
64 |
+
if self._is_streaming:
|
65 |
+
self._streaming_state['offsets'] = offsets + T
|
66 |
+
|
67 |
+
return x
|
audiocraft/utils/extend.py
CHANGED
@@ -51,7 +51,7 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
|
51 |
print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
|
52 |
return segments
|
53 |
|
54 |
-
def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, progress= gr.Progress(track_tqdm=True)):
|
55 |
# generate audio segments
|
56 |
melody_segments = separate_audio_segments(melody, segment_duration, 0)
|
57 |
|
@@ -96,7 +96,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
|
|
96 |
pbar.update(1)
|
97 |
print(f"melody_segments: {len(melody_segments)} fixed")
|
98 |
|
99 |
-
# Iterate over the segments to create list of
|
100 |
for segment_idx in range(total_segments):
|
101 |
if INTERRUPTING:
|
102 |
return [], duration
|
@@ -119,6 +119,10 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
|
|
119 |
verse = verse[None]
|
120 |
verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
121 |
|
|
|
|
|
|
|
|
|
122 |
# Append the segment to the melodys list
|
123 |
melodys.append(verse)
|
124 |
pbar.update(1)
|
@@ -139,10 +143,17 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
|
|
139 |
top_p=MODEL.generation_params["top_p"],
|
140 |
temperature=MODEL.generation_params["temp"],
|
141 |
cfg_coef=MODEL.generation_params["cfg_coef"],
|
|
|
142 |
duration=segment_duration,
|
143 |
two_step_cfg=False,
|
144 |
-
rep_penalty=0.5
|
145 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# Generate a new prompt segment. This will be applied to all segments for consistency
|
147 |
print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
|
148 |
prompt_segment = MODEL.generate_with_all(
|
@@ -168,10 +179,18 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
|
|
168 |
top_p=MODEL.generation_params["top_p"],
|
169 |
temperature=MODEL.generation_params["temp"],
|
170 |
cfg_coef=MODEL.generation_params["cfg_coef"],
|
|
|
171 |
duration=mod_duration,
|
172 |
two_step_cfg=False,
|
173 |
-
rep_penalty=0.5
|
174 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
try:
|
176 |
# get last chunk
|
177 |
verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
|
|
|
51 |
print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
|
52 |
return segments
|
53 |
|
54 |
+
def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, excerpt_duration:float=3.5, progress= gr.Progress(track_tqdm=True)):
|
55 |
# generate audio segments
|
56 |
melody_segments = separate_audio_segments(melody, segment_duration, 0)
|
57 |
|
|
|
96 |
pbar.update(1)
|
97 |
print(f"melody_segments: {len(melody_segments)} fixed")
|
98 |
|
99 |
+
# Iterate over the segments to create list of Melody tensors
|
100 |
for segment_idx in range(total_segments):
|
101 |
if INTERRUPTING:
|
102 |
return [], duration
|
|
|
119 |
verse = verse[None]
|
120 |
verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
121 |
|
122 |
+
# Reduce the length of verse to sr * excerpt_duration
|
123 |
+
if ("style" in MODEL.name):
|
124 |
+
verse = verse[:, :, :int(sr * excerpt_duration)]
|
125 |
+
|
126 |
# Append the segment to the melodys list
|
127 |
melodys.append(verse)
|
128 |
pbar.update(1)
|
|
|
143 |
top_p=MODEL.generation_params["top_p"],
|
144 |
temperature=MODEL.generation_params["temp"],
|
145 |
cfg_coef=MODEL.generation_params["cfg_coef"],
|
146 |
+
cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
|
147 |
duration=segment_duration,
|
148 |
two_step_cfg=False,
|
149 |
+
rep_penalty=0.5,
|
150 |
)
|
151 |
+
if ("style" in MODEL.name):
|
152 |
+
MODEL.set_style_conditioner_params(
|
153 |
+
eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
|
154 |
+
excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
155 |
+
)
|
156 |
+
|
157 |
# Generate a new prompt segment. This will be applied to all segments for consistency
|
158 |
print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
|
159 |
prompt_segment = MODEL.generate_with_all(
|
|
|
179 |
top_p=MODEL.generation_params["top_p"],
|
180 |
temperature=MODEL.generation_params["temp"],
|
181 |
cfg_coef=MODEL.generation_params["cfg_coef"],
|
182 |
+
cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
|
183 |
duration=mod_duration,
|
184 |
two_step_cfg=False,
|
185 |
+
rep_penalty=0.5,
|
186 |
)
|
187 |
+
|
188 |
+
if ("style" in MODEL.name):
|
189 |
+
MODEL.set_style_conditioner_params(
|
190 |
+
eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
|
191 |
+
excerpt_length=min(excerpt_duration, mod_duration), # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
192 |
+
)
|
193 |
+
|
194 |
try:
|
195 |
# get last chunk
|
196 |
verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
|
audiocraft/utils/utils.py
CHANGED
@@ -298,3 +298,31 @@ def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
|
298 |
pkg = load_state_dict(path)
|
299 |
pkg.pop('text_branch.embeddings.position_ids', None)
|
300 |
clap_model.model.load_state_dict(pkg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
pkg = load_state_dict(path)
|
299 |
pkg.pop('text_branch.embeddings.position_ids', None)
|
300 |
clap_model.model.load_state_dict(pkg)
|
301 |
+
|
302 |
+
def construct_frame_chords(
|
303 |
+
min_timestamp: int,
|
304 |
+
chord_changes: tp.List[tp.Tuple[float, str]],
|
305 |
+
mapping_dict: tp.Dict,
|
306 |
+
prev_chord: str,
|
307 |
+
frame_rate: float,
|
308 |
+
segment_duration: float,
|
309 |
+
) -> tp.List[str]:
|
310 |
+
""" Translate symbolic chords [(start_time, tuples),...] into a frame-level int sequence"""
|
311 |
+
|
312 |
+
frames = [
|
313 |
+
frame / frame_rate
|
314 |
+
for frame in range(
|
315 |
+
min_timestamp, int(min_timestamp + segment_duration * frame_rate)
|
316 |
+
)
|
317 |
+
]
|
318 |
+
|
319 |
+
frame_chords = []
|
320 |
+
current_chord = prev_chord
|
321 |
+
|
322 |
+
for frame in frames:
|
323 |
+
while chord_changes and frame >= chord_changes[0][0]:
|
324 |
+
current_chord = chord_changes.pop(0)[1]
|
325 |
+
current_chord = 'N' if current_chord in {None, ''} else current_chord
|
326 |
+
frame_chords.append(mapping_dict[current_chord])
|
327 |
+
|
328 |
+
return frame_chords
|
requirements.txt
CHANGED
@@ -10,14 +10,16 @@ soundfile
|
|
10 |
huggingface_hub
|
11 |
hf_xet
|
12 |
tqdm
|
13 |
-
transformers
|
14 |
xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
|
15 |
demucs
|
16 |
librosa==0.11.0
|
17 |
soundfile
|
18 |
gradio[oauth]
|
19 |
pillow
|
|
|
20 |
torchmetrics
|
|
|
21 |
encodec
|
22 |
protobuf>=3.20.1
|
23 |
filetype
|
|
|
10 |
huggingface_hub
|
11 |
hf_xet
|
12 |
tqdm
|
13 |
+
transformers==4.43.4 # need Encodec there.
|
14 |
xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
|
15 |
demucs
|
16 |
librosa==0.11.0
|
17 |
soundfile
|
18 |
gradio[oauth]
|
19 |
pillow
|
20 |
+
torchdiffeq
|
21 |
torchmetrics
|
22 |
+
nnAudio
|
23 |
encodec
|
24 |
protobuf>=3.20.1
|
25 |
filetype
|