alibabasglab commited on
Commit
8bcecef
·
verified ·
1 Parent(s): 5245e8c

Delete dataloader/dataloader.py

Browse files
Files changed (1) hide show
  1. dataloader/dataloader.py +0 -498
dataloader/dataloader.py DELETED
@@ -1,498 +0,0 @@
1
- import numpy as np
2
- import math, os, csv
3
- import torchaudio
4
- import torch
5
- import torch.nn as nn
6
- import torch.utils.data as data
7
- import torch.distributed as dist
8
- import soundfile as sf
9
- from torch.utils.data import Dataset
10
- import torch.utils.data as data
11
- import os
12
- import sys
13
- sys.path.append(os.path.dirname(__file__))
14
-
15
- from dataloader.misc import read_and_config_file
16
- import librosa
17
- import random
18
- EPS = 1e-6
19
- MAX_WAV_VALUE = 32768.0
20
-
21
- def audioread(path, sampling_rate):
22
- """
23
- Reads an audio file from the specified path, normalizes the audio,
24
- resamples it to the desired sampling rate (if necessary), and ensures it is single-channel.
25
-
26
- Parameters:
27
- path (str): The file path of the audio file to be read.
28
- sampling_rate (int): The target sampling rate for the audio.
29
-
30
- Returns:
31
- numpy.ndarray: The processed audio data, normalized, resampled (if necessary),
32
- and converted to mono (if the input audio has multiple channels).
33
- """
34
-
35
- # Read audio data and its sample rate from the file.
36
- data, fs = sf.read(path)
37
-
38
- # Normalize the audio data.
39
- data, scalar = audio_norm(data)
40
-
41
- # Resample the audio if the sample rate is different from the target sampling rate.
42
- if fs != sampling_rate:
43
- data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate)
44
-
45
- # Convert to mono by selecting the first channel if the audio has multiple channels.
46
- if len(data.shape) > 1:
47
- data = data[:, 0]
48
-
49
- # Return the processed audio data.
50
- return data, scalar
51
-
52
- def audio_norm(x):
53
- """
54
- Normalizes the input audio signal to a target Root Mean Square (RMS) level,
55
- applying two stages of scaling. This ensures the audio signal is neither too quiet
56
- nor too loud, keeping its amplitude consistent.
57
-
58
- Parameters:
59
- x (numpy.ndarray): Input audio signal to be normalized.
60
-
61
- Returns:
62
- numpy.ndarray: Normalized audio signal.
63
- """
64
-
65
- # Compute the root mean square (RMS) of the input audio signal.
66
- rms = (x ** 2).mean() ** 0.5
67
-
68
- # Calculate the scalar to adjust the signal to the target level (-25 dB).
69
- scalar = 10 ** (-25 / 20) / (rms + EPS)
70
-
71
- # Scale the input audio by the computed scalar.
72
- x = x * scalar
73
-
74
- # Compute the power of the scaled audio signal.
75
- pow_x = x ** 2
76
-
77
- # Calculate the average power of the audio signal.
78
- avg_pow_x = pow_x.mean()
79
-
80
- # Compute RMS only for audio segments with higher-than-average power.
81
- rmsx = pow_x[pow_x > avg_pow_x].mean() ** 0.5
82
-
83
- # Calculate another scalar to further normalize based on higher-power segments.
84
- scalarx = 10 ** (-25 / 20) / (rmsx + EPS)
85
-
86
- # Apply the second scalar to the audio.
87
- x = x * scalarx
88
-
89
- # Return the doubly normalized audio signal.
90
- return x, 1/(scalar * scalarx + EPS)
91
-
92
- class DataReader(object):
93
- """
94
- A class for reading audio data from a list of files, normalizing it,
95
- and extracting features for further processing. It supports extracting
96
- features from each file, reshaping the data, and returning metadata
97
- like utterance ID and data length.
98
-
99
- Parameters:
100
- args: Arguments containing the input path and target sampling rate.
101
-
102
- Attributes:
103
- file_list (list): A list of audio file paths to process.
104
- sampling_rate (int): The target sampling rate for audio files.
105
- """
106
-
107
- def __init__(self, args):
108
- # Read and configure the file list from the input path provided in the arguments.
109
- # The file list is decoded, if necessary.
110
- self.file_list = read_and_config_file(args, args.input_path, decode=True)
111
-
112
- # Store the target sampling rate.
113
- self.sampling_rate = args.sampling_rate
114
-
115
- # Store the args file
116
- self.args = args
117
-
118
- def __len__(self):
119
- """
120
- Returns the number of audio files in the file list.
121
-
122
- Returns:
123
- int: Number of files to process.
124
- """
125
- return len(self.file_list)
126
-
127
- def __getitem__(self, index):
128
- """
129
- Retrieves the features of the audio file at the given index.
130
-
131
- Parameters:
132
- index (int): Index of the file in the file list.
133
-
134
- Returns:
135
- tuple: Features (inputs, utterance ID, data length) for the selected audio file.
136
- """
137
- if self.args.task == 'target_speaker_extraction':
138
- if self.args.network_reference.cue== 'lip':
139
- return self.file_list[index]
140
- return self.extract_feature(self.file_list[index])
141
-
142
- def extract_feature(self, path):
143
- """
144
- Extracts features from the given audio file path.
145
-
146
- Parameters:
147
- path (str): The file path of the audio file.
148
-
149
- Returns:
150
- inputs (numpy.ndarray): Reshaped audio data for further processing.
151
- utt_id (str): The unique identifier of the audio file, usually the filename.
152
- length (int): The length of the original audio data.
153
- """
154
- # Extract the utterance ID from the file path (usually the filename).
155
- utt_id = path.split('/')[-1]
156
-
157
- # Read and normalize the audio data, converting it to float32 for processing.
158
- #data = audioread(path, self.sampling_rate).astype(np.float32)
159
- data, scalar = audioread(path, self.sampling_rate)
160
- data = data.astype(np.float32)
161
-
162
- # Reshape the data to ensure it's in the format [1, data_length].
163
- inputs = np.reshape(data, [1, data.shape[0]])
164
-
165
- # Return the reshaped audio data, utterance ID, and the length of the original data.
166
- return inputs, utt_id, data.shape[0], scalar
167
-
168
- class Wave_Processor(object):
169
- """
170
- A class for processing audio data, specifically for reading input and label audio files,
171
- segmenting them into fixed-length segments, and applying padding or trimming as necessary.
172
-
173
- Methods:
174
- process(path, segment_length, sampling_rate):
175
- Processes audio data by reading, padding, or segmenting it to match the specified segment length.
176
-
177
- Parameters:
178
- path (dict): A dictionary containing file paths for 'inputs' and 'labels' audio files.
179
- segment_length (int): The desired length of audio segments to extract.
180
- sampling_rate (int): The target sampling rate for reading the audio files.
181
- """
182
-
183
- def process(self, path, segment_length, sampling_rate):
184
- """
185
- Reads input and label audio files, and ensures the audio is segmented into
186
- the desired length, padding if necessary or extracting random segments if
187
- the audio is longer than the target segment length.
188
-
189
- Parameters:
190
- path (dict): Dictionary containing the paths to 'inputs' and 'labels' audio files.
191
- segment_length (int): Desired length of the audio segment in samples.
192
- sampling_rate (int): Target sample rate for the audio.
193
-
194
- Returns:
195
- tuple: A pair of numpy arrays representing the processed input and label audio,
196
- either padded to the segment length or trimmed.
197
- """
198
- # Read the input and label audio files using the target sampling rate.
199
- wave_inputs = audioread(path['inputs'], sampling_rate)
200
- wave_labels = audioread(path['labels'], sampling_rate)
201
-
202
- # Get the length of the label audio (assumed both inputs and labels have similar lengths).
203
- len_wav = wave_labels.shape[0]
204
-
205
- # If the input audio is shorter than the desired segment length, pad it with zeros.
206
- if wave_inputs.shape[0] < segment_length:
207
- # Create zero-padded arrays for inputs and labels.
208
- padded_inputs = np.zeros(segment_length, dtype=np.float32)
209
- padded_labels = np.zeros(segment_length, dtype=np.float32)
210
-
211
- # Copy the original audio into the padded arrays.
212
- padded_inputs[:wave_inputs.shape[0]] = wave_inputs
213
- padded_labels[:wave_labels.shape[0]] = wave_labels
214
- else:
215
- # Randomly select a start index for segmenting the audio if it's longer than the segment length.
216
- st_idx = random.randint(0, len_wav - segment_length)
217
-
218
- # Extract a segment of the desired length from the inputs and labels.
219
- padded_inputs = wave_inputs[st_idx:st_idx + segment_length]
220
- padded_labels = wave_labels[st_idx:st_idx + segment_length]
221
-
222
- # Return the processed (padded or segmented) input and label audio.
223
- return padded_inputs, padded_labels
224
-
225
- class Fbank_Processor(object):
226
- """
227
- A class for processing input audio data into mel-filterbank (Fbank) features,
228
- including the computation of delta and delta-delta features.
229
-
230
- Methods:
231
- process(inputs, args):
232
- Processes the raw audio input and returns the mel-filterbank features
233
- along with delta and delta-delta features.
234
- """
235
-
236
- def process(self, inputs, args):
237
- # Convert frame length and shift from seconds to milliseconds.
238
- frame_length = int(args.win_len / args.sampling_rate * 1000)
239
- frame_shift = int(args.win_inc / args.sampling_rate * 1000)
240
-
241
- # Set up configuration for the mel-filterbank computation.
242
- fbank_config = {
243
- "dither": 1.0,
244
- "frame_length": frame_length,
245
- "frame_shift": frame_shift,
246
- "num_mel_bins": args.num_mels,
247
- "sample_frequency": args.sampling_rate,
248
- "window_type": args.win_type
249
- }
250
-
251
- # Convert the input audio to a FloatTensor and scale it to match the expected input range.
252
- inputs = torch.FloatTensor(inputs * MAX_WAV_VALUE)
253
-
254
- # Compute the mel-filterbank features using Kaldi's fbank function.
255
- fbank = torchaudio.compliance.kaldi.fbank(inputs.unsqueeze(0), **fbank_config)
256
-
257
- # Add delta and delta-delta features.
258
- fbank_tr = torch.transpose(fbank, 0, 1)
259
- fbank_delta = torchaudio.functional.compute_deltas(fbank_tr)
260
- fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta)
261
- fbank_delta = torch.transpose(fbank_delta, 0, 1)
262
- fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1)
263
-
264
- # Concatenate the original Fbank, delta, and delta-delta features.
265
- fbanks = torch.cat([fbank, fbank_delta, fbank_delta_delta], dim=1)
266
-
267
- return fbanks.numpy()
268
-
269
- class AudioDataset(Dataset):
270
- """
271
- A dataset class for loading and processing audio data from different data types
272
- (train, validation, test). Supports audio processing and feature extraction
273
- (e.g., waveform processing, Fbank feature extraction).
274
-
275
- Parameters:
276
- args: Arguments containing dataset configuration (paths, sampling rate, etc.).
277
- data_type (str): The type of data to load (train, val, test).
278
- """
279
-
280
- def __init__(self, args, data_type):
281
- self.args = args
282
- self.sampling_rate = args.sampling_rate
283
-
284
- # Read the list of audio files based on the data type.
285
- if data_type == 'train':
286
- self.wav_list = read_and_config_file(args.tr_list)
287
- elif data_type == 'val':
288
- self.wav_list = read_and_config_file(args.cv_list)
289
- elif data_type == 'test':
290
- self.wav_list = read_and_config_file(args.tt_list)
291
- else:
292
- print(f'Data type: {data_type} is unknown!')
293
-
294
- # Initialize processors for waveform and Fbank features.
295
- self.wav_processor = Wave_Processor()
296
- self.fbank_processor = Fbank_Processor()
297
-
298
- # Clip data to a fixed segment length based on the sampling rate and max length.
299
- self.segment_length = self.sampling_rate * self.args.max_length
300
- print(f'No. {data_type} files: {len(self.wav_list)}')
301
-
302
- def __len__(self):
303
- # Return the number of audio files in the dataset.
304
- return len(self.wav_list)
305
-
306
- def __getitem__(self, index):
307
- # Get the input and label paths from the list.
308
- data_info = self.wav_list[index]
309
-
310
- # Process the waveform inputs and labels.
311
- inputs, labels = self.wav_processor.process(
312
- {'inputs': data_info['inputs'], 'labels': data_info['labels']},
313
- self.segment_length,
314
- self.sampling_rate
315
- )
316
-
317
- # Optionally load Fbank features if specified.
318
- if self.args.load_fbank is not None:
319
- fbanks = self.fbank_processor.process(inputs, self.args)
320
- return inputs * MAX_WAV_VALUE, labels * MAX_WAV_VALUE, fbanks
321
-
322
- return inputs, labels
323
-
324
- def zero_pad_concat(self, inputs):
325
- """
326
- Concatenates a list of input arrays, applying zero-padding as needed to ensure
327
- they all match the length of the longest input.
328
-
329
- Parameters:
330
- inputs (list of numpy arrays): List of input arrays to be concatenated.
331
-
332
- Returns:
333
- numpy.ndarray: A zero-padded array with concatenated inputs.
334
- """
335
-
336
- # Get the maximum length among all inputs.
337
- max_t = max(inp.shape[0] for inp in inputs)
338
-
339
- # Determine the shape of the output based on the input dimensions.
340
- shape = None
341
- if len(inputs[0].shape) == 1:
342
- shape = (len(inputs), max_t)
343
- elif len(inputs[0].shape) == 2:
344
- shape = (len(inputs), max_t, inputs[0].shape[1])
345
-
346
- # Initialize an array with zeros to hold the concatenated inputs.
347
- input_mat = np.zeros(shape, dtype=np.float32)
348
-
349
- # Copy the input data into the zero-padded array.
350
- for e, inp in enumerate(inputs):
351
- if len(inp.shape) == 1:
352
- input_mat[e, :inp.shape[0]] = inp
353
- elif len(inp.shape) == 2:
354
- input_mat[e, :inp.shape[0], :] = inp
355
-
356
- return input_mat
357
-
358
- def collate_fn_2x_wavs(data):
359
- """
360
- A custom collate function for combining batches of waveform input and label pairs.
361
-
362
- Parameters:
363
- data (list): List of tuples (inputs, labels).
364
-
365
- Returns:
366
- tuple: Batched inputs and labels as torch.FloatTensors.
367
- """
368
- inputs, labels = zip(*data)
369
- x = torch.FloatTensor(inputs)
370
- y = torch.FloatTensor(labels)
371
- return x, y
372
-
373
- def collate_fn_2x_wavs_fbank(data):
374
- """
375
- A custom collate function for combining batches of waveform inputs, labels, and Fbank features.
376
-
377
- Parameters:
378
- data (list): List of tuples (inputs, labels, fbanks).
379
-
380
- Returns:
381
- tuple: Batched inputs, labels, and Fbank features as torch.FloatTensors.
382
- """
383
- inputs, labels, fbanks = zip(*data)
384
- x = torch.FloatTensor(inputs)
385
- y = torch.FloatTensor(labels)
386
- z = torch.FloatTensor(fbanks)
387
- return x, y, z
388
-
389
- class DistributedSampler(data.Sampler):
390
- """
391
- Sampler for distributed training. Divides the dataset among multiple replicas (processes),
392
- ensuring that each process gets a unique subset of the data. It also supports shuffling
393
- and managing epochs.
394
-
395
- Parameters:
396
- dataset (Dataset): The dataset to sample from.
397
- num_replicas (int): Number of processes participating in the training.
398
- rank (int): Rank of the current process.
399
- shuffle (bool): Whether to shuffle the data or not.
400
- seed (int): Random seed for reproducibility.
401
- """
402
-
403
- def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
404
- if num_replicas is None:
405
- if not dist.is_available():
406
- raise RuntimeError("Requires distributed package to be available")
407
- num_replicas = dist.get_world_size()
408
- if rank is None:
409
- if not dist.is_available():
410
- raise RuntimeError("Requires distributed package to be available")
411
- rank = dist.get_rank()
412
-
413
- self.dataset = dataset
414
- self.num_replicas = num_replicas
415
- self.rank = rank
416
- self.epoch = 0
417
- self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
418
- self.total_size = self.num_samples * self.num_replicas
419
- self.shuffle = shuffle
420
- self.seed = seed
421
-
422
- def __iter__(self):
423
- # Shuffle the indices based on the epoch and seed.
424
- if self.shuffle:
425
- g = torch.Generator()
426
- g.manual_seed(self.seed + self.epoch)
427
- ind = torch.randperm(int(len(self.dataset) / self.num_replicas), generator=g) * self.num_replicas
428
- indices = []
429
- for i in range(self.num_replicas):
430
- indices = indices + (ind + i).tolist()
431
- else:
432
- indices = list(range(len(self.dataset)))
433
-
434
- # Add extra samples to make the dataset evenly divisible.
435
- indices += indices[:(self.total_size - len(indices))]
436
- assert len(indices) == self.total_size
437
-
438
- # Subsample for the current process.
439
- indices = indices[self.rank * self.num_samples:(self.rank + 1) * self.num_samples]
440
- assert len(indices) == self.num_samples
441
-
442
- return iter(indices)
443
-
444
- def __len__(self):
445
- return self.num_samples
446
-
447
- def set_epoch(self, epoch):
448
- self.epoch = epoch
449
-
450
- def get_dataloader(args, data_type):
451
- """
452
- Creates and returns a data loader and sampler for the specified dataset type (train, validation, or test).
453
-
454
- Parameters:
455
- args (Namespace): Configuration arguments containing details such as batch size, sampling rate,
456
- network type, and whether distributed training is used.
457
- data_type (str): The type of dataset to load ('train', 'val', 'test').
458
-
459
- Returns:
460
- sampler (DistributedSampler or None): The sampler for distributed training, or None if not used.
461
- generator (DataLoader): The PyTorch DataLoader for the specified dataset.
462
- """
463
-
464
- # Initialize the dataset based on the given arguments and dataset type (train, val, or test).
465
- datasets = AudioDataset(args=args, data_type=data_type)
466
-
467
- # Create a distributed sampler if distributed training is enabled; otherwise, use no sampler.
468
- sampler = DistributedSampler(
469
- datasets,
470
- num_replicas=args.world_size, # Number of replicas in distributed training.
471
- rank=args.local_rank # Rank of the current process.
472
- ) if args.distributed else None
473
-
474
- # Select the appropriate collate function based on the network type.
475
- if args.network == 'FRCRN_SE_16K' or args.network == 'MossFormerGAN_SE_16K':
476
- # Use the collate function for two-channel waveform data (inputs and labels).
477
- collate_fn = collate_fn_2x_wavs
478
- elif args.network == 'MossFormer2_SE_48K':
479
- # Use the collate function for waveforms along with Fbank features.
480
- collate_fn = collate_fn_2x_wavs_fbank
481
- else:
482
- # Print an error message if the network type is unknown.
483
- print(f'in dataloader, please specify a correct network type using args.network!')
484
- return
485
-
486
- # Create a DataLoader with the specified dataset, batch size, and worker configuration.
487
- generator = data.DataLoader(
488
- datasets,
489
- batch_size=args.batch_size, # Batch size for training.
490
- shuffle=(sampler is None), # Shuffle the data only if no sampler is used.
491
- collate_fn=collate_fn, # Use the selected collate function for batching data.
492
- num_workers=args.num_workers, # Number of workers for data loading.
493
- sampler=sampler # Use the distributed sampler if applicable.
494
- )
495
-
496
- # Return both the sampler and DataLoader (generator).
497
- return sampler, generator
498
-