alibabasglab commited on
Commit
e1ed673
·
verified ·
1 Parent(s): d94221b

Upload 3 files

Browse files
dataloader/dataloader.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math, os, csv
3
+ import torchaudio
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.data as data
7
+ import torch.distributed as dist
8
+ import soundfile as sf
9
+ from torch.utils.data import Dataset
10
+ import torch.utils.data as data
11
+ import os
12
+ import sys
13
+ sys.path.append(os.path.dirname(__file__))
14
+ from pydub import AudioSegment
15
+ from dataloader.misc import read_and_config_file, get_file_extension
16
+ import librosa
17
+ import random
18
+ EPS = 1e-6
19
+ MAX_WAV_VALUE_16B = 32768.0
20
+ MAX_WAV_VALUE_32B = 2147483648.0
21
+
22
+ def audioread_archieved(path, sampling_rate):
23
+ """
24
+ Reads an audio file from the specified path, normalizes the audio,
25
+ resamples it to the desired sampling rate (if necessary), and ensures it is single-channel.
26
+
27
+ Parameters:
28
+ path (str): The file path of the audio file to be read.
29
+ sampling_rate (int): The target sampling rate for the audio.
30
+
31
+ Returns:
32
+ numpy.ndarray: The processed audio data, normalized, resampled (if necessary),
33
+ and converted to mono (if the input audio has multiple channels).
34
+ """
35
+
36
+ # Read audio data and its sample rate from the file.
37
+ data, fs = sf.read(path)
38
+
39
+ # convert to mono channel
40
+ if len(data.shape) >1:
41
+ if data.shape[0] > data.shape[1]:
42
+ data = data[:, 0]
43
+ else:
44
+ data = data[0, :]
45
+
46
+ # Normalize the audio data.
47
+ data, scalar = audio_norm(data)
48
+
49
+ # Resample the audio if the sample rate is different from the target sampling rate.
50
+ if fs != sampling_rate:
51
+ data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate)
52
+
53
+ # Convert to mono by selecting the first channel if the audio has multiple channels.
54
+ if len(data.shape) > 1:
55
+ data = data[:, 0]
56
+
57
+ # Return the processed audio data.
58
+ return data, scalar
59
+
60
+ def read_audio(file_path):
61
+ """
62
+ Use AudioSegment to load audio from all supported audio input format
63
+ """
64
+
65
+ try:
66
+ audio = AudioSegment.from_file(file_path)
67
+ return audio
68
+ except Exception as e:
69
+ print(f"Error loading file: {e}")
70
+ return None
71
+
72
+ def audioread(path, sampling_rate, use_norm):
73
+ """
74
+ Reads an audio file from the specified path, normalizes the audio,
75
+ resamples it to the desired sampling rate (if necessary), and ensures it is single-channel.
76
+
77
+ Parameters:
78
+ path (str): The file path of the audio file to be read.
79
+ sampling_rate (int): The target sampling rate for the audio.
80
+ use_norm (bool): The flag for specifying whether using input audio normalization
81
+
82
+ Returns:
83
+ numpy.ndarray: The processed audio data, normalized, resampled (if necessary),
84
+ and converted to mono (if the input audio has multiple channels).
85
+ """
86
+
87
+ # Read audio data and its sample rate from the file.
88
+ audio_info = {}
89
+ ext = get_file_extension(path).replace('.', '')
90
+ audio_info['ext']=ext
91
+
92
+ try:
93
+ data = AudioSegment.from_file(path)
94
+ except Exception as e:
95
+ print(f"Error loading file: {e}")
96
+ return None
97
+
98
+ data = read_audio(path)
99
+
100
+ audio_info['sample_rate'] = data.frame_rate
101
+ audio_info['channels'] = data.channels
102
+ audio_info['sample_width'] = data.sample_width
103
+
104
+ data_array = np.array(data.get_array_of_samples())
105
+ if max(data_array) > MAX_WAV_VALUE_16B:
106
+ audio_np = data_array / MAX_WAV_VALUE_32B
107
+ else:
108
+ audio_np = data_array / MAX_WAV_VALUE_16B
109
+
110
+ audios = []
111
+ # Check if the audio is stereo
112
+ if audio_info['channels'] == 2:
113
+ audios.append(audio_np[::2]) # Even indices (left channel)
114
+ audios.append(audio_np[1::2]) # Odd indices (right channel)
115
+ else:
116
+ audios.append(audio_np)
117
+
118
+ # Normalize the audio data.
119
+ audios_normed = []
120
+ scalars = []
121
+ for audio in audios:
122
+ if use_norm:
123
+ audio_normed, scalar = audio_norm(audio)
124
+ audios_normed.append(audio_normed)
125
+ scalars.append(scalar)
126
+ else:
127
+ audios_normed.append(audio)
128
+ scalars.append(1)
129
+ # Resample the audio if the sample rate is different from the target sampling rate.
130
+ if audio_info['sample_rate'] != sampling_rate:
131
+ index = 0
132
+ for audio_normed in audios_normed:
133
+ audios_normed[index] = librosa.resample(audio_normed, orig_sr=audio_info['sample_rate'], target_sr=sampling_rate)
134
+ index = index + 1
135
+
136
+ # Return the processed audio data.
137
+ return audios_normed, scalars, audio_info
138
+
139
+ def audio_norm(x):
140
+ """
141
+ Normalizes the input audio signal to a target Root Mean Square (RMS) level,
142
+ applying two stages of scaling. This ensures the audio signal is neither too quiet
143
+ nor too loud, keeping its amplitude consistent.
144
+
145
+ Parameters:
146
+ x (numpy.ndarray): Input audio signal to be normalized.
147
+
148
+ Returns:
149
+ numpy.ndarray: Normalized audio signal.
150
+ """
151
+
152
+ # Compute the root mean square (RMS) of the input audio signal.
153
+ rms = (x ** 2).mean() ** 0.5
154
+
155
+ # Calculate the scalar to adjust the signal to the target level (-25 dB).
156
+ scalar = 10 ** (-25 / 20) / (rms + EPS)
157
+
158
+ # Scale the input audio by the computed scalar.
159
+ x = x * scalar
160
+
161
+ # Compute the power of the scaled audio signal.
162
+ pow_x = x ** 2
163
+
164
+ # Calculate the average power of the audio signal.
165
+ avg_pow_x = pow_x.mean()
166
+
167
+ # Compute RMS only for audio segments with higher-than-average power.
168
+ rmsx = pow_x[pow_x > avg_pow_x].mean() ** 0.5
169
+
170
+ # Calculate another scalar to further normalize based on higher-power segments.
171
+ scalarx = 10 ** (-25 / 20) / (rmsx + EPS)
172
+
173
+ # Apply the second scalar to the audio.
174
+ x = x * scalarx
175
+
176
+ # Return the doubly normalized audio signal.
177
+ return x, 1/(scalar * scalarx + EPS)
178
+
179
+ class DataReader(object):
180
+ """
181
+ A class for reading audio data from a list of files, normalizing it,
182
+ and extracting features for further processing. It supports extracting
183
+ features from each file, reshaping the data, and returning metadata
184
+ like utterance ID and data length.
185
+
186
+ Parameters:
187
+ args: Arguments containing the input path and target sampling rate.
188
+
189
+ Attributes:
190
+ file_list (list): A list of audio file paths to process.
191
+ sampling_rate (int): The target sampling rate for audio files.
192
+ """
193
+
194
+ def __init__(self, args):
195
+ # Read and configure the file list from the input path provided in the arguments.
196
+ # The file list is decoded, if necessary.
197
+ self.file_list = read_and_config_file(args, args.input_path, decode=True)
198
+
199
+ # Store the target sampling rate.
200
+ self.sampling_rate = args.sampling_rate
201
+
202
+ # Store the args file
203
+ self.args = args
204
+
205
+ def __len__(self):
206
+ """
207
+ Returns the number of audio files in the file list.
208
+
209
+ Returns:
210
+ int: Number of files to process.
211
+ """
212
+ return len(self.file_list)
213
+
214
+ def __getitem__(self, index):
215
+ """
216
+ Retrieves the features of the audio file at the given index.
217
+
218
+ Parameters:
219
+ index (int): Index of the file in the file list.
220
+
221
+ Returns:
222
+ tuple: Features (inputs, utterance ID, data length) for the selected audio file.
223
+ """
224
+ if self.args.task == 'target_speaker_extraction':
225
+ if self.args.network_reference.cue== 'lip':
226
+ return self.file_list[index]
227
+ return self.extract_feature(self.file_list[index])
228
+
229
+ def extract_feature(self, path):
230
+ """
231
+ Extracts features from the given audio file path.
232
+
233
+ Parameters:
234
+ path (str): The file path of the audio file.
235
+
236
+ Returns:
237
+ inputs (numpy.ndarray): Reshaped audio data for further processing.
238
+ utt_id (str): The unique identifier of the audio file, usually the filename.
239
+ length (int): The length of the original audio data.
240
+ """
241
+ # Extract the utterance ID from the file path (usually the filename).
242
+ utt_id = path.split('/')[-1]
243
+ use_norm = False
244
+
245
+ #We suggest to use norm for 'FRCRN_SE_16K' and 'MossFormer2_SS_16K' models
246
+ if self.args.network in ['FRCRN_SE_16K','MossFormer2_SS_16K'] :
247
+ use_norm = True
248
+
249
+ # Read and normalize the audio data, converting it to float32 for processing.
250
+ audios_norm, scalars, audio_info = audioread(path, self.sampling_rate, use_norm)
251
+
252
+ if self.args.network in ['MossFormer2_SR_48K']:
253
+ audio_info['sample_rate'] = self.sampling_rate
254
+
255
+ for i in range(len(audios_norm)):
256
+ audios_norm[i] = audios_norm[i].astype(np.float32)
257
+ # Reshape the data to ensure it's in the format [1, data_length].
258
+ audios_norm[i] = np.reshape(audios_norm[i], [1, audios_norm[i].shape[0]])
259
+
260
+ # Return the reshaped audio data, utterance ID, and the length of the original data.
261
+ return audios_norm, utt_id, audios_norm[0].shape[1], scalars, audio_info
262
+
263
+ class Wave_Processor(object):
264
+ """
265
+ A class for processing audio data, specifically for reading input and label audio files,
266
+ segmenting them into fixed-length segments, and applying padding or trimming as necessary.
267
+
268
+ Methods:
269
+ process(path, segment_length, sampling_rate):
270
+ Processes audio data by reading, padding, or segmenting it to match the specified segment length.
271
+
272
+ Parameters:
273
+ path (dict): A dictionary containing file paths for 'inputs' and 'labels' audio files.
274
+ segment_length (int): The desired length of audio segments to extract.
275
+ sampling_rate (int): The target sampling rate for reading the audio files.
276
+ """
277
+
278
+ def process(self, path, segment_length, sampling_rate):
279
+ """
280
+ Reads input and label audio files, and ensures the audio is segmented into
281
+ the desired length, padding if necessary or extracting random segments if
282
+ the audio is longer than the target segment length.
283
+
284
+ Parameters:
285
+ path (dict): Dictionary containing the paths to 'inputs' and 'labels' audio files.
286
+ segment_length (int): Desired length of the audio segment in samples.
287
+ sampling_rate (int): Target sample rate for the audio.
288
+
289
+ Returns:
290
+ tuple: A pair of numpy arrays representing the processed input and label audio,
291
+ either padded to the segment length or trimmed.
292
+ """
293
+ # Read the input and label audio files using the target sampling rate.
294
+ wave_inputs = audioread(path['inputs'], sampling_rate)
295
+ wave_labels = audioread(path['labels'], sampling_rate)
296
+
297
+ # Get the length of the label audio (assumed both inputs and labels have similar lengths).
298
+ len_wav = wave_labels.shape[0]
299
+
300
+ # If the input audio is shorter than the desired segment length, pad it with zeros.
301
+ if wave_inputs.shape[0] < segment_length:
302
+ # Create zero-padded arrays for inputs and labels.
303
+ padded_inputs = np.zeros(segment_length, dtype=np.float32)
304
+ padded_labels = np.zeros(segment_length, dtype=np.float32)
305
+
306
+ # Copy the original audio into the padded arrays.
307
+ padded_inputs[:wave_inputs.shape[0]] = wave_inputs
308
+ padded_labels[:wave_labels.shape[0]] = wave_labels
309
+ else:
310
+ # Randomly select a start index for segmenting the audio if it's longer than the segment length.
311
+ st_idx = random.randint(0, len_wav - segment_length)
312
+
313
+ # Extract a segment of the desired length from the inputs and labels.
314
+ padded_inputs = wave_inputs[st_idx:st_idx + segment_length]
315
+ padded_labels = wave_labels[st_idx:st_idx + segment_length]
316
+
317
+ # Return the processed (padded or segmented) input and label audio.
318
+ return padded_inputs, padded_labels
319
+
320
+ class Fbank_Processor(object):
321
+ """
322
+ A class for processing input audio data into mel-filterbank (Fbank) features,
323
+ including the computation of delta and delta-delta features.
324
+
325
+ Methods:
326
+ process(inputs, args):
327
+ Processes the raw audio input and returns the mel-filterbank features
328
+ along with delta and delta-delta features.
329
+ """
330
+
331
+ def process(self, inputs, args):
332
+ # Convert frame length and shift from seconds to milliseconds.
333
+ frame_length = int(args.win_len / args.sampling_rate * 1000)
334
+ frame_shift = int(args.win_inc / args.sampling_rate * 1000)
335
+
336
+ # Set up configuration for the mel-filterbank computation.
337
+ fbank_config = {
338
+ "dither": 1.0,
339
+ "frame_length": frame_length,
340
+ "frame_shift": frame_shift,
341
+ "num_mel_bins": args.num_mels,
342
+ "sample_frequency": args.sampling_rate,
343
+ "window_type": args.win_type
344
+ }
345
+
346
+ # Convert the input audio to a FloatTensor and scale it to match the expected input range.
347
+ inputs = torch.FloatTensor(inputs * MAX_WAV_VALUE)
348
+
349
+ # Compute the mel-filterbank features using Kaldi's fbank function.
350
+ fbank = torchaudio.compliance.kaldi.fbank(inputs.unsqueeze(0), **fbank_config)
351
+
352
+ # Add delta and delta-delta features.
353
+ fbank_tr = torch.transpose(fbank, 0, 1)
354
+ fbank_delta = torchaudio.functional.compute_deltas(fbank_tr)
355
+ fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta)
356
+ fbank_delta = torch.transpose(fbank_delta, 0, 1)
357
+ fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1)
358
+
359
+ # Concatenate the original Fbank, delta, and delta-delta features.
360
+ fbanks = torch.cat([fbank, fbank_delta, fbank_delta_delta], dim=1)
361
+
362
+ return fbanks.numpy()
363
+
364
+ class AudioDataset(Dataset):
365
+ """
366
+ A dataset class for loading and processing audio data from different data types
367
+ (train, validation, test). Supports audio processing and feature extraction
368
+ (e.g., waveform processing, Fbank feature extraction).
369
+
370
+ Parameters:
371
+ args: Arguments containing dataset configuration (paths, sampling rate, etc.).
372
+ data_type (str): The type of data to load (train, val, test).
373
+ """
374
+
375
+ def __init__(self, args, data_type):
376
+ self.args = args
377
+ self.sampling_rate = args.sampling_rate
378
+
379
+ # Read the list of audio files based on the data type.
380
+ if data_type == 'train':
381
+ self.wav_list = read_and_config_file(args.tr_list)
382
+ elif data_type == 'val':
383
+ self.wav_list = read_and_config_file(args.cv_list)
384
+ elif data_type == 'test':
385
+ self.wav_list = read_and_config_file(args.tt_list)
386
+ else:
387
+ print(f'Data type: {data_type} is unknown!')
388
+
389
+ # Initialize processors for waveform and Fbank features.
390
+ self.wav_processor = Wave_Processor()
391
+ self.fbank_processor = Fbank_Processor()
392
+
393
+ # Clip data to a fixed segment length based on the sampling rate and max length.
394
+ self.segment_length = self.sampling_rate * self.args.max_length
395
+ print(f'No. {data_type} files: {len(self.wav_list)}')
396
+
397
+ def __len__(self):
398
+ # Return the number of audio files in the dataset.
399
+ return len(self.wav_list)
400
+
401
+ def __getitem__(self, index):
402
+ # Get the input and label paths from the list.
403
+ data_info = self.wav_list[index]
404
+
405
+ # Process the waveform inputs and labels.
406
+ inputs, labels = self.wav_processor.process(
407
+ {'inputs': data_info['inputs'], 'labels': data_info['labels']},
408
+ self.segment_length,
409
+ self.sampling_rate
410
+ )
411
+
412
+ # Optionally load Fbank features if specified.
413
+ if self.args.load_fbank is not None:
414
+ fbanks = self.fbank_processor.process(inputs, self.args)
415
+ return inputs * MAX_WAV_VALUE, labels * MAX_WAV_VALUE, fbanks
416
+
417
+ return inputs, labels
418
+
419
+ def zero_pad_concat(self, inputs):
420
+ """
421
+ Concatenates a list of input arrays, applying zero-padding as needed to ensure
422
+ they all match the length of the longest input.
423
+
424
+ Parameters:
425
+ inputs (list of numpy arrays): List of input arrays to be concatenated.
426
+
427
+ Returns:
428
+ numpy.ndarray: A zero-padded array with concatenated inputs.
429
+ """
430
+
431
+ # Get the maximum length among all inputs.
432
+ max_t = max(inp.shape[0] for inp in inputs)
433
+
434
+ # Determine the shape of the output based on the input dimensions.
435
+ shape = None
436
+ if len(inputs[0].shape) == 1:
437
+ shape = (len(inputs), max_t)
438
+ elif len(inputs[0].shape) == 2:
439
+ shape = (len(inputs), max_t, inputs[0].shape[1])
440
+
441
+ # Initialize an array with zeros to hold the concatenated inputs.
442
+ input_mat = np.zeros(shape, dtype=np.float32)
443
+
444
+ # Copy the input data into the zero-padded array.
445
+ for e, inp in enumerate(inputs):
446
+ if len(inp.shape) == 1:
447
+ input_mat[e, :inp.shape[0]] = inp
448
+ elif len(inp.shape) == 2:
449
+ input_mat[e, :inp.shape[0], :] = inp
450
+
451
+ return input_mat
452
+
453
+ def collate_fn_2x_wavs(data):
454
+ """
455
+ A custom collate function for combining batches of waveform input and label pairs.
456
+
457
+ Parameters:
458
+ data (list): List of tuples (inputs, labels).
459
+
460
+ Returns:
461
+ tuple: Batched inputs and labels as torch.FloatTensors.
462
+ """
463
+ inputs, labels = zip(*data)
464
+ x = torch.FloatTensor(inputs)
465
+ y = torch.FloatTensor(labels)
466
+ return x, y
467
+
468
+ def collate_fn_2x_wavs_fbank(data):
469
+ """
470
+ A custom collate function for combining batches of waveform inputs, labels, and Fbank features.
471
+
472
+ Parameters:
473
+ data (list): List of tuples (inputs, labels, fbanks).
474
+
475
+ Returns:
476
+ tuple: Batched inputs, labels, and Fbank features as torch.FloatTensors.
477
+ """
478
+ inputs, labels, fbanks = zip(*data)
479
+ x = torch.FloatTensor(inputs)
480
+ y = torch.FloatTensor(labels)
481
+ z = torch.FloatTensor(fbanks)
482
+ return x, y, z
483
+
484
+ class DistributedSampler(data.Sampler):
485
+ """
486
+ Sampler for distributed training. Divides the dataset among multiple replicas (processes),
487
+ ensuring that each process gets a unique subset of the data. It also supports shuffling
488
+ and managing epochs.
489
+
490
+ Parameters:
491
+ dataset (Dataset): The dataset to sample from.
492
+ num_replicas (int): Number of processes participating in the training.
493
+ rank (int): Rank of the current process.
494
+ shuffle (bool): Whether to shuffle the data or not.
495
+ seed (int): Random seed for reproducibility.
496
+ """
497
+
498
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
499
+ if num_replicas is None:
500
+ if not dist.is_available():
501
+ raise RuntimeError("Requires distributed package to be available")
502
+ num_replicas = dist.get_world_size()
503
+ if rank is None:
504
+ if not dist.is_available():
505
+ raise RuntimeError("Requires distributed package to be available")
506
+ rank = dist.get_rank()
507
+
508
+ self.dataset = dataset
509
+ self.num_replicas = num_replicas
510
+ self.rank = rank
511
+ self.epoch = 0
512
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
513
+ self.total_size = self.num_samples * self.num_replicas
514
+ self.shuffle = shuffle
515
+ self.seed = seed
516
+
517
+ def __iter__(self):
518
+ # Shuffle the indices based on the epoch and seed.
519
+ if self.shuffle:
520
+ g = torch.Generator()
521
+ g.manual_seed(self.seed + self.epoch)
522
+ ind = torch.randperm(int(len(self.dataset) / self.num_replicas), generator=g) * self.num_replicas
523
+ indices = []
524
+ for i in range(self.num_replicas):
525
+ indices = indices + (ind + i).tolist()
526
+ else:
527
+ indices = list(range(len(self.dataset)))
528
+
529
+ # Add extra samples to make the dataset evenly divisible.
530
+ indices += indices[:(self.total_size - len(indices))]
531
+ assert len(indices) == self.total_size
532
+
533
+ # Subsample for the current process.
534
+ indices = indices[self.rank * self.num_samples:(self.rank + 1) * self.num_samples]
535
+ assert len(indices) == self.num_samples
536
+
537
+ return iter(indices)
538
+
539
+ def __len__(self):
540
+ return self.num_samples
541
+
542
+ def set_epoch(self, epoch):
543
+ self.epoch = epoch
544
+
545
+ def get_dataloader(args, data_type):
546
+ """
547
+ Creates and returns a data loader and sampler for the specified dataset type (train, validation, or test).
548
+
549
+ Parameters:
550
+ args (Namespace): Configuration arguments containing details such as batch size, sampling rate,
551
+ network type, and whether distributed training is used.
552
+ data_type (str): The type of dataset to load ('train', 'val', 'test').
553
+
554
+ Returns:
555
+ sampler (DistributedSampler or None): The sampler for distributed training, or None if not used.
556
+ generator (DataLoader): The PyTorch DataLoader for the specified dataset.
557
+ """
558
+
559
+ # Initialize the dataset based on the given arguments and dataset type (train, val, or test).
560
+ datasets = AudioDataset(args=args, data_type=data_type)
561
+
562
+ # Create a distributed sampler if distributed training is enabled; otherwise, use no sampler.
563
+ sampler = DistributedSampler(
564
+ datasets,
565
+ num_replicas=args.world_size, # Number of replicas in distributed training.
566
+ rank=args.local_rank # Rank of the current process.
567
+ ) if args.distributed else None
568
+
569
+ # Select the appropriate collate function based on the network type.
570
+ if args.network == 'FRCRN_SE_16K' or args.network == 'MossFormerGAN_SE_16K':
571
+ # Use the collate function for two-channel waveform data (inputs and labels).
572
+ collate_fn = collate_fn_2x_wavs
573
+ elif args.network == 'MossFormer2_SE_48K':
574
+ # Use the collate function for waveforms along with Fbank features.
575
+ collate_fn = collate_fn_2x_wavs_fbank
576
+ else:
577
+ # Print an error message if the network type is unknown.
578
+ print(f'in dataloader, please specify a correct network type using args.network!')
579
+ return
580
+
581
+ # Create a DataLoader with the specified dataset, batch size, and worker configuration.
582
+ generator = data.DataLoader(
583
+ datasets,
584
+ batch_size=args.batch_size, # Batch size for training.
585
+ shuffle=(sampler is None), # Shuffle the data only if no sampler is used.
586
+ collate_fn=collate_fn, # Use the selected collate function for batching data.
587
+ num_workers=args.num_workers, # Number of workers for data loading.
588
+ sampler=sampler # Use the distributed sampler if applicable.
589
+ )
590
+
591
+ # Return both the sampler and DataLoader (generator).
592
+ return sampler, generator
593
+
dataloader/meldataset.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ import numpy as np
7
+ from librosa.util import normalize
8
+ from scipy.io.wavfile import read
9
+ import scipy
10
+ import librosa
11
+ import wave
12
+ from pydub import AudioSegment
13
+
14
+ MAX_WAV_VALUE = 32768.0
15
+
16
+
17
+ def load_wav(full_path):
18
+ try:
19
+ sampling_rate, data = read(full_path)
20
+ if max(data.shape) / sampling_rate < 0.5:
21
+ return None, None
22
+ except FileNotFoundError:
23
+ print(f"File not found: {file_path}")
24
+ return None, None
25
+ except Exception as e:
26
+ print(f"An unexpected error occurred: {e}")
27
+ return None, None
28
+
29
+ if len(data.shape) > 1:
30
+ if data.shape[1] <= 2:
31
+ data = data[...,0]
32
+ else:
33
+ data = data[0,...]
34
+ return data / MAX_WAV_VALUE, sampling_rate
35
+
36
+ def get_wave_duration(file_path):
37
+ """
38
+ Gets the duration of a WAV file in seconds.
39
+
40
+ :param file_path: Path to the WAV file.
41
+ :return: Duration of the WAV file in seconds.
42
+ """
43
+ try:
44
+ with wave.open(file_path, 'rb') as wf:
45
+ # Get the number of frames
46
+ num_frames = wf.getnframes()
47
+ # Get the frame rate
48
+ frame_rate = wf.getframerate()
49
+ # Calculate duration
50
+ duration = num_frames / float(frame_rate)
51
+ return duration, frame_rate, num_frames
52
+ except wave.Error as e:
53
+ print(f"Error reading {file_path}: {e}")
54
+ return None, None, None
55
+ except FileNotFoundError:
56
+ print(f"File not found: {file_path}")
57
+ return None, None, None
58
+ except Exception as e:
59
+ print(f"An unexpected error occurred: {e}")
60
+ return None, None, None
61
+
62
+ def read_audio_segment(file_path, start_ms, end_ms):
63
+ """
64
+ Reads a segment from a WAV file and returns the raw data and its properties.
65
+
66
+ :param file_path: Path to the WAV file.
67
+ :param start_ms: Start time of the segment in milliseconds.
68
+ :param end_ms: End time of the segment in milliseconds.
69
+ :return: A tuple containing the raw audio data, frame rate, sample width, and number of channels.
70
+ """
71
+ #start_time = time.time()
72
+ try:
73
+ # Load the audio file
74
+ audio = AudioSegment.from_wav(file_path)
75
+ # Extract the segment
76
+ segment = audio[start_ms:end_ms]
77
+ # Get raw audio data
78
+ raw_data = segment.raw_data
79
+ # Get audio properties
80
+ frame_rate = segment.frame_rate
81
+ sample_width = segment.sample_width
82
+ channels = segment.channels
83
+ # Create NumPy array from the raw audio data
84
+ audio_array = np.frombuffer(raw_data, dtype=np.int16)
85
+
86
+ # If stereo, reshape the array to have a second dimension
87
+ if channels > 1:
88
+ audio_array = audio_array.reshape((-1, channels))
89
+ audio_array = audio_array[...,0]
90
+ '''
91
+ if frame_rate !=48000:
92
+ audio_array = audio_array/MAX_WAV_VALUE
93
+ audio_array = librosa.resample(audio_array, frame_rate, 48000)
94
+ audio_array = audio_array * MAX_WAV_VALUE
95
+ frame_rate = 48000
96
+ '''
97
+ #end_time = time.time()
98
+ #time_taken = end_time - start_time
99
+
100
+ #print(f"Successfully read segment from {start_ms}ms to {end_ms}ms in {time_taken:.4f} seconds")
101
+ return audio_array / MAX_WAV_VALUE#, frame_rate #, sample_width, channels
102
+ except Exception as e:
103
+ print(f"An error occurred: {e}")
104
+ return None#, None #, None, None
105
+
106
+ def resample(audio, sr_in, sr_out, target_len=None):
107
+ #audio = audio / MAX_WAV_VALUE
108
+ #audio = normalize(audio) * 0.95
109
+ if target_len is not None:
110
+ audio = scipy.signal.resample(audio, target_len)
111
+ return audio
112
+ resample_factor = sr_out / sr_in
113
+ new_samples = int(len(audio) * resample_factor)
114
+ audio = scipy.signal.resample(audio, new_samples)
115
+ return audio
116
+
117
+ def load_segment(full_path, target_sampling_rate=None, segment_size=None):
118
+
119
+ if segment_size is not None:
120
+ dur,sampling_rate,len_data = get_wave_duration(full_path)
121
+ if sampling_rate is None: return None, None
122
+ if sampling_rate < 44100: return None, None
123
+
124
+ target_dur = segment_size / target_sampling_rate
125
+ if dur < target_dur:
126
+ data, sampling_rate = load_wav(full_path)
127
+ #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}')
128
+ if data is None: return None, None
129
+
130
+ if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
131
+ data = resample(data, sampling_rate, target_sampling_rate)
132
+ sampling_rate = target_sampling_rate
133
+ data = torch.FloatTensor(data)
134
+ data = data.unsqueeze(0)
135
+ data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant')
136
+ data = data.squeeze(0)
137
+ return data.numpy(), sampling_rate
138
+ else:
139
+ dur,sampling_rate,len_data = get_wave_duration(full_path)
140
+ if sampling_rate < 44100: return None, None
141
+
142
+ target_dur = segment_size / target_sampling_rate
143
+ target_len = int(target_dur * sampling_rate)
144
+ start_idx = random.randint(0, (len_data - target_len))
145
+ start_ms = start_idx / sampling_rate * 1000
146
+ end_ms = start_ms + target_dur * 1000
147
+ data = read_audio_segment(full_path, start_ms, end_ms)
148
+ #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}')
149
+ if data is None: return None, None
150
+ if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
151
+ data = resample(data, sampling_rate, target_sampling_rate)
152
+ sampling_rate = target_sampling_rate
153
+ if len(data) < segment_size:
154
+ data = torch.FloatTensor(data)
155
+ data = data.unsqueeze(0)
156
+ data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant')
157
+ data = data.squeeze(0)
158
+ data = data.numpy()
159
+ else:
160
+ start_idx = random.randint(0, (len(data) - segment_size))
161
+ data = data[start_idx:start_idx+segment_size]
162
+ #print(f'data_cut: {data.shape}')
163
+ return data, sampling_rate
164
+ else:
165
+ dur,sampling_rate,len_data = get_wave_duration(full_path)
166
+ if sampling_rate is None: return None, None
167
+ if sampling_rate < 44100: return None, None
168
+ data, sampling_rate = load_wav(full_path)
169
+ if data is None: return None, None
170
+ if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
171
+ data = resample(data, sampling_rate, target_sampling_rate)
172
+ sampling_rate = target_sampling_rate
173
+ return data, sampling_rate
174
+
175
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
176
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
177
+
178
+
179
+ def dynamic_range_decompression(x, C=1):
180
+ return np.exp(x) / C
181
+
182
+
183
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
184
+ return torch.log(torch.clamp(x, min=clip_val) * C)
185
+
186
+
187
+ def dynamic_range_decompression_torch(x, C=1):
188
+ return torch.exp(x) / C
189
+
190
+
191
+ def spectral_normalize_torch(magnitudes):
192
+ output = dynamic_range_compression_torch(magnitudes)
193
+ return output
194
+
195
+
196
+ def spectral_de_normalize_torch(magnitudes):
197
+ output = dynamic_range_decompression_torch(magnitudes)
198
+ return output
199
+
200
+
201
+ mel_basis = {}
202
+ hann_window = {}
203
+
204
+
205
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
206
+ '''
207
+ if torch.min(y) < -1.:
208
+ print('min value is ', torch.min(y))
209
+ if torch.max(y) > 1.:
210
+ print('max value is ', torch.max(y))
211
+ '''
212
+ global mel_basis, hann_window
213
+ if fmax not in mel_basis:
214
+ #mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
215
+ # sr, n_fft, n_mels=128, fmin=0.0, fmax
216
+ mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
217
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
218
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
219
+
220
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
221
+ y = y.squeeze(1)
222
+
223
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
224
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
225
+
226
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
227
+
228
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
229
+ spec = spectral_normalize_torch(spec)
230
+
231
+ return spec
232
+
233
+
234
+ def get_dataset_filelist_org(a):
235
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
236
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
237
+ for x in fi.read().split('\n') if len(x) > 0]
238
+
239
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
240
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
241
+ for x in fi.read().split('\n') if len(x) > 0]
242
+ return training_files, validation_files
243
+
244
+ def get_dataset_filelist(a):
245
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
246
+ training_files = [x for x in fi.read().split('\n') if len(x) > 0]
247
+
248
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
249
+ validation_files = [x for x in fi.read().split('\n') if len(x) > 0]
250
+
251
+ return training_files, validation_files
252
+
253
+ class MelDataset(torch.utils.data.Dataset):
254
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
255
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
256
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
257
+ self.audio_files = training_files
258
+ random.seed(1234)
259
+ if shuffle:
260
+ random.shuffle(self.audio_files)
261
+ self.segment_size = segment_size
262
+ self.sampling_rate = sampling_rate
263
+ self.split = split
264
+ self.n_fft = n_fft
265
+ self.num_mels = num_mels
266
+ self.hop_size = hop_size
267
+ self.win_size = win_size
268
+ self.fmin = fmin
269
+ self.fmax = fmax
270
+ self.fmax_loss = fmax_loss
271
+ self.cached_wav = None
272
+ self.n_cache_reuse = n_cache_reuse
273
+ self._cache_ref_count = 0
274
+ self.device = device
275
+ self.fine_tuning = fine_tuning
276
+ self.base_mels_path = base_mels_path
277
+ self.supported_samples = [16000, 22050, 24000] #[4000, 8000, 16000, 22050, 24000, 32000]
278
+ #self.supported_samples = [4000, 8000] #, 16000, 22050, 24000, 32000]
279
+
280
+ def __getitem__(self, index):
281
+ filename = self.audio_files[index]
282
+ while 1:
283
+ #audio, sampling_rate = load_wav(filename)
284
+ audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size)
285
+ if audio is not None: break
286
+ else:
287
+ filename = self.audio_files[random.randint(0,index)]
288
+ #audio, sampling_rate = load_wav(filename)
289
+ #audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size)
290
+
291
+ #audio = audio / MAX_WAV_VALUE
292
+ if not self.fine_tuning:
293
+ audio = normalize(audio) * 0.95
294
+
295
+ sr_out = random.choice(self.supported_samples)
296
+ audio_down = resample(audio, self.sampling_rate, sr_out)
297
+
298
+ target_len = len(audio) #/ downsample_factor
299
+ audio_up = resample(audio_down, None, None, target_len)
300
+
301
+ audio = torch.FloatTensor(audio)
302
+ audio = audio.unsqueeze(0)
303
+ audio_up = torch.FloatTensor(audio_up)
304
+ audio_up = audio_up.unsqueeze(0)
305
+
306
+ mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels,
307
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
308
+ center=False)
309
+
310
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
311
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
312
+ center=False)
313
+
314
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
315
+
316
+ def __getitem__org(self, index):
317
+ filename = self.audio_files[index]
318
+ if self._cache_ref_count == 0:
319
+ while 1:
320
+ audio, sampling_rate = load_wav(filename)
321
+ if audio is not None: break
322
+ else:
323
+ filename = self.audio_files[random.randint(0,index)]
324
+ audio, sampling_rate = load_wav(filename)
325
+
326
+ audio = audio / MAX_WAV_VALUE
327
+ if not self.fine_tuning:
328
+ audio = normalize(audio) * 0.95
329
+ #self.cached_wav = audio
330
+ if sampling_rate != self.sampling_rate:
331
+ resample_factor = self.sampling_rate / sampling_rate
332
+ new_samples = int(len(audio) * resample_factor)
333
+ audio = scipy.signal.resample(audio, new_samples)#.astype(np.int16)
334
+ #raise ValueError("{} SR doesn't match target {} SR".format(
335
+ # sampling_rate, self.sampling_rate))
336
+
337
+ downsample_factor = 16000 / self.sampling_rate
338
+ new_samples = int(len(audio) * downsample_factor)
339
+ audio_down = scipy.signal.resample(audio, new_samples)
340
+
341
+ new_samples = len(audio) #/ downsample_factor
342
+ audio_up = scipy.signal.resample(audio_down, new_samples)
343
+ #print(f'audio: {audio.shape}, audio_up: {audio_up.shape}')
344
+ #min_idx = min(len(audio), len(audio_up))
345
+ #audio = audio[:min_idx]
346
+ #audio_up = audio_up[:min_idx]
347
+
348
+ self.cached_wav = audio
349
+ self.cached_wav_up = audio_up
350
+ self._cache_ref_count = self.n_cache_reuse
351
+ else:
352
+ audio = self.cached_wav
353
+ audio_up = self.cached_wav_up
354
+ self._cache_ref_count -= 1
355
+
356
+ audio = torch.FloatTensor(audio)
357
+ audio = audio.unsqueeze(0)
358
+ audio_up = torch.FloatTensor(audio_up)
359
+ audio_up = audio_up.unsqueeze(0)
360
+
361
+ if True:
362
+ if self.split:
363
+ if audio.size(1) >= self.segment_size:
364
+ max_audio_start = audio.size(1) - self.segment_size
365
+ audio_start = random.randint(0, max_audio_start)
366
+ audio = audio[:, audio_start:audio_start+self.segment_size]
367
+ audio_up = audio_up[:, audio_start:audio_start+self.segment_size]
368
+ else:
369
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
370
+ audio_up = torch.nn.functional.pad(audio_up, (0, self.segment_size - audio_up.size(1)), 'constant')
371
+
372
+ mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels,
373
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
374
+ center=False)
375
+
376
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
377
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
378
+ center=False)
379
+
380
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
381
+
382
+ def __len__(self):
383
+ return len(self.audio_files)
dataloader/misc.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python -u
3
+ # -*- coding: utf-8 -*-
4
+
5
+ from __future__ import absolute_import
6
+ from __future__ import division
7
+ from __future__ import print_function
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import os
12
+ import sys
13
+ import librosa
14
+ import mimetypes
15
+
16
+ def get_file_extension(file_path):
17
+ """
18
+ Return an audio file extension
19
+ """
20
+
21
+ _, ext = os.path.splitext(file_path)
22
+ return ext
23
+
24
+ def is_audio_file(file_path):
25
+ """
26
+ Check if the given file_path is an audio file
27
+ Return True if it is an audio file, otherwise, return False
28
+ """
29
+ file_ext = ["wav", "aac", "ac3", "aiff", "flac", "m4a", "mp3", "ogg", "opus", "wma", "webm"]
30
+
31
+ ext = get_file_extension(file_path)
32
+ if ext.replace('.','') in file_ext:
33
+ return True
34
+
35
+ mime_type, _ = mimetypes.guess_type(file_path)
36
+ if mime_type and mime_type.startswith('audio'):
37
+ return True
38
+ return False
39
+
40
+ def read_and_config_file(args, input_path, decode=0):
41
+ """
42
+ Reads and processes the input file or directory to extract audio file paths or configuration data.
43
+
44
+ Parameters:
45
+ args: The args
46
+ input_path (str): Path to a file or directory containing audio data or file paths.
47
+ decode (bool): If True (decode=1) for decoding, process the input as audio files directly (find .wav or .flac files) or from a .scp file.
48
+ If False (decode=0) for training, assume the input file contains lines with paths to audio files.
49
+
50
+ Returns:
51
+ processed_list (list): A list of processed file paths or a list of dictionaries containing input
52
+ and optional condition audio paths.
53
+ """
54
+ processed_list = [] # Initialize list to hold processed file paths or configurations
55
+
56
+ #The supported audio types are listed below (tested), but not limited to.
57
+ file_ext = ["wav", "aac", "ac3", "aiff", "flac", "m4a", "mp3", "ogg", "opus", "wma", "webm"]
58
+
59
+ if decode:
60
+ if args.task == 'target_speaker_extraction':
61
+ if args.network_reference.cue== 'lip':
62
+ # If decode is True, find video files in a directory or single file
63
+ if os.path.isdir(input_path):
64
+ # Find all .mp4 , mov .avi files in the input directory
65
+ processed_list = librosa.util.find_files(input_path, ext="mp4")
66
+ processed_list += librosa.util.find_files(input_path, ext="avi")
67
+ processed_list += librosa.util.find_files(input_path, ext="mov")
68
+ processed_list += librosa.util.find_files(input_path, ext="MOV")
69
+ processed_list += librosa.util.find_files(input_path, ext="webm")
70
+ else:
71
+ # If it's a single file and it's a .wav or .flac, add to processed list
72
+ if input_path.lower().endswith(".mp4") or input_path.lower().endswith(".avi") or input_path.lower().endswith(".mov") or input_path.lower().endswith(".webm"):
73
+ processed_list.append(input_path)
74
+ else:
75
+ # Read file paths from the input text file (one path per line)
76
+ with open(input_path) as fid:
77
+ for line in fid:
78
+ path_s = line.strip().split() # Split paths (space-separated)
79
+ processed_list.append(path_s[0]) # Add the first path (input audio path)
80
+ return processed_list
81
+
82
+ # If decode is True, find audio files in a directory or single file
83
+ if os.path.isdir(input_path):
84
+ # Find all .wav files in the input directory
85
+ processed_list = librosa.util.find_files(input_path, ext=file_ext)
86
+ else:
87
+ # If it's a single file and it's a .wav or .flac, add to processed list
88
+ #if input_path.lower().endswith(".wav") or input_path.lower().endswith(".flac"):
89
+ if is_audio_file(input_path):
90
+ processed_list.append(input_path)
91
+ else:
92
+ # Read file paths from the input text file (one path per line)
93
+ with open(input_path) as fid:
94
+ for line in fid:
95
+ path_s = line.strip().split() # Split paths (space-separated)
96
+ processed_list.append(path_s[0]) # Add the first path (input audio path)
97
+ return processed_list
98
+
99
+ # If decode is False, treat the input file as a configuration file
100
+ with open(input_path) as fid:
101
+ for line in fid:
102
+ tmp_paths = line.strip().split() # Split paths (space-separated)
103
+ if len(tmp_paths) == 2:
104
+ # If two paths per line, treat the second as 'condition_audio'
105
+ sample = {'inputs': tmp_paths[0], 'condition_audio': tmp_paths[1]}
106
+ elif len(tmp_paths) == 1:
107
+ # If only one path per line, treat it as 'inputs'
108
+ sample = {'inputs': tmp_paths[0]}
109
+ processed_list.append(sample) # Append processed sample to list
110
+ return processed_list
111
+