Tzktz's picture
Upload 7664 files
6fc683c verified
import os.path as op
from typing import BinaryIO, Optional, Tuple, Union
import numpy as np
def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization=True
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): Normalize values to [-1, 1] (Default: True)
"""
if isinstance(path_or_fp, str):
ext = op.splitext(op.basename(path_or_fp))[1]
if ext not in {".flac", ".wav"}:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError("Please install soundfile to load WAV/FLAC file")
waveform, sample_rate = sf.read(path_or_fp, dtype="float32")
if not normalization:
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
return waveform, sample_rate
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.fbank import FbankOptions, Fbank
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torch
import torchaudio.compliance.kaldi as ta_kaldi
import torchaudio.sox_effects as ta_sox
waveform = torch.from_numpy(waveform)
if len(waveform.shape) == 1:
# Mono channel: D -> 1 x D
waveform = waveform.unsqueeze(0)
else:
# Merge multiple channels to one: C x D -> 1 x D
waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1'])
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
sound, sample_rate = get_waveform(path_or_fp, normalization=False)
features = _get_kaldi_fbank(sound, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features