alibabasglab commited on
Commit
8e8cd3e
·
verified ·
1 Parent(s): af5b0c7

Upload 161 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/clearvoice.cpython-37.pyc +0 -0
  2. __pycache__/clearvoice.cpython-38.pyc +0 -0
  3. __pycache__/network_wrapper.cpython-37.pyc +0 -0
  4. __pycache__/network_wrapper.cpython-38.pyc +0 -0
  5. __pycache__/networks.cpython-38.pyc +0 -0
  6. checkpoints/AV_MossFormer2_TSE_16K/config.yaml +55 -0
  7. checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint +1 -0
  8. checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_enhance.pt +3 -0
  9. checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_separate.pt +3 -0
  10. checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_tmp.pt +3 -0
  11. checkpoints/FRCRN_SE_16K/config.yaml +33 -0
  12. checkpoints/FRCRN_SE_16K/last_best_checkpoint +1 -0
  13. checkpoints/FRCRN_SE_16K/last_checkpoint +1 -0
  14. checkpoints/FRCRN_SE_16K/model.ckpt-88-8491630.pt +3 -0
  15. clearvoice.py +62 -0
  16. config/inference/AV_MossFormer2_TSE_16K.yaml +41 -0
  17. config/inference/FRCRN_SE_16K.yaml +20 -0
  18. config/inference/MossFormer2_SE_48K.yaml +22 -0
  19. config/inference/MossFormer2_SS_16K.yaml +21 -0
  20. config/inference/MossFormerGAN_SE_16K.yaml +22 -0
  21. config/inference/SpEx_plus_TSE_8K.yaml +18 -0
  22. dataloader/__pycache__/dataloader.cpython-38.pyc +0 -0
  23. dataloader/__pycache__/misc.cpython-38.pyc +0 -0
  24. dataloader/dataloader.py +496 -0
  25. dataloader/misc.py +84 -0
  26. demo.py +70 -0
  27. demo_with_detailed_comments.py +61 -0
  28. input.wav +0 -0
  29. models/.DS_Store +0 -0
  30. models/__pycache__/__init__.cpython-36.pyc +0 -0
  31. models/__pycache__/__init__.cpython-37.pyc +0 -0
  32. models/__pycache__/__init__.cpython-38.pyc +0 -0
  33. models/__pycache__/complex_nn.cpython-36.pyc +0 -0
  34. models/__pycache__/complex_nn.cpython-37.pyc +0 -0
  35. models/__pycache__/complex_nn.cpython-38.pyc +0 -0
  36. models/__pycache__/constant.cpython-36.pyc +0 -0
  37. models/__pycache__/constant.cpython-37.pyc +0 -0
  38. models/__pycache__/constant.cpython-38.pyc +0 -0
  39. models/__pycache__/conv_stft.cpython-38.pyc +0 -0
  40. models/__pycache__/criterion.cpython-36.pyc +0 -0
  41. models/__pycache__/criterion.cpython-37.pyc +0 -0
  42. models/__pycache__/criterion.cpython-38.pyc +0 -0
  43. models/__pycache__/frcrn.cpython-38.pyc +0 -0
  44. models/__pycache__/metric.cpython-36.pyc +0 -0
  45. models/__pycache__/noisedataset.cpython-36.pyc +0 -0
  46. models/__pycache__/noisedataset.cpython-37.pyc +0 -0
  47. models/__pycache__/noisedataset.cpython-38.pyc +0 -0
  48. models/__pycache__/phasen_dccrn.cpython-36.pyc +0 -0
  49. models/__pycache__/phasen_dccrn.cpython-37.pyc +0 -0
  50. models/__pycache__/phasen_dccrn.cpython-38.pyc +0 -0
__pycache__/clearvoice.cpython-37.pyc ADDED
Binary file (2.13 kB). View file
 
__pycache__/clearvoice.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
 
__pycache__/network_wrapper.cpython-37.pyc ADDED
Binary file (8.18 kB). View file
 
__pycache__/network_wrapper.cpython-38.pyc ADDED
Binary file (7.39 kB). View file
 
