|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .. import FairseqDataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RawAudioDataset(FairseqDataset): |
|
def __init__( |
|
self, |
|
sample_rate, |
|
max_sample_size=None, |
|
min_sample_size=0, |
|
shuffle=True, |
|
pad=False, |
|
normalize=False, |
|
): |
|
super().__init__() |
|
|
|
self.sample_rate = sample_rate |
|
self.sizes = [] |
|
self.max_sample_size = ( |
|
max_sample_size if max_sample_size is not None else sys.maxsize |
|
) |
|
self.min_sample_size = min_sample_size |
|
self.pad = pad |
|
self.shuffle = shuffle |
|
self.normalize = normalize |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError() |
|
|
|
def __len__(self): |
|
return len(self.sizes) |
|
|
|
def postprocess(self, feats, curr_sample_rate): |
|
if feats.dim() == 2: |
|
feats = feats.mean(-1) |
|
|
|
if curr_sample_rate != self.sample_rate: |
|
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") |
|
|
|
assert feats.dim() == 1, feats.dim() |
|
|
|
if self.normalize: |
|
with torch.no_grad(): |
|
feats = F.layer_norm(feats, feats.shape) |
|
return feats |
|
|
|
def crop_to_max_size(self, wav, target_size): |
|
size = len(wav) |
|
diff = size - target_size |
|
if diff <= 0: |
|
return wav |
|
|
|
start = np.random.randint(0, diff + 1) |
|
end = size - diff + start |
|
return wav[start:end] |
|
|
|
def collater(self, samples): |
|
samples = [s for s in samples if s["source"] is not None] |
|
if len(samples) == 0: |
|
return {} |
|
|
|
sources = [s["source"] for s in samples] |
|
sizes = [len(s) for s in sources] |
|
|
|
if self.pad: |
|
target_size = min(max(sizes), self.max_sample_size) |
|
else: |
|
target_size = min(min(sizes), self.max_sample_size) |
|
|
|
collated_sources = sources[0].new_zeros(len(sources), target_size) |
|
padding_mask = ( |
|
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None |
|
) |
|
for i, (source, size) in enumerate(zip(sources, sizes)): |
|
diff = size - target_size |
|
if diff == 0: |
|
collated_sources[i] = source |
|
elif diff < 0: |
|
assert self.pad |
|
collated_sources[i] = torch.cat( |
|
[source, source.new_full((-diff,), 0.0)] |
|
) |
|
padding_mask[i, diff:] = True |
|
else: |
|
collated_sources[i] = self.crop_to_max_size(source, target_size) |
|
|
|
input = {"source": collated_sources} |
|
if self.pad: |
|
input["padding_mask"] = padding_mask |
|
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} |
|
|
|
def num_tokens(self, index): |
|
return self.size(index) |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
if self.pad: |
|
return self.sizes[index] |
|
return min(self.sizes[index], self.max_sample_size) |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
|
|
if self.shuffle: |
|
order = [np.random.permutation(len(self))] |
|
else: |
|
order = [np.arange(len(self))] |
|
|
|
order.append(self.sizes) |
|
return np.lexsort(order)[::-1] |
|
|
|
|
|
class FileAudioDataset(RawAudioDataset): |
|
def __init__( |
|
self, |
|
manifest_path, |
|
sample_rate, |
|
max_sample_size=None, |
|
min_sample_size=0, |
|
shuffle=True, |
|
pad=False, |
|
normalize=False, |
|
): |
|
super().__init__( |
|
sample_rate=sample_rate, |
|
max_sample_size=max_sample_size, |
|
min_sample_size=min_sample_size, |
|
shuffle=shuffle, |
|
pad=pad, |
|
normalize=normalize, |
|
) |
|
|
|
self.fnames = [] |
|
self.line_inds = set() |
|
|
|
skipped = 0 |
|
with open(manifest_path, "r") as f: |
|
self.root_dir = f.readline().strip() |
|
for i, line in enumerate(f): |
|
items = line.strip().split("\t") |
|
assert len(items) == 2, line |
|
sz = int(items[1]) |
|
if min_sample_size is not None and sz < min_sample_size: |
|
skipped += 1 |
|
continue |
|
self.fnames.append(items[0]) |
|
self.line_inds.add(i) |
|
self.sizes.append(sz) |
|
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") |
|
|
|
def __getitem__(self, index): |
|
import soundfile as sf |
|
|
|
fname = os.path.join(self.root_dir, self.fnames[index]) |
|
wav, curr_sample_rate = sf.read(fname) |
|
feats = torch.from_numpy(wav).float() |
|
feats = self.postprocess(feats, curr_sample_rate) |
|
return {"id": index, "source": feats} |
|
|