Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from pathlib import Path | |
import yaml | |
import time | |
import uuid | |
import numpy as np | |
import audiotools as at | |
import argbind | |
import shutil | |
import torch | |
from datetime import datetime | |
from pyharp import load_audio, save_audio, OutputLabel, LabelList, build_endpoint, ModelCard | |
import gradio as gr | |
from vampnet.interface import Interface, signal_concat | |
from vampnet import mask as pmask | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
print(f"using device {device}\n"*10) | |
interface = Interface.default() | |
init_model_choice = open("DEFAULT_MODEL").read().strip() | |
# load the init model | |
interface.load_finetuned(init_model_choice) | |
interface.to(device) | |
def to_output(sig): | |
return sig.sample_rate, sig.cpu().detach().numpy()[0][0] | |
MAX_DURATION_S = 10 | |
def load_audio(file): | |
print(file) | |
if isinstance(file, str): | |
filepath = file | |
elif isinstance(file, tuple): | |
# not a file | |
sr, samples = file | |
samples = samples / np.iinfo(samples.dtype).max | |
return sr, samples | |
else: | |
filepath = file.name | |
sig = at.AudioSignal.salient_excerpt( | |
filepath, duration=MAX_DURATION_S | |
) | |
sig = at.AudioSignal(filepath) | |
return to_output(sig) | |
def load_example_audio(): | |
return load_audio("./assets/example.wav") | |
from torch_pitch_shift import pitch_shift, get_fast_shifts | |
def shift_pitch(signal, interval: int): | |
signal.samples = pitch_shift( | |
signal.samples, | |
shift=interval, | |
sample_rate=signal.sample_rate | |
) | |
return signal | |
def onsets(sig: at.AudioSignal, hop_length: int): | |
assert sig.batch_size == 1, "batch size must be 1" | |
assert sig.num_channels == 1, "mono signals only" | |
import librosa | |
onset_frame_idxs = librosa.onset.onset_detect( | |
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate, | |
hop_length=hop_length, | |
backtrack=True, | |
) | |
return onset_frame_idxs | |
def new_vampnet_mask(self, | |
codes, | |
onset_idxs, | |
width: int = 5, | |
periodic_prompt=2, | |
upper_codebook_mask=1, | |
drop_amt: float = 0.1 | |
): | |
from vampnet.newmask import mask_and, mask_or, onset_mask, periodic_mask, drop_ones, codebook_mask | |
mask = mask_and( | |
periodic_mask(codes, periodic_prompt, 1, random_roll=False), | |
mask_or( # this re-masks the onsets, according to a periodic schedule | |
onset_mask(onset_idxs, codes, width=width), | |
periodic_mask(codes, periodic_prompt, 1, random_roll=False), | |
) | |
).int() | |
# make sure the onset idxs themselves are unmasked | |
# mask = 1 - mask | |
mask[:, :, onset_idxs] = 0 | |
mask = mask.cpu() # debug | |
mask = 1-drop_ones(1-mask, drop_amt) | |
mask = codebook_mask(mask, upper_codebook_mask) | |
# save mask as txt (ints) | |
np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d') | |
mask = mask.to(self.device) | |
return mask[:, :, :] | |
def mask_preview(periodic_p, n_mask_codebooks, onset_mask_width, dropout): | |
# make a mask preview | |
codes = torch.zeros((1, 14, 80)).to(device) | |
mask = interface.build_mask( | |
codes, | |
periodic_prompt=periodic_p, | |
# onset_mask_width=onset_mask_width, | |
_dropout=dropout, | |
upper_codebook_mask=n_mask_codebooks, | |
) | |
# mask = mask.cpu().numpy() | |
import matplotlib.pyplot as plt | |
plt.clf() | |
interface.visualize_codes(mask) | |
plt.title("mask preview") | |
plt.savefig("scratch/mask-prev.png") | |
return "scratch/mask-prev.png" | |
def _vamp_internal( | |
seed, input_audio, model_choice, | |
pitch_shift_amt, periodic_p, | |
n_mask_codebooks, onset_mask_width, | |
dropout, sampletemp, typical_filtering, | |
typical_mass, typical_min_tokens, top_p, | |
sample_cutoff, stretch_factor, sampling_steps, beat_mask_ms, num_feedback_steps, api=False, harp=False | |
): | |
print("args!") | |
print(f"seed: {seed}") | |
print(f"input_audio: {input_audio}") | |
print(f"model_choice: {model_choice}") | |
print(f"pitch_shift_amt: {pitch_shift_amt}") | |
print(f"periodic_p: {periodic_p}") | |
print(f"n_mask_codebooks: {n_mask_codebooks}") | |
print(f"onset_mask_width: {onset_mask_width}") | |
print(f"dropout: {dropout}") | |
print(f"sampletemp: {sampletemp}") | |
print(f"typical_filtering: {typical_filtering}") | |
print(f"typical_mass: {typical_mass}") | |
print(f"typical_min_tokens: {typical_min_tokens}") | |
print(f"top_p: {top_p}") | |
print(f"sample_cutoff: {sample_cutoff}") | |
print(f"stretch_factor: {stretch_factor}") | |
print(f"sampling_steps: {sampling_steps}") | |
print(f"api: {api}") | |
print(f"beat_mask_ms: {beat_mask_ms}") | |
print(f"using device {interface.device}") | |
print(f"num feedback steps: {num_feedback_steps}") | |
t0 = time.time() | |
interface.to(device) | |
print(f"using device {interface.device}") | |
_seed = seed if seed > 0 else None | |
if _seed is None: | |
_seed = int(torch.randint(0, 2**32, (1,)).item()) | |
at.util.seed(_seed) | |
if input_audio is None: | |
raise gr.Error("no input audio received!") | |
sr, input_audio = input_audio | |
input_audio = input_audio / np.iinfo(input_audio.dtype).max | |
sig = at.AudioSignal(input_audio, sr).to_mono() | |
loudness = sig.loudness() | |
sig = interface._preprocess(sig) | |
# reload the model if necessary | |
interface.load_finetuned(model_choice) | |
if pitch_shift_amt != 0: | |
sig = shift_pitch(sig, pitch_shift_amt) | |
codes = interface.encode(sig) | |
# mask = new_vampnet_mask( | |
# interface, | |
# codes, | |
# onset_idxs=onsets(sig, hop_length=interface.codec.hop_length), | |
# width=onset_mask_width, | |
# periodic_prompt=periodic_p, | |
# upper_codebook_mask=n_mask_codebooks, | |
# drop_amt=dropout | |
# ).long() | |
mask = interface.build_mask( | |
codes, | |
sig=sig, | |
periodic_prompt=periodic_p, | |
onset_mask_width=onset_mask_width, | |
_dropout=dropout, | |
upper_codebook_mask=n_mask_codebooks, | |
) | |
if beat_mask_ms > 0: | |
# bm = pmask.mask_or( | |
# pmask.periodic_mask( | |
# codes, periodic_p, random_roll=False | |
# ), | |
# ) | |
mask = pmask.mask_and( | |
mask, interface.make_beat_mask( | |
sig, after_beat_s=beat_mask_ms/1000., | |
) | |
) | |
mask = pmask.codebook_mask(mask, n_mask_codebooks) | |
np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d') | |
interface.set_chunk_size(10.0) | |
# lord help me | |
if top_p is not None: | |
if top_p > 0: | |
pass | |
else: | |
top_p = None | |
codes, mask_z = interface.vamp( | |
codes, mask, | |
batch_size=2, | |
feedback_steps=num_feedback_steps, | |
_sampling_steps=sampling_steps, | |
time_stretch_factor=stretch_factor, | |
return_mask=True, | |
temperature=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=top_p, | |
seed=_seed, | |
sample_cutoff=sample_cutoff, | |
) | |
print(f"vamp took {time.time() - t0} seconds") | |
sig = interface.decode(codes) | |
sig = sig.normalize(loudness) | |
import matplotlib.pyplot as plt | |
plt.clf() | |
# plt.imshow(mask_z[0].cpu().numpy(), aspect='auto | |
interface.visualize_codes(mask) | |
plt.title("actual mask") | |
plt.savefig("scratch/mask.png") | |
plt.clf() | |
if harp: | |
return sig | |
if not api: | |
return to_output(sig[0]), to_output(sig[1]), "scratch/mask.png" | |
else: | |
return to_output(sig[0]), to_output(sig[1]) | |
def vamp(input_audio, | |
sampletemp, | |
top_p, | |
periodic_p, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
sampling_steps, | |
beat_mask_ms, | |
num_feedback_steps): | |
return _vamp_internal( | |
seed=seed, | |
input_audio=input_audio, | |
model_choice=model_choice, | |
pitch_shift_amt=pitch_shift_amt, | |
periodic_p=periodic_p, | |
n_mask_codebooks=n_mask_codebooks, | |
onset_mask_width=onset_mask_width, | |
dropout=dropout, | |
sampletemp=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=top_p, | |
sample_cutoff=sample_cutoff, | |
stretch_factor=stretch_factor, | |
sampling_steps=sampling_steps, | |
beat_mask_ms=beat_mask_ms, | |
num_feedback_steps=num_feedback_steps, | |
api=False, | |
) | |
def api_vamp(input_audio, | |
sampletemp, top_p, | |
periodic_p, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
sampling_steps, | |
beat_mask_ms, num_feedback_steps): | |
return _vamp_internal( | |
seed=seed, | |
input_audio=input_audio, | |
model_choice=model_choice, | |
pitch_shift_amt=pitch_shift_amt, | |
periodic_p=periodic_p, | |
n_mask_codebooks=n_mask_codebooks, | |
onset_mask_width=onset_mask_width, | |
dropout=dropout, | |
sampletemp=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=top_p, | |
sample_cutoff=sample_cutoff, | |
stretch_factor=stretch_factor, | |
sampling_steps=sampling_steps, | |
beat_mask_ms=beat_mask_ms, | |
num_feedback_steps=num_feedback_steps, | |
api=True, | |
) | |
def harp_vamp(input_audio, sampletemp, periodic_p, dropout, n_mask_codebooks, model_choice, beat_mask_ms): | |
sig = at.AudioSignal(input_audio).to_mono() | |
input_audio = sig.cpu().detach().numpy()[0][0] | |
input_audio = input_audio * np.iinfo(np.int16).max | |
input_audio = input_audio.astype(np.int16) | |
input_audio = input_audio.reshape(1, -1) | |
input_audio = (sig.sample_rate, input_audio) | |
sig = _vamp_internal( | |
seed=0, | |
input_audio=input_audio, | |
model_choice=model_choice, | |
pitch_shift_amt=0, | |
periodic_p=int(periodic_p), | |
n_mask_codebooks=int(n_mask_codebooks), | |
onset_mask_width=0, | |
dropout=dropout, | |
sampletemp=sampletemp, | |
typical_filtering=False, | |
typical_mass=0.15, | |
typical_min_tokens=1, | |
top_p=None, | |
sample_cutoff=1.0, | |
stretch_factor=1.0, | |
sampling_steps=36, | |
beat_mask_ms=int(beat_mask_ms), | |
num_feedback_steps=1, | |
api=False, | |
harp=True, | |
) | |
ll = LabelList() | |
ll.append(OutputLabel(label='short label', t=0.0, description='longer description')) | |
return save_audio(sig.detach().cpu()), ll | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
manual_audio_upload = gr.File( | |
label=f"upload some audio (will be randomly trimmed to max of 100s)", | |
file_types=["audio"] | |
) | |
load_example_audio_button = gr.Button("or load example audio") | |
input_audio = gr.Audio( | |
label="input audio", | |
interactive=False, | |
type="numpy", | |
) | |
# audio_mask = gr.Audio( | |
# label="audio mask (listen to this to hear the mask hints)", | |
# interactive=False, | |
# type="numpy", | |
# ) | |
# connect widgets | |
load_example_audio_button.click( | |
fn=load_example_audio, | |
inputs=[], | |
outputs=[ input_audio] | |
) | |
manual_audio_upload.change( | |
fn=load_audio, | |
inputs=[manual_audio_upload], | |
outputs=[ input_audio] | |
) | |
# mask settings | |
with gr.Column(): | |
with gr.Accordion("manual controls", open=True): | |
periodic_p = gr.Slider( | |
label="periodic prompt", | |
minimum=0, | |
maximum=13, | |
step=1, | |
value=7, | |
) | |
onset_mask_width = gr.Slider( | |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) does not affect mask preview", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=0, visible=True | |
) | |
beat_mask_ms = gr.Slider( | |
label="beat mask width (milliseconds) does not affect mask preview", | |
minimum=1, | |
maximum=200, | |
step=1, | |
value=0, | |
visible=True | |
) | |
n_mask_codebooks = gr.Slider( | |
label="compression prompt ", | |
value=3, | |
minimum=1, | |
maximum=14, | |
step=1, | |
) | |
dropout = gr.Slider( | |
label="mask dropout", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0 | |
) | |
num_feedback_steps = gr.Slider( | |
label="feedback steps (token telephone) -- turn it up for better timbre/rhythm transfer quality, but it's slower!", | |
minimum=1, | |
maximum=8, | |
step=1, | |
value=1 | |
) | |
preset_dropdown = gr.Dropdown( | |
label="preset", | |
choices=["timbre transfer", "small variation", "small variation (follow beat)", "medium variation", "medium variation (follow beat)", "large variation", "large variation (follow beat)", "unconditional"], | |
value="medium variation" | |
) | |
def change_preset(preset_dropdown): | |
if preset_dropdown == "timbre transfer": | |
periodic_p = 2 | |
n_mask_codebooks = 1 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms = 0 | |
elif preset_dropdown == "small variation": | |
periodic_p = 5 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms = 0 | |
elif preset_dropdown == "small variation (follow beat)": | |
periodic_p = 7 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms = 50 | |
elif preset_dropdown == "medium variation": | |
periodic_p = 7 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms = 0 | |
elif preset_dropdown == "medium variation (follow beat)": | |
periodic_p = 13 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms = 50 | |
elif preset_dropdown == "large variation": | |
periodic_p = 13 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.2 | |
beat_mask_ms = 0 | |
elif preset_dropdown == "large variation (follow beat)": | |
periodic_p = 0 | |
n_mask_codebooks = 4 | |
onset_mask_width = 0 | |
dropout = 0.0 | |
beat_mask_ms=80 | |
elif preset_dropdown == "unconditional": | |
periodic_p=0 | |
n_mask_codebooks=1 | |
onset_mask_width=0 | |
dropout=0.0 | |
return periodic_p, n_mask_codebooks, onset_mask_width, dropout, beat_mask_ms | |
preset_dropdown.change( | |
fn=change_preset, | |
inputs=[preset_dropdown], | |
outputs=[periodic_p, n_mask_codebooks, onset_mask_width, dropout, beat_mask_ms] | |
) | |
# preset_dropdown.change( | |
maskimg = gr.Image( | |
label="mask image", | |
interactive=False, | |
type="filepath" | |
) | |
with gr.Accordion("extras ", open=False): | |
pitch_shift_amt = gr.Slider( | |
label="pitch shift amount (semitones)", | |
minimum=-12, | |
maximum=12, | |
step=1, | |
value=0, | |
) | |
stretch_factor = gr.Slider( | |
label="time stretch factor", | |
minimum=0, | |
maximum=8, | |
step=1, | |
value=1, | |
) | |
with gr.Accordion("sampling settings", open=False): | |
sampletemp = gr.Slider( | |
label="sample temperature", | |
minimum=0.1, | |
maximum=10.0, | |
value=1.0, | |
step=0.001 | |
) | |
top_p = gr.Slider( | |
label="top p (0.0 = off)", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0 | |
) | |
typical_filtering = gr.Checkbox( | |
label="typical filtering ", | |
value=True | |
) | |
typical_mass = gr.Slider( | |
label="typical mass (should probably stay between 0.1 and 0.5)", | |
minimum=0.01, | |
maximum=0.99, | |
value=0.15 | |
) | |
typical_min_tokens = gr.Slider( | |
label="typical min tokens (should probably stay between 1 and 256)", | |
minimum=1, | |
maximum=256, | |
step=1, | |
value=64 | |
) | |
sample_cutoff = gr.Slider( | |
label="sample cutoff", | |
minimum=0.0, | |
maximum=0.9, | |
value=1.0, | |
step=0.01 | |
) | |
sampling_steps = gr.Slider( | |
label="sampling steps", | |
minimum=1, | |
maximum=128, | |
step=1, | |
value=36 | |
) | |
seed = gr.Number( | |
label="seed (0 for random)", | |
value=0, | |
precision=0, | |
) | |
# mask settings | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
label="model choice", | |
choices=list(interface.available_models()), | |
value=init_model_choice, | |
visible=True | |
) | |
vamp_button = gr.Button("generate (vamp)!!!") | |
audio_outs = [] | |
use_as_input_btns = [] | |
for i in range(2): | |
with gr.Column(): | |
audio_outs.append(gr.Audio( | |
label=f"output audio {i+1}", | |
interactive=False, | |
type="numpy" | |
)) | |
use_as_input_btns.append( | |
gr.Button(f"use as input (feedback)") | |
) | |
thank_you = gr.Markdown("") | |
# download all the outputs | |
# download = gr.File(type="filepath", label="download outputs") | |
# mask preview change | |
for widget in ( | |
periodic_p, n_mask_codebooks, | |
onset_mask_width, dropout | |
): | |
widget.change( | |
fn=mask_preview, | |
inputs=[periodic_p, n_mask_codebooks, | |
onset_mask_width, dropout], | |
outputs=[maskimg] | |
) | |
_inputs = [ | |
input_audio, | |
sampletemp, | |
top_p, | |
periodic_p, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
sampling_steps, | |
beat_mask_ms, | |
num_feedback_steps | |
] | |
# connect widgets | |
vamp_button.click( | |
fn=vamp, | |
inputs=_inputs, | |
outputs=[audio_outs[0], audio_outs[1], maskimg], | |
) | |
api_vamp_button = gr.Button("api vamp", visible=True) | |
api_vamp_button.click( | |
fn=api_vamp, | |
inputs=[input_audio, | |
sampletemp, top_p, | |
periodic_p, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
sampling_steps, | |
beat_mask_ms, | |
num_feedback_steps | |
], | |
outputs=[audio_outs[0], audio_outs[1]], | |
api_name="vamp" | |
) | |
app = build_endpoint( | |
model_card=ModelCard( | |
name="vampnet", | |
description="generating audio by filling in the blanks.", | |
author="hugo flores garcía et al. (descript/northwestern)", | |
tags=["sound", "generation",], | |
midi_in=False, | |
midi_out=False, | |
), | |
components=[ | |
sampletemp, periodic_p, dropout, n_mask_codebooks, model_choice, beat_mask_ms | |
], | |
process_fn=harp_vamp, | |
) | |
try: | |
demo.queue() | |
demo.launch(share=True) | |
except KeyboardInterrupt: | |
shutil.rmtree("gradio-outputs", ignore_errors=True) | |
raise |