__pycache__/networks.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
checkpoints/AV_MossFormer2_TSE_16K/config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Config file
2
+
3
+ # Log
4
+ seed: 777
5
+ use_cuda: 1 # 1 for True, 0 for False
6
+
7
+ # dataset
8
+ speaker_no: 2
9
+ mix_lst_path: ./data/allData/voxceleb2/mixture_data_list_2mix_pretrain.csv
10
+ audio_direc: /mnt/nas_sg/wulanchabu/zexu.pan/datasets/voxceleb2/audio_clean
11
+ reference_direc: /mnt/nas_sg/wulanchabu/zexu.pan/datasets/ # not used
12
+ audio_sr: 16000
13
+ ref_sr: 25
14
+
15
+ # dataloader
16
+ num_workers: 4
17
+ batch_size: 2 # 4-GPU training with a total effective batch size of 8
18
+ accu_grad: 0
19
+ effec_batch_size: 4 # per GPU, only used if accu_grad is set to 1, must be multiple times of batch size
20
+ max_length: 5 # truncate the utterances in dataloader, in seconds
21
+
22
+ # network settings
23
+ init_from: checkpoints/log_2024-09-30(09:49:14) # 'None' or a log name 'log_2024-07-22(18:12:13)'
24
+ causal: 0 # 1 for True, 0 for False
25
+ network_reference:
26
+ cue: lip # lip or speech or gesture or EEG
27
+ backbone: resnet18 # resnet18 or shufflenetV2 or blazenet64
28
+ emb_size: 256 # resnet18:256
29
+ network_audio:
30
+ backbone: mossformer2
31
+ encoder_kernel_size: 16
32
+ encoder_out_nchannels: 512
33
+ encoder_in_nchannels: 1
34
+
35
+ masknet_numspks: 1
36
+ masknet_chunksize: 250
37
+ masknet_numlayers: 1
38
+ masknet_norm: "ln"
39
+ masknet_useextralinearlayer: False
40
+ masknet_extraskipconnection: True
41
+
42
+ intra_numlayers: 24
43
+ intra_nhead: 8
44
+ intra_dffn: 1024
45
+ intra_dropout: 0
46
+ intra_use_positional: True
47
+ intra_norm_before: True
48
+
49
+
50
+ # optimizer
51
+ loss_type: hybrid # "snr", "sisdr", "hybrid"
52
+ init_learning_rate: 0.00015
53
+ max_epoch: 150
54
+ clip_grad_norm: 5
55
+
checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint ADDED
@@ -0,0 +1 @@
 
 
1
+ last_best_checkpoint_tmp.pt
checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_enhance.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6b416073f66c7a9faa84ad8088bf4ae69c946f6c2ea3db2e7c6ead1a1fca088
3
+ size 134
checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_separate.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb45b197224686bbe11c3898b32b9da84533572bb93751f75586087fda43193b
3
+ size 134
checkpoints/AV_MossFormer2_TSE_16K/last_best_checkpoint_tmp.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:981fe2b4a3e912e10919a41e674606870b61b8cb00e4f15ca97984b0144fc61a
3
+ size 134
checkpoints/FRCRN_SE_16K/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'train'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+
5
+ sampling_rate: 16000
6
+ network: "FRCRN_SE_16K" ##network type
7
+ ## FFT Parameters
8
+ win_type: hanning
9
+ win_len: 640
10
+ win_inc: 320
11
+ fft_len: 640
12
+
13
+ # Train
14
+ #tr_list: 'datasets/tr_tts_16k_noise_0to10db_p13_p20.lst_dur'
15
+ tr_list: 'data/cv_webrtc_test_set_20200521_16k.lst'
16
+ cv_list: 'data/cv_webrtc_test_set_20200521_16k.lst'
17
+ init_learning_rate: 0.001 #learning rate for a new training
18
+ finetune_learning_rate: 0.0001 #learning rate for a finetune training
19
+ max_epoch: 100
20
+
21
+ weight_decay: 0.00001
22
+ clip_grad_norm: 10.
23
+
24
+ # Log
25
+ seed: 777
26
+
27
+ # # dataset
28
+ num_workers: 4
29
+ batch_size: 4
30
+ accu_grad: 1 # accumulate multiple batch sizes for one back-propagation updating
31
+ effec_batch_size: 12 # per GPU, only used if accu_grad is set to 1, must be multiple times of batch size
32
+ max_length: 1 # truncate the utterances in dataloader, in seconds
33
+
checkpoints/FRCRN_SE_16K/last_best_checkpoint ADDED
@@ -0,0 +1 @@
 
 
1
+ model.ckpt-88-8491630.pt
checkpoints/FRCRN_SE_16K/last_checkpoint ADDED
@@ -0,0 +1 @@
 
 
1
+ model.ckpt-88-8491630.pt
checkpoints/FRCRN_SE_16K/model.ckpt-88-8491630.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b22256adbb91b68cf5a3db8f6657a4fb17066eecd5f069803e59c186c1cf3ebb
3
+ size 161053751
clearvoice.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from network_wrapper import network_wrapper
2
+
3
+ class ClearVoice:
4
+ """ The main class inferface to the end users for performing speech processing
5
+ this class provides the desired model to perform the given task
6
+ """
7
+ def __init__(self, task, model_names):
8
+ """ Load the desired models for the specified task. Perform all the given models and return all results.
9
+
10
+ Parameters:
11
+ ----------
12
+ task: str
13
+ the task matching any of the provided tasks:
14
+ 'speech_enhancement'
15
+ 'speech_separation'
16
+ 'target_speaker_extraction'
17
+ model_names: str or list of str
18
+ the model names matching any of the provided models:
19
+ 'FRCRN_SE_16K'
20
+ 'MossFormer2_SE_48K'
21
+ 'MossFormerGAN_SE_16K'
22
+ 'MossFormer2_SS_16K'
23
+ 'AV_MossFormer2_TSE_16K'
24
+
25
+ Returns:
26
+ --------
27
+ A ModelsList object, that can be run to get the desired results
28
+ """
29
+ self.network_wrapper = network_wrapper()
30
+ self.models = []
31
+ for model_name in model_names:
32
+ model = self.network_wrapper(task, model_name)
33
+ self.models += [model]
34
+
35
+ def __call__(self, input_path, online_write=False, output_path=None):
36
+ results = {}
37
+ for model in self.models:
38
+ result = model.process(input_path, online_write, output_path)
39
+ if not online_write:
40
+ results[model.name] = result
41
+
42
+ if not online_write:
43
+ if len(results) == 1:
44
+ return results[model.name]
45
+ else:
46
+ return results
47
+
48
+ def write(self, results, output_path):
49
+ add_subdir = False
50
+ use_key = False
51
+ if len(self.models) > 1: add_subdir = True #multi_model is True
52
+ for model in self.models:
53
+ if isinstance(results, dict):
54
+ if model.name in results:
55
+ if len(results[model.name]) > 1: use_key = True
56
+
57
+ else:
58
+ if len(results) > 1: use_key = True #multi_input is True
59
+ break
60
+
61
+ for model in self.models:
62
+ model.write(output_path, add_subdir, use_key)
config/inference/AV_MossFormer2_TSE_16K.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 16000
6
+ network: "AV_MossFormer2_TSE_16K" # network type
7
+ checkpoint_dir: "checkpoints/AV_MossFormer2_TSE_16K"
8
+
9
+ input_path: "scp/video_samples.scp" # an input dir or input scp file
10
+ output_dir: "path_to_output_videos_tse" # output dir to store processed audio
11
+
12
+ # decode parameters
13
+ one_time_decode_length: 3 # maximum segment length for one-pass decoding (seconds), longer audio (>5s) will use segmented decoding
14
+ decode_window: 3 # one-pass decoding length
15
+
16
+
17
+ # Model-specific settings for target speaker extraction
18
+ network_reference:
19
+ cue: lip
20
+ backbone: resnet18
21
+ emb_size: 256
22
+ network_audio:
23
+ backbone: mossformer2
24
+ encoder_kernel_size: 16
25
+ encoder_out_nchannels: 512
26
+ encoder_in_nchannels: 1
27
+
28
+ masknet_numspks: 1
29
+ masknet_chunksize: 250
30
+ masknet_numlayers: 1
31
+ masknet_norm: "ln"
32
+ masknet_useextralinearlayer: False
33
+ masknet_extraskipconnection: True
34
+
35
+ intra_numlayers: 24
36
+ intra_nhead: 8
37
+ intra_dffn: 1024
38
+ intra_dropout: 0
39
+ intra_use_positional: True
40
+ intra_norm_before: True
41
+
config/inference/FRCRN_SE_16K.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 16000
6
+ network: "FRCRN_SE_16K" ##network type
7
+ checkpoint_dir: "checkpoints/FRCRN_SE_16K"
8
+ #input_path: "data/cv_webrtc_test_set_20200521_16k.scp" # an input dir or input scp file
9
+ input_path: "/home/shengkui.zhao/DingTalk_NS/data/webrtc_test_set_20200521_16k/noisy"
10
+ output_dir: "outputs/FRCRN_SE_16K" ## output dir to store processed audio
11
+
12
+ # decode parameters
13
+ one_time_decode_length: 120 #maximum segment length for one-pass decoding (seconds), longer audio will use segmented decoding
14
+ decode_window: 1 #one-pass decoding length
15
+ #
16
+ # FFT parameters
17
+ win_type: 'hanning'
18
+ win_len: 640
19
+ win_inc: 320
20
+ fft_len: 640
config/inference/MossFormer2_SE_48K.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 48000
6
+ network: "MossFormer2_SE_48K" ##network type
7
+ checkpoint_dir: "checkpoints/MossFormer2_SE_48K"
8
+
9
+ #input_path: support wav dir or wav scp or a wav file
10
+ input_path: "/mnt/nas/mit_sg/shengkui.zhao/DNS-Challenge/datasets/test_set/synthetic/no_reverb/noisy"
11
+ output_dir: "outputs/MossFormer2_SE_48K_dns_2020_noreverb"
12
+
13
+ # decode parameters
14
+ one_time_decode_length: 20 #maximum segment length for one-pass decoding (seconds), longer audio will use segmented decoding
15
+ decode_window: 4 #one-pass decoding length
16
+
17
+ # FFT parameters
18
+ win_type: 'hamming'
19
+ win_len: 1920
20
+ win_inc: 384
21
+ fft_len: 1920
22
+ num_mels: 60
config/inference/MossFormer2_SS_16K.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 16000
6
+ network: "MossFormer2_SS_16K" ##network type
7
+ checkpoint_dir: "checkpoints/MossFormer2_SS_16K"
8
+ input_path: "data/wsj0_2mix_16k_fullpath.lst" # an input dir or input scp file
9
+ #input_path: "/home/shengkui.zhao/DingTalk_NS/data/webrtc_test_set_20200521_16k/noisy"
10
+ #input_path: "/mnt/nas_sg/mit_sg/shengkui.zhao/ComplexNN/audio/youtube_testset_16k/noisy_long/noisy"
11
+ output_dir: "outputs/MossFormer2_SS_16K_wsj0_2mix" ## output dir to store processed audio
12
+
13
+ # decode parameters
14
+ one_time_decode_length: 30 #maximum segment length for one-pass decoding (seconds), longer audio (>3s) will use segmented decoding
15
+ decode_window: 10 #one-pass decoding length
16
+
17
+ num_spks: 2
18
+ encoder_kernel_size: 16
19
+ encoder_embedding_dim: 512
20
+ mossformer_sequence_dim: 512
21
+ num_mossformer_layer: 24
config/inference/MossFormerGAN_SE_16K.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 16000
6
+ network: "MossFormerGAN_SE_16K" ##network type
7
+ checkpoint_dir: "checkpoints/MossFormerGAN_SE_16K"
8
+
9
+ #input_path: "data/cv_webrtc_test_set_20200521_16k.scp" # an input dir or input scp file
10
+ #input_path: "/home/shengkui.zhao/DingTalk_NS/data/webrtc_test_set_20200521_16k/noisy"
11
+ input_path: "/mnt/nas_sg/mit_sg/shengkui.zhao/ComplexNN/audio/youtube_testset_16k/noisy_long/noisy"
12
+ output_dir: "outputs/MossFormerGAN_SE_16K" ## output dir to store processed audio
13
+
14
+ # decode parameters
15
+ one_time_decode_length: 10 #maximum segment length for one-pass decoding (seconds), longer audio will use segmented decoding
16
+ decode_window: 10 #one-pass decoding length
17
+
18
+ # FFT parameters
19
+ win_type: 'hamming'
20
+ win_len: 400
21
+ win_inc: 100
22
+ fft_len: 400
config/inference/SpEx_plus_TSE_8K.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mode: 'inference'
3
+ use_cuda: 1 # 1 for True, 0 for False
4
+ num_gpu: 1
5
+ sampling_rate: 8000
6
+ network: "SpEx_plus_TSE_8K" ##network type
7
+ checkpoint_dir: "checkpoints/SpEx_plus_TSE_8K"
8
+ input_path: "data/wsj0_2mix_16k_fullpath.lst" # an input dir or input scp file
9
+ #input_path: "/home/shengkui.zhao/DingTalk_NS/data/webrtc_test_set_20200521_16k/noisy"
10
+ #input_path: "/mnt/nas_sg/mit_sg/shengkui.zhao/ComplexNN/audio/youtube_testset_16k/noisy_long/noisy"
11
+ output_dir: "outputs/MossFormer2_SS_16K_wsj0_2mix" ## output dir to store processed audio
12
+
13
+ # decode parameters
14
+ one_time_decode_length: 5 #maximum segment length for one-pass decoding (seconds), longer audio (>3s) will use segmented decoding
15
+ decode_window: 1 #one-pass decoding length
16
+
17
+
18
+
dataloader/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (14.2 kB). View file
 
