import gradio as gr import torch import torchaudio import yaml from models.gense_wavlm import N2S, S2S class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def get_firstchannel_read(path, target_sr=16000): wav, sr = torchaudio.load(path) if wav.shape[0] > 1: wav = wav[0].unsqueeze(0) if sr != target_sr: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) wav = resampler(wav) return wav.unsqueeze(0) def inference(noisy_path): noisy_wav = get_firstchannel_read(noisy_path).to(device) noisy_s, clean_s = n2s_model.generate(noisy_wav) enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s) out_path = 'enhanced2.wav' torchaudio.save(out_path, enhanced_wav, sample_rate=16000) return out_path from huggingface_hub import hf_hub_download config_path = hf_hub_download(repo_id="yaoxunji/gense", filename="gense.yaml") n2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="n2s_wavlm.ckpt") s2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="s2s_wavlm.ckpt") device = 'cuda' if torch.cuda.is_available() else 'cpu' with open(config_path, "r") as f: config = yaml.safe_load(f) config = AttrDict(config) n2s_model = N2S(config) n2s_model.load_state_dict(torch.load(n2s_ckpt_path)["state_dict"]) n2s_model = n2s_model.eval() n2s_model = n2s_model.to(device) s2s_model = S2S(config) s2s_model.load_state_dict(torch.load(s2s_ckpt_path)["state_dict"]) s2s_model = s2s_model.eval() s2s_model = s2s_model.to(device) # demo = gr.Interface( fn=inference, inputs=[ gr.Audio(label="Upload Noisy Wav", type="filepath"), ], outputs=gr.Audio(label="Enhanced Audio"), title="GenSE Demo", description=""" [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/pdf/2502.02942)) """, ) demo.launch()