|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import yaml |
|
from models.gense_wavlm import N2S, S2S |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
def get_firstchannel_read(path, target_sr=16000): |
|
wav, sr = torchaudio.load(path) |
|
if wav.shape[0] > 1: |
|
wav = wav[0].unsqueeze(0) |
|
if sr != target_sr: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) |
|
wav = resampler(wav) |
|
return wav.unsqueeze(0) |
|
|
|
def inference(noisy_path): |
|
noisy_wav = get_firstchannel_read(noisy_path).to(device) |
|
noisy_s, clean_s = n2s_model.generate(noisy_wav) |
|
enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s) |
|
out_path = 'enhanced2.wav' |
|
torchaudio.save(out_path, enhanced_wav, sample_rate=16000) |
|
return out_path |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
config_path = hf_hub_download(repo_id="yaoxunji/gense", filename="gense.yaml") |
|
n2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="n2s_wavlm.ckpt") |
|
s2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="s2s_wavlm.ckpt") |
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
with open(config_path, "r") as f: |
|
config = yaml.safe_load(f) |
|
config = AttrDict(config) |
|
|
|
n2s_model = N2S(config) |
|
n2s_model.load_state_dict(torch.load(n2s_ckpt_path)["state_dict"]) |
|
n2s_model = n2s_model.eval() |
|
n2s_model = n2s_model.to(device) |
|
|
|
s2s_model = S2S(config) |
|
s2s_model.load_state_dict(torch.load(s2s_ckpt_path)["state_dict"]) |
|
s2s_model = s2s_model.eval() |
|
s2s_model = s2s_model.to(device) |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Audio(label="Upload Noisy Wav", type="filepath"), |
|
], |
|
outputs=gr.Audio(label="Enhanced Audio"), |
|
title="GenSE Demo", |
|
description=""" |
|
[](https://arxiv.org/pdf/2502.02942)) |
|
""", |
|
) |
|
|
|
demo.launch() |
|
|