dataloader/__pycache__/misc.cpython-38.pyc ADDED
Binary file (2.15 kB). View file
 
dataloader/dataloader.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math, os, csv
3
+ import torchaudio
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.data as data
7
+ import torch.distributed as dist
8
+ import soundfile as sf
9
+ from torch.utils.data import Dataset
10
+ import torch.utils.data as data
11
+ import os
12
+ import sys
13
+ sys.path.append(os.path.dirname(__file__))
14
+
15
+ from dataloader.misc import read_and_config_file
16
+ import librosa
17
+ import random
18
+ EPS = 1e-6
19
+ MAX_WAV_VALUE = 32768.0
20
+
21
+ def audioread(path, sampling_rate):
22
+ """
23
+ Reads an audio file from the specified path, normalizes the audio,
24
+ resamples it to the desired sampling rate (if necessary), and ensures it is single-channel.
25
+
26
+ Parameters:
27
+ path (str): The file path of the audio file to be read.
28
+ sampling_rate (int): The target sampling rate for the audio.
29
+
30
+ Returns:
31
+ numpy.ndarray: The processed audio data, normalized, resampled (if necessary),
32
+ and converted to mono (if the input audio has multiple channels).
33
+ """
34
+
35
+ # Read audio data and its sample rate from the file.
36
+ data, fs = sf.read(path)
37
+
38
+ # Normalize the audio data.
39
+ data = audio_norm(data)
40
+
41
+ # Resample the audio if the sample rate is different from the target sampling rate.
42
+ if fs != sampling_rate:
43
+ data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate)
44
+
45
+ # Convert to mono by selecting the first channel if the audio has multiple channels.
46
+ if len(data.shape) > 1:
47
+ data = data[:, 0]
48
+
49
+ # Return the processed audio data.
50
+ return data
51
+
52
+ def audio_norm(x):
53
+ """
54
+ Normalizes the input audio signal to a target Root Mean Square (RMS) level,
55
+ applying two stages of scaling. This ensures the audio signal is neither too quiet
56
+ nor too loud, keeping its amplitude consistent.
57
+
58
+ Parameters:
59
+ x (numpy.ndarray): Input audio signal to be normalized.
60
+
61
+ Returns:
62
+ numpy.ndarray: Normalized audio signal.
63
+ """
64
+
65
+ # Compute the root mean square (RMS) of the input audio signal.
66
+ rms = (x ** 2).mean() ** 0.5
67
+
68
+ # Calculate the scalar to adjust the signal to the target level (-25 dB).
69
+ scalar = 10 ** (-25 / 20) / (rms + EPS)
70
+
71
+ # Scale the input audio by the computed scalar.
72
+ x = x * scalar
73
+
74
+ # Compute the power of the scaled audio signal.
75
+ pow_x = x ** 2
76
+
77
+ # Calculate the average power of the audio signal.
78
+ avg_pow_x = pow_x.mean()
79
+
80
+ # Compute RMS only for audio segments with higher-than-average power.
81
+ rmsx = pow_x[pow_x > avg_pow_x].mean() ** 0.5
82
+
83
+ # Calculate another scalar to further normalize based on higher-power segments.
84
+ scalarx = 10 ** (-25 / 20) / (rmsx + EPS)
85
+
86
+ # Apply the second scalar to the audio.
87
+ x = x * scalarx
88
+
89
+ # Return the doubly normalized audio signal.
90
+ return x
91
+
92
+ class DataReader(object):
93
+ """
94
+ A class for reading audio data from a list of files, normalizing it,
95
+ and extracting features for further processing. It supports extracting
96
+ features from each file, reshaping the data, and returning metadata
97
+ like utterance ID and data length.
98
+
99
+ Parameters:
100
+ args: Arguments containing the input path and target sampling rate.
101
+
102
+ Attributes:
103
+ file_list (list): A list of audio file paths to process.
104
+ sampling_rate (int): The target sampling rate for audio files.
105
+ """
106
+
107
+ def __init__(self, args):
108
+ # Read and configure the file list from the input path provided in the arguments.
109
+ # The file list is decoded, if necessary.
110
+ self.file_list = read_and_config_file(args, args.input_path, decode=True)
111
+
112
+ # Store the target sampling rate.
113
+ self.sampling_rate = args.sampling_rate
114
+
115
+ # Store the args file
116
+ self.args = args
117
+
118
+ def __len__(self):
119
+ """
120
+ Returns the number of audio files in the file list.
121
+
122
+ Returns:
123
+ int: Number of files to process.
124
+ """
125
+ return len(self.file_list)
126
+
127
+ def __getitem__(self, index):
128
+ """
129
+ Retrieves the features of the audio file at the given index.
130
+
131
+ Parameters:
132
+ index (int): Index of the file in the file list.
133
+
134
+ Returns:
135
+ tuple: Features (inputs, utterance ID, data length) for the selected audio file.
136
+ """
137
+ if self.args.task == 'target_speaker_extraction':
138
+ if self.args.network_reference.cue== 'lip':
139
+ return self.file_list[index]
140
+ return self.extract_feature(self.file_list[index])
141
+
142
+ def extract_feature(self, path):
143
+ """
144
+ Extracts features from the given audio file path.
145
+
146
+ Parameters:
147
+ path (str): The file path of the audio file.
148
+
149
+ Returns:
150
+ inputs (numpy.ndarray): Reshaped audio data for further processing.
151
+ utt_id (str): The unique identifier of the audio file, usually the filename.
152
+ length (int): The length of the original audio data.
153
+ """
154
+ # Extract the utterance ID from the file path (usually the filename).
155
+ utt_id = path.split('/')[-1]
156
+
157
+ # Read and normalize the audio data, converting it to float32 for processing.
158
+ data = audioread(path, self.sampling_rate).astype(np.float32)
159
+
160
+ # Reshape the data to ensure it's in the format [1, data_length].
161
+ inputs = np.reshape(data, [1, data.shape[0]])
162
+
163
+ # Return the reshaped audio data, utterance ID, and the length of the original data.
164
+ return inputs, utt_id, data.shape[0]
165
+
166
+ class Wave_Processor(object):
167
+ """
168
+ A class for processing audio data, specifically for reading input and label audio files,
169
+ segmenting them into fixed-length segments, and applying padding or trimming as necessary.
170
+
171
+ Methods:
172
+ process(path, segment_length, sampling_rate):
173
+ Processes audio data by reading, padding, or segmenting it to match the specified segment length.
174
+
175
+ Parameters:
176
+ path (dict): A dictionary containing file paths for 'inputs' and 'labels' audio files.
177
+ segment_length (int): The desired length of audio segments to extract.
178
+ sampling_rate (int): The target sampling rate for reading the audio files.
179
+ """
180
+
181
+ def process(self, path, segment_length, sampling_rate):
182
+ """
183
+ Reads input and label audio files, and ensures the audio is segmented into
184
+ the desired length, padding if necessary or extracting random segments if
185
+ the audio is longer than the target segment length.
186
+
187
+ Parameters:
188
+ path (dict): Dictionary containing the paths to 'inputs' and 'labels' audio files.
189
+ segment_length (int): Desired length of the audio segment in samples.
190
+ sampling_rate (int): Target sample rate for the audio.
191
+
192
+ Returns:
193
+ tuple: A pair of numpy arrays representing the processed input and label audio,
194
+ either padded to the segment length or trimmed.
195
+ """
196
+ # Read the input and label audio files using the target sampling rate.
197
+ wave_inputs = audioread(path['inputs'], sampling_rate)
198
+ wave_labels = audioread(path['labels'], sampling_rate)
199
+
200
+ # Get the length of the label audio (assumed both inputs and labels have similar lengths).
201
+ len_wav = wave_labels.shape[0]
202
+
203
+ # If the input audio is shorter than the desired segment length, pad it with zeros.
204
+ if wave_inputs.shape[0] < segment_length:
205
+ # Create zero-padded arrays for inputs and labels.
206
+ padded_inputs = np.zeros(segment_length, dtype=np.float32)
207
+ padded_labels = np.zeros(segment_length, dtype=np.float32)
208
+
209
+ # Copy the original audio into the padded arrays.
210
+ padded_inputs[:wave_inputs.shape[0]] = wave_inputs
211
+ padded_labels[:wave_labels.shape[0]] = wave_labels
212
+ else:
213
+ # Randomly select a start index for segmenting the audio if it's longer than the segment length.
214
+ st_idx = random.randint(0, len_wav - segment_length)
215
+
216
+ # Extract a segment of the desired length from the inputs and labels.
217
+ padded_inputs = wave_inputs[st_idx:st_idx + segment_length]
218
+ padded_labels = wave_labels[st_idx:st_idx + segment_length]
219
+
220
+ # Return the processed (padded or segmented) input and label audio.
221
+ return padded_inputs, padded_labels
222
+
223
+ class Fbank_Processor(object):
224
+ """
225
+ A class for processing input audio data into mel-filterbank (Fbank) features,
226
+ including the computation of delta and delta-delta features.
227
+
228
+ Methods:
229
+ process(inputs, args):
230
+ Processes the raw audio input and returns the mel-filterbank features
231
+ along with delta and delta-delta features.
232
+ """
233
+
234
+ def process(self, inputs, args):
235
+ # Convert frame length and shift from seconds to milliseconds.
236
+ frame_length = int(args.win_len / args.sampling_rate * 1000)
237
+ frame_shift = int(args.win_inc / args.sampling_rate * 1000)
238
+
239
+ # Set up configuration for the mel-filterbank computation.
240
+ fbank_config = {
241
+ "dither": 1.0,
242
+ "frame_length": frame_length,
243
+ "frame_shift": frame_shift,
244
+ "num_mel_bins": args.num_mels,
245
+ "sample_frequency": args.sampling_rate,
246
+ "window_type": args.win_type
247
+ }
248
+
249
+ # Convert the input audio to a FloatTensor and scale it to match the expected input range.
250
+ inputs = torch.FloatTensor(inputs * MAX_WAV_VALUE)
251
+
252
+ # Compute the mel-filterbank features using Kaldi's fbank function.
253
+ fbank = torchaudio.compliance.kaldi.fbank(inputs.unsqueeze(0), **fbank_config)
254
+
255
+ # Add delta and delta-delta features.
256
+ fbank_tr = torch.transpose(fbank, 0, 1)
257
+ fbank_delta = torchaudio.functional.compute_deltas(fbank_tr)
258
+ fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta)
259
+ fbank_delta = torch.transpose(fbank_delta, 0, 1)
260
+ fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1)
261
+
262
+ # Concatenate the original Fbank, delta, and delta-delta features.
263
+ fbanks = torch.cat([fbank, fbank_delta, fbank_delta_delta], dim=1)
264
+
265
+ return fbanks.numpy()
266
+
267
+ class AudioDataset(Dataset):
268
+ """
269
+ A dataset class for loading and processing audio data from different data types
270
+ (train, validation, test). Supports audio processing and feature extraction
271
+ (e.g., waveform processing, Fbank feature extraction).
272
+
273
+ Parameters:
274
+ args: Arguments containing dataset configuration (paths, sampling rate, etc.).
275
+ data_type (str): The type of data to load (train, val, test).
276
+ """
277
+
278
+ def __init__(self, args, data_type):
279
+ self.args = args
280
+ self.sampling_rate = args.sampling_rate
281
+
282
+ # Read the list of audio files based on the data type.
283
+ if data_type == 'train':
284
+ self.wav_list = read_and_config_file(args.tr_list)
285
+ elif data_type == 'val':
286
+ self.wav_list = read_and_config_file(args.cv_list)
287
+ elif data_type == 'test':
288
+ self.wav_list = read_and_config_file(args.tt_list)
289
+ else:
290
+ print(f'Data type: {data_type} is unknown!')
291
+
292
+ # Initialize processors for waveform and Fbank features.
293
+ self.wav_processor = Wave_Processor()
294
+ self.fbank_processor = Fbank_Processor()
295
+
296
+ # Clip data to a fixed segment length based on the sampling rate and max length.
297
+ self.segment_length = self.sampling_rate * self.args.max_length
298
+ print(f'No. {data_type} files: {len(self.wav_list)}')
299
+
300
+ def __len__(self):
301
+ # Return the number of audio files in the dataset.
302
+ return len(self.wav_list)
303
+
304
+ def __getitem__(self, index):
305
+ # Get the input and label paths from the list.
306
+ data_info = self.wav_list[index]
307
+
308
+ # Process the waveform inputs and labels.
309
+ inputs, labels = self.wav_processor.process(
310
+ {'inputs': data_info['inputs'], 'labels': data_info['labels']},
311
+ self.segment_length,
312
+ self.sampling_rate
313
+ )
314
+
315
+ # Optionally load Fbank features if specified.
316
+ if self.args.load_fbank is not None:
317
+ fbanks = self.fbank_processor.process(inputs, self.args)
318
+ return inputs * MAX_WAV_VALUE, labels * MAX_WAV_VALUE, fbanks
319
+
320
+ return inputs, labels
321
+
322
+ def zero_pad_concat(self, inputs):
323
+ """
324
+ Concatenates a list of input arrays, applying zero-padding as needed to ensure
325
+ they all match the length of the longest input.
326
+
327
+ Parameters:
328
+ inputs (list of numpy arrays): List of input arrays to be concatenated.
329
+
330
+ Returns:
331
+ numpy.ndarray: A zero-padded array with concatenated inputs.
332
+ """
333
+
334
+ # Get the maximum length among all inputs.
335
+ max_t = max(inp.shape[0] for inp in inputs)
336
+
337
+ # Determine the shape of the output based on the input dimensions.
338
+ shape = None
339
+ if len(inputs[0].shape) == 1:
340
+ shape = (len(inputs), max_t)
341
+ elif len(inputs[0].shape) == 2:
342
+ shape = (len(inputs), max_t, inputs[0].shape[1])
343
+
344
+ # Initialize an array with zeros to hold the concatenated inputs.
345
+ input_mat = np.zeros(shape, dtype=np.float32)
346
+
347
+ # Copy the input data into the zero-padded array.
348
+ for e, inp in enumerate(inputs):
349
+ if len(inp.shape) == 1:
350
+ input_mat[e, :inp.shape[0]] = inp
351
+ elif len(inp.shape) == 2:
352
+ input_mat[e, :inp.shape[0], :] = inp
353
+
354
+ return input_mat
355
+
356
+ def collate_fn_2x_wavs(data):
357
+ """
358
+ A custom collate function for combining batches of waveform input and label pairs.
359
+
360
+ Parameters:
361
+ data (list): List of tuples (inputs, labels).
362
+
363
+ Returns:
364
+ tuple: Batched inputs and labels as torch.FloatTensors.
365
+ """
366
+ inputs, labels = zip(*data)
367
+ x = torch.FloatTensor(inputs)
368
+ y = torch.FloatTensor(labels)
369
+ return x, y
370
+
371
+ def collate_fn_2x_wavs_fbank(data):
372
+ """
373
+ A custom collate function for combining batches of waveform inputs, labels, and Fbank features.
374
+
375
+ Parameters:
376
+ data (list): List of tuples (inputs, labels, fbanks).
377
+
378
+ Returns:
379
+ tuple: Batched inputs, labels, and Fbank features as torch.FloatTensors.
380
+ """
381
+ inputs, labels, fbanks = zip(*data)
382
+ x = torch.FloatTensor(inputs)
383
+ y = torch.FloatTensor(labels)
384
+ z = torch.FloatTensor(fbanks)
385
+ return x, y, z
386
+
387
+ class DistributedSampler(data.Sampler):
388
+ """
389
+ Sampler for distributed training. Divides the dataset among multiple replicas (processes),
390
+ ensuring that each process gets a unique subset of the data. It also supports shuffling
391
+ and managing epochs.
392
+
393
+ Parameters:
394
+ dataset (Dataset): The dataset to sample from.
395
+ num_replicas (int): Number of processes participating in the training.
396
+ rank (int): Rank of the current process.
397
+ shuffle (bool): Whether to shuffle the data or not.
398
+ seed (int): Random seed for reproducibility.
399
+ """
400
+
401
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
402
+ if num_replicas is None:
403
+ if not dist.is_available():
404
+ raise RuntimeError("Requires distributed package to be available")
405
+ num_replicas = dist.get_world_size()
406
+ if rank is None:
407
+ if not dist.is_available():
408
+ raise RuntimeError("Requires distributed package to be available")
409
+ rank = dist.get_rank()
410
+
411
+ self.dataset = dataset
412
+ self.num_replicas = num_replicas
413
+ self.rank = rank
414
+ self.epoch = 0
415
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
416
+ self.total_size = self.num_samples * self.num_replicas
417
+ self.shuffle = shuffle
418
+ self.seed = seed
419
+
420
+ def __iter__(self):
421
+ # Shuffle the indices based on the epoch and seed.
422
+ if self.shuffle:
423
+ g = torch.Generator()
424
+ g.manual_seed(self.seed + self.epoch)
425
+ ind = torch.randperm(int(len(self.dataset) / self.num_replicas), generator=g) * self.num_replicas
426
+ indices = []
427
+ for i in range(self.num_replicas):
428
+ indices = indices + (ind + i).tolist()
429
+ else:
430
+ indices = list(range(len(self.dataset)))
431
+
432
+ # Add extra samples to make the dataset evenly divisible.
433
+ indices += indices[:(self.total_size - len(indices))]
434
+ assert len(indices) == self.total_size
435
+
436
+ # Subsample for the current process.
437
+ indices = indices[self.rank * self.num_samples:(self.rank + 1) * self.num_samples]
438
+ assert len(indices) == self.num_samples
439
+
440
+ return iter(indices)
441
+
442
+ def __len__(self):
443
+ return self.num_samples
444
+
445
+ def set_epoch(self, epoch):
446
+ self.epoch = epoch
447
+
448
+ def get_dataloader(args, data_type):
449
+ """
450
+ Creates and returns a data loader and sampler for the specified dataset type (train, validation, or test).
451
+
452
+ Parameters:
453
+ args (Namespace): Configuration arguments containing details such as batch size, sampling rate,
454
+ network type, and whether distributed training is used.
455
+ data_type (str): The type of dataset to load ('train', 'val', 'test').
456
+
457
+ Returns:
458
+ sampler (DistributedSampler or None): The sampler for distributed training, or None if not used.
459
+ generator (DataLoader): The PyTorch DataLoader for the specified dataset.
460
+ """
461
+
462
+ # Initialize the dataset based on the given arguments and dataset type (train, val, or test).
463
+ datasets = AudioDataset(args=args, data_type=data_type)
464
+
465
+ # Create a distributed sampler if distributed training is enabled; otherwise, use no sampler.
466
+ sampler = DistributedSampler(
467
+ datasets,
468
+ num_replicas=args.world_size, # Number of replicas in distributed training.
469
+ rank=args.local_rank # Rank of the current process.
470
+ ) if args.distributed else None
471
+
472
+ # Select the appropriate collate function based on the network type.
473
+ if args.network == 'FRCRN_SE_16K' or args.network == 'MossFormerGAN_SE_16K':
474
+ # Use the collate function for two-channel waveform data (inputs and labels).
475
+ collate_fn = collate_fn_2x_wavs
476
+ elif args.network == 'MossFormer2_SE_48K':
477
+ # Use the collate function for waveforms along with Fbank features.
478
+ collate_fn = collate_fn_2x_wavs_fbank
479
+ else:
480
+ # Print an error message if the network type is unknown.
481
+ print(f'in dataloader, please specify a correct network type using args.network!')
482
+ return
483
+
484
+ # Create a DataLoader with the specified dataset, batch size, and worker configuration.
485
+ generator = data.DataLoader(
486
+ datasets,
487
+ batch_size=args.batch_size, # Batch size for training.
488
+ shuffle=(sampler is None), # Shuffle the data only if no sampler is used.
489
+ collate_fn=collate_fn, # Use the selected collate function for batching data.
490
+ num_workers=args.num_workers, # Number of workers for data loading.
491
+ sampler=sampler # Use the distributed sampler if applicable.
492
+ )
493
+
494
+ # Return both the sampler and DataLoader (generator).
495
+ return sampler, generator
496
+
dataloader/misc.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python -u
3
+ # -*- coding: utf-8 -*-
4
+
5
+ from __future__ import absolute_import
6
+ from __future__ import division
7
+ from __future__ import print_function
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import os
12
+ import sys
13
+ import librosa
14
+
15
+ def read_and_config_file(args, input_path, decode=0):
16
+ """
17
+ Reads and processes the input file or directory to extract audio file paths or configuration data.
18
+
19
+ Parameters:
20
+ args: The args
21
+ input_path (str): Path to a file or directory containing audio data or file paths.
22
+ decode (bool): If True (decode=1) for decoding, process the input as audio files directly (find .wav or .flac files) or from a .scp file.
23
+ If False (decode=0) for training, assume the input file contains lines with paths to audio files.
24
+
25
+ Returns:
26
+ processed_list (list): A list of processed file paths or a list of dictionaries containing input
27
+ and optional condition audio paths.
28
+ """
29
+ processed_list = [] # Initialize list to hold processed file paths or configurations
30
+
31
+ if decode:
32
+ if args.task == 'target_speaker_extraction':
33
+ if args.network_reference.cue== 'lip':
34
+ # If decode is True, find video files in a directory or single file
35
+ if os.path.isdir(input_path):
36
+ # Find all .mp4 , mov .avi files in the input directory
37
+ processed_list = librosa.util.find_files(input_path, ext="mp4")
38
+ processed_list += librosa.util.find_files(input_path, ext="avi")
39
+ processed_list += librosa.util.find_files(input_path, ext="mov")
40
+ processed_list += librosa.util.find_files(input_path, ext="MOV")
41
+ else:
42
+ # If it's a single file and it's a .wav or .flac, add to processed list
43
+ if input_path.lower().endswith(".mp4") or input_path.lower().endswith(".avi") or input_path.lower().endswith(".mov"):
44
+ processed_list.append(input_path)
45
+ else:
46
+ # Read file paths from the input text file (one path per line)
47
+ with open(input_path) as fid:
48
+ for line in fid:
49
+ path_s = line.strip().split() # Split paths (space-separated)
50
+ processed_list.append(path_s[0]) # Add the first path (input audio path)
51
+ return processed_list
52
+
53
+ # If decode is True, find audio files in a directory or single file
54
+ if os.path.isdir(input_path):
55
+ # Find all .wav files in the input directory
56
+ processed_list = librosa.util.find_files(input_path, ext="wav")
57
+ if len(processed_list) == 0:
58
+ # If no .wav files, look for .flac files
59
+ processed_list = librosa.util.find_files(input_path, ext="flac")
60
+ else:
61
+ # If it's a single file and it's a .wav or .flac, add to processed list
62
+ if input_path.lower().endswith(".wav") or input_path.lower().endswith(".flac"):
63
+ processed_list.append(input_path)
64
+ else:
65
+ # Read file paths from the input text file (one path per line)
66
+ with open(input_path) as fid:
67
+ for line in fid:
68
+ path_s = line.strip().split() # Split paths (space-separated)
69
+ processed_list.append(path_s[0]) # Add the first path (input audio path)
70
+ return processed_list
71
+
72
+ # If decode is False, treat the input file as a configuration file
73
+ with open(input_path) as fid:
74
+ for line in fid:
75
+ tmp_paths = line.strip().split() # Split paths (space-separated)
76
+ if len(tmp_paths) == 2:
77
+ # If two paths per line, treat the second as 'condition_audio'
78
+ sample = {'inputs': tmp_paths[0], 'condition_audio': tmp_paths[1]}
79
+ elif len(tmp_paths) == 1:
80
+ # If only one path per line, treat it as 'inputs'
81
+ sample = {'inputs': tmp_paths[0]}
82
+ processed_list.append(sample) # Append processed sample to list
83
+ return processed_list
84
+
demo.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clearvoice import ClearVoice
2
+
3
+ ##-----------------demo one: use one model ----------------------------------
4
+ if False:
5
+ myClearVoice = ClearVoice(task='speech_enhancement', model_names=['MossFormer2_SE_48K'])
6
+
7
+ ##1sd calling method: process an input waveform and return output waveform, then write to output.wav
8
+ #output_wav = myClearVoice(input_path='input.wav', online_write=False)
9
+ #myClearVoice.write(output_wav, output_path='output.wav')
10
+
11
+ ##2nd calling method: process all wav files in 'path_to_input_wavs/' and write outputs to 'path_to_output_wavs'
12
+ myClearVoice(input_path='path_to_input_wavs', online_write=True, output_path='path_to_output_wavs')
13
+
14
+ ##3rd calling method: process wav files listed in .scp file, and write outputs to 'path_to_output_waves/'
15
+ #myClearVoice(input_path='scp/cv_webrtc_test_set_20200521_16k.scp', online_write=True, output_path='path_to_output_scp')
16
+
17
+
18
+ ##----------------Demo two: use multiple models -----------------------------------
19
+ if False:
20
+ myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K']) #, 'MossFormerGAN_SE_16K'])
21
+
22
+ ##1sd calling method: process the waveform from input.wav and return output waveform, then write to output.wav
23
+ #output_wav = myClearVoice(input_path='input.wav', online_write=False)
24
+ #myClearVoice.write(output_wav, output_path='output.wav')
25
+
26
+ ##2nd calling method: process all wav files in 'path_to_input_wavs/' and write outputs to 'path_to_output_wavs'
27
+ myClearVoice(input_path='path_to_input_wavs', online_write=True, output_path='path_to_output_wavs')
28
+
29
+ ##3rd calling method: process wav files listed in .scp file, and write outputs to 'path_to_output_waves/'
30
+ #myClearVoice(input_path='scp/cv_webrtc_test_set_20200521_16k.scp', online_write=True, output_path='path_to_output_scp')
31
+
32
+ if False:
33
+ myClearVoice = ClearVoice(task='speech_enhancement', model_names=['MossFormerGAN_SE_16K'])
34
+
35
+ ##1sd calling method: process the waveform from input.wav and return output waveform, then write to output.wav
36
+ #output_wav = myClearVoice(input_path='input.wav', online_write=False)
37
+ #myClearVoice.write(output_wav, output_path='output.wav')
38
+
39
+ ##2nd calling method: process all wav files in 'path_to_input_wavs/' and write outputs to 'path_to_output_wavs'
40
+ myClearVoice(input_path='path_to_input_wavs', online_write=True, output_path='path_to_output_wavs')
41
+
42
+ ##3rd calling method: process wav files listed in .scp file, and write outputs to 'path_to_output_waves/'
43
+ #myClearVoice(input_path='scp/cv_webrtc_test_set_20200521_16k.scp', online_write=True, output_path='path_to_output_scp')
44
+
45
+ ##----------------Demo three: use one model for speech separation -----------------------------------
46
+ if True:
47
+ myClearVoice = ClearVoice(task='speech_separation', model_names=['MossFormer2_SS_16K'])
48
+
49
+ ##1sd calling method: process an input waveform and return output waveform, then write to output.wav
50
+ #output_wav = myClearVoice(input_path='input.wav', online_write=False)
51
+ #myClearVoice.write(output_wav, output_path='output.wav')
52
+
53
+ #2nd calling method: process all wav files in 'path_to_input_wavs/' and write outputs to 'path_to_output_wavs'
54
+ #myClearVoice(input_path='path_to_input_wavs_ss', online_write=True, output_path='path_to_output_wavs')
55
+
56
+ ##3rd calling method: process wav files listed in .scp file, and write outputs to 'path_to_output_waves/'
57
+ myClearVoice(input_path='scp/libri_2mix_tt.scp', online_write=True, output_path='path_to_output_scp')
58
+
59
+ ##----------------Demo four: use one model for audio-visual target speaker extraction -----------------------------------
60
+ if False:
61
+ myClearVoice = ClearVoice(task='target_speaker_extraction', model_names=['AV_MossFormer2_TSE_16K'])
62
+
63
+ # #1sd calling method: process an input video and return output video, then write outputs to 'path_to_output_videos_tse'
64
+ # output_wav = myClearVoice(input_path='path_to_input_videos_tse/004.MOV', online_write=True, output_path='path_to_output_videos_tse')
65
+
66
+ #2nd calling method: process all video files in 'path_to_input_videos/' and write outputs to 'path_to_output_videos_tse'
67
+ myClearVoice(input_path='path_to_input_videos_tse', online_write=True, output_path='path_to_output_videos_tse')
68
+
69
+ # #3rd calling method: process video files listed in .scp file, and write outputs to 'path_to_output_videos_tse/'
70
+ # myClearVoice(input_path='scp/video_samples.scp', online_write=True, output_path='path_to_output_videos_tse')
demo_with_detailed_comments.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clearvoice import ClearVoice # Import the ClearVoice class for speech processing tasks
2
+
3
+ if __name__ == '__main__':
4
+ ## ----------------- Demo One: Using a Single Model ----------------------
5
+ if True: # This block demonstrates how to use a single model for speech enhancement
6
+ # Initialize ClearVoice for the task of speech enhancement using the MossFormerGAN_SE_16K model
7
+ myClearVoice = ClearVoice(task='speech_enhancement', model_names=['MossFormerGAN_SE_16K'])
8
+
9
+ # 1st calling method:
10
+ # Process an input waveform and return the enhanced output waveform
11
+ # - input_path: Path to the input noisy audio file (input.wav)
12
+ # - The returned value is the enhanced output waveform
13
+ output_wav = myClearVoice(input_path='input.wav')
14
+ # Write the processed waveform to an output file
15
+ # - output_wav: The enhanced waveform data
16
+ # - output_path: Path to save the enhanced audio file (output.wav)
17
+ myClearVoice.write(output_wav, output_path='output.wav')
18
+
19
+ # 2nd calling method:
20
+ # Process and write audio files directly
21
+ # - input_path: Directory of input noisy audio files
22
+ # - online_write=True: Enables writing the enhanced audio directly to files during processing
23
+ # - output_path: Directory where the enhanced audio files will be saved
24
+ myClearVoice(input_path='path_to_input_wavs', online_write=True, output_path='path_to_output_wavs')
25
+
26
+ # 3rd calling method:
27
+ # Use an .scp file to specify input audio paths
28
+ # - input_path: Path to an .scp file listing multiple audio file paths
29
+ # - online_write=True: Directly writes the enhanced output during processing
30
+ # - output_path: Directory to save the enhanced output files
31
+ myClearVoice(input_path='data/cv_webrtc_test_set_20200521_16k.scp', online_write=True, output_path='path_to_output_waves')
32
+
33
+
34
+ ## ---------------- Demo Two: Using Multiple Models -----------------------
35
+ if False: # This block demonstrates using multiple models for speech enhancement
36
+ # Initialize ClearVoice for the task of speech enhancement using two models: FRCRN_SE_16K and MossFormerGAN_SE_16K
37
+ myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K', 'MossFormerGAN_SE_16K'])
38
+
39
+ # 1st calling method:
40
+ # Process an input waveform using the multiple models and return the enhanced output waveform
41
+ # - input_path: Path to the input noisy audio file (input.wav)
42
+ # - The returned value is the enhanced output waveform after being processed by the models
43
+ output_wav = myClearVoice(input_path='input.wav')
44
+ # Write the processed waveform to an output file
45
+ # - output_wav: The enhanced waveform data
46
+ # - output_path: Path to save the enhanced audio file (output.wav)
47
+ myClearVoice.write(output_wav, output_path='output.wav')
48
+
49
+ # 2nd calling method:
50
+ # Process and write audio files directly using multiple models
51
+ # - input_path: Directory of input noisy audio files
52
+ # - online_write=True: Enables writing the enhanced audio directly to files during processing
53
+ # - output_path: Directory where the enhanced audio files will be saved
54
+ myClearVoice(input_path='path_to_input_wavs', online_write=True, output_path='path_to_output_wavs')
55
+
56
+ # 3rd calling method:
57
+ # Use an .scp file to specify input audio paths for multiple models
58
+ # - input_path: Path to an .scp file listing multiple audio file paths
59
+ # - online_write=True: Directly writes the enhanced output during processing
60
+ # - output_path: Directory to save the enhanced output files
61
+ myClearVoice(input_path='data/cv_webrtc_test_set_20200521_16k.scp', online_write=True, output_path='path_to_output_waves')
input.wav ADDED
Binary file (76.8 kB). View file
 
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (309 Bytes). View file
 
