gense / app.py
yaoxunji's picture
Update app.py
98cc349 verified
import gradio as gr
import torch
import torchaudio
import yaml
from models.gense_wavlm import N2S, S2S
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def get_firstchannel_read(path, target_sr=16000):
wav, sr = torchaudio.load(path)
if wav.shape[0] > 1:
wav = wav[0].unsqueeze(0)
if sr != target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
wav = resampler(wav)
return wav.unsqueeze(0)
def inference(noisy_path):
noisy_wav = get_firstchannel_read(noisy_path).to(device)
noisy_s, clean_s = n2s_model.generate(noisy_wav)
enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s)
out_path = 'enhanced2.wav'
torchaudio.save(out_path, enhanced_wav, sample_rate=16000)
return out_path
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id="yaoxunji/gense", filename="gense.yaml")
n2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="n2s_wavlm.ckpt")
s2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="s2s_wavlm.ckpt")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open(config_path, "r") as f:
config = yaml.safe_load(f)
config = AttrDict(config)
n2s_model = N2S(config)
n2s_model.load_state_dict(torch.load(n2s_ckpt_path)["state_dict"])
n2s_model = n2s_model.eval()
n2s_model = n2s_model.to(device)
s2s_model = S2S(config)
s2s_model.load_state_dict(torch.load(s2s_ckpt_path)["state_dict"])
s2s_model = s2s_model.eval()
s2s_model = s2s_model.to(device)
#
demo = gr.Interface(
fn=inference,
inputs=[
gr.Audio(label="Upload Noisy Wav", type="filepath"),
],
outputs=gr.Audio(label="Enhanced Audio"),
title="GenSE Demo",
description="""
[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/pdf/2502.02942))
""",
)
demo.launch()