audio_denoiser / app.py
wrice's picture
Handle audio_path being None
3b84cc8
"""Gradio demo for denoisers."""
import gradio as gr
import torch
import torchaudio
from denoisers import UNet1DModel, WaveUNetModel
from tqdm import tqdm
MODELS = [
"wrice/unet1d-vctk-48khz",
"wrice/waveunet-vctk-48khz",
"wrice/waveunet-vctk-24khz",
]
def denoise(model_name: str, audio_path: str) -> str:
"""Denoise audio."""
if "unet1d" in model_name:
model = UNet1DModel.from_pretrained(model_name)
else:
model = WaveUNetModel.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.cuda()
if audio_path:
stream_reader = torchaudio.io.StreamReader(audio_path)
stream_reader.add_basic_audio_stream(
frames_per_chunk=model.config.max_length,
sample_rate=model.config.sample_rate,
num_channels=1,
)
stream_writer = torchaudio.io.StreamWriter("denoised.wav")
stream_writer.add_audio_stream(
sample_rate=model.config.sample_rate, num_channels=1
)
chunk_size = model.config.max_length
with stream_writer.open():
for (audio_chunk,) in tqdm(stream_reader.stream()):
if audio_chunk is None:
break
audio_chunk = audio_chunk.permute(1, 0)
original_chunk_size = audio_chunk.size(-1)
if audio_chunk.size(-1) < chunk_size:
padding = chunk_size - audio_chunk.size(-1)
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, padding))
if torch.cuda.is_available():
audio_chunk = audio_chunk.cuda()
with torch.no_grad():
denoised_chunk = model(audio_chunk[None]).audio
denoised_chunk = denoised_chunk[..., :original_chunk_size]
stream_writer.write_audio_chunk(
0, denoised_chunk.squeeze(0).permute(1, 0).cpu()
)
return "denoised.wav"
iface = gr.Interface(
fn=denoise,
inputs=[gr.Dropdown(choices=MODELS, value=MODELS[0]), gr.Audio(type="filepath")],
outputs=gr.Audio(type="filepath"),
)
iface.launch()