models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (332 Bytes). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (314 Bytes). View file
 
models/__pycache__/complex_nn.cpython-36.pyc ADDED
Binary file (7.86 kB). View file
 
models/__pycache__/complex_nn.cpython-37.pyc ADDED
Binary file (7.76 kB). View file
 
models/__pycache__/complex_nn.cpython-38.pyc ADDED
Binary file (7.42 kB). View file
 
models/__pycache__/constant.cpython-36.pyc ADDED
Binary file (394 Bytes). View file
 
models/__pycache__/constant.cpython-37.pyc ADDED
Binary file (417 Bytes). View file
 
models/__pycache__/constant.cpython-38.pyc ADDED
Binary file (419 Bytes). View file
 
models/__pycache__/conv_stft.cpython-38.pyc ADDED
Binary file (5.05 kB). View file
 
models/__pycache__/criterion.cpython-36.pyc ADDED
Binary file (1.63 kB). View file
 
models/__pycache__/criterion.cpython-37.pyc ADDED
Binary file (1.63 kB). View file
 
models/__pycache__/criterion.cpython-38.pyc ADDED
Binary file (1.65 kB). View file
 
models/__pycache__/frcrn.cpython-38.pyc ADDED
Binary file (7.17 kB). View file
 
models/__pycache__/metric.cpython-36.pyc ADDED
Binary file (1.07 kB). View file
 
models/__pycache__/noisedataset.cpython-36.pyc ADDED
Binary file (2.82 kB). View file
 
models/__pycache__/noisedataset.cpython-37.pyc ADDED
Binary file (2.78 kB). View file
 
models/__pycache__/noisedataset.cpython-38.pyc ADDED
Binary file (2.79 kB). View file
 
models/__pycache__/phasen_dccrn.cpython-36.pyc ADDED
Binary file (14.4 kB). View file
 
models/__pycache__/phasen_dccrn.cpython-37.pyc ADDED
Binary file (14.2 kB). View file
 
models/__pycache__/phasen_dccrn.cpython-38.pyc ADDED
Binary file (14.1 kB). View file