Spaces:
kgout
/
Running on Zero

kgout commited on
Commit
f0ca515
·
verified ·
1 Parent(s): abfd644

Upload 104 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audiosr/__init__.py +2 -2
  2. audiosr/__main__.py +123 -123
  3. audiosr/clap/open_clip/__init__.py +25 -25
  4. audiosr/clap/open_clip/factory.py +276 -276
  5. audiosr/clap/open_clip/feature_fusion.py +192 -192
  6. audiosr/clap/open_clip/htsat.py +0 -0
  7. audiosr/clap/open_clip/loss.py +397 -397
  8. audiosr/clap/open_clip/model.py +931 -931
  9. audiosr/clap/open_clip/model_configs/HTSAT-base.json +22 -22
  10. audiosr/clap/open_clip/model_configs/HTSAT-large.json +22 -22
  11. audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +22 -22
  12. audiosr/clap/open_clip/model_configs/HTSAT-tiny.json +22 -22
  13. audiosr/clap/open_clip/model_configs/PANN-10.json +22 -22
  14. audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json +22 -22
  15. audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +22 -22
  16. audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +22 -22
  17. audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json +22 -22
  18. audiosr/clap/open_clip/model_configs/PANN-14.json +22 -22
  19. audiosr/clap/open_clip/model_configs/PANN-6.json +22 -22
  20. audiosr/clap/open_clip/model_configs/RN101-quickgelu.json +21 -21
  21. audiosr/clap/open_clip/model_configs/RN101.json +20 -20
  22. audiosr/clap/open_clip/model_configs/RN50-quickgelu.json +22 -22
  23. audiosr/clap/open_clip/model_configs/RN50.json +20 -20
  24. audiosr/clap/open_clip/model_configs/RN50x16.json +20 -20
  25. audiosr/clap/open_clip/model_configs/RN50x4.json +20 -20
  26. audiosr/clap/open_clip/model_configs/ViT-B-16.json +15 -15
  27. audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +16 -16
  28. audiosr/clap/open_clip/model_configs/ViT-B-32.json +15 -15
  29. audiosr/clap/open_clip/model_configs/ViT-L-14.json +15 -15
  30. audiosr/clap/open_clip/openai.py +156 -156
  31. audiosr/clap/open_clip/pann_model.py +697 -697
  32. audiosr/clap/open_clip/pretrained.py +167 -167
  33. audiosr/clap/open_clip/timm_model.py +112 -112
  34. audiosr/clap/open_clip/tokenizer.py +197 -197
  35. audiosr/clap/open_clip/transform.py +45 -45
  36. audiosr/clap/open_clip/utils.py +355 -355
  37. audiosr/clap/training/data.py +865 -865
  38. audiosr/clap/training/params.py +563 -563
  39. audiosr/hifigan/LICENSE +20 -20
  40. audiosr/hifigan/__init__.py +8 -8
  41. audiosr/hifigan/models.py +174 -174
  42. audiosr/hifigan/models_v2.py +395 -395
  43. audiosr/latent_diffusion/models/ddim.py +492 -492
  44. audiosr/latent_diffusion/models/ddpm.py +0 -0
  45. audiosr/latent_diffusion/models/plms.py +360 -360
  46. audiosr/latent_diffusion/modules/attention.py +467 -467
  47. audiosr/latent_diffusion/modules/audiomae/AudioMAE.py +149 -149
  48. audiosr/latent_diffusion/modules/audiomae/models_mae.py +613 -613
  49. audiosr/latent_diffusion/modules/audiomae/models_vit.py +243 -243
  50. audiosr/latent_diffusion/modules/audiomae/util/crop.py +43 -43
audiosr/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
- from .pipeline import *
 
1
+ from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
+ from .pipeline import *
audiosr/__main__.py CHANGED
@@ -1,123 +1,123 @@
1
- #!/usr/bin/python3
2
- import os
3
- import torch
4
- import logging
5
- from audiosr import super_resolution, build_model, save_wave, get_time, read_list
6
- import argparse
7
-
8
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
- matplotlib_logger = logging.getLogger('matplotlib')
10
- matplotlib_logger.setLevel(logging.WARNING)
11
-
12
- parser = argparse.ArgumentParser()
13
-
14
- parser.add_argument(
15
- "-i",
16
- "--input_audio_file",
17
- type=str,
18
- required=False,
19
- help="Input audio file for audio super resolution",
20
- )
21
-
22
- parser.add_argument(
23
- "-il",
24
- "--input_file_list",
25
- type=str,
26
- required=False,
27
- default="",
28
- help="A file that contains all audio files that need to perform audio super resolution",
29
- )
30
-
31
- parser.add_argument(
32
- "-s",
33
- "--save_path",
34
- type=str,
35
- required=False,
36
- help="The path to save model output",
37
- default="./output",
38
- )
39
-
40
- parser.add_argument(
41
- "--model_name",
42
- type=str,
43
- required=False,
44
- help="The checkpoint you gonna use",
45
- default="basic",
46
- choices=["basic","speech"]
47
- )
48
-
49
- parser.add_argument(
50
- "-d",
51
- "--device",
52
- type=str,
53
- required=False,
54
- help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
55
- default="auto",
56
- )
57
-
58
- parser.add_argument(
59
- "--ddim_steps",
60
- type=int,
61
- required=False,
62
- default=50,
63
- help="The sampling step for DDIM",
64
- )
65
-
66
- parser.add_argument(
67
- "-gs",
68
- "--guidance_scale",
69
- type=float,
70
- required=False,
71
- default=3.5,
72
- help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
73
- )
74
-
75
- parser.add_argument(
76
- "--seed",
77
- type=int,
78
- required=False,
79
- default=42,
80
- help="Changing this value (any integer number) will lead to a different generation result.",
81
- )
82
-
83
- parser.add_argument(
84
- "--suffix",
85
- type=str,
86
- required=False,
87
- help="Suffix for the output file",
88
- default="_AudioSR_Processed_48K",
89
- )
90
-
91
- args = parser.parse_args()
92
- torch.set_float32_matmul_precision("high")
93
- save_path = os.path.join(args.save_path, get_time())
94
-
95
- assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
96
-
97
- input_file = args.input_audio_file
98
- random_seed = args.seed
99
- sample_rate=48000
100
- latent_t_per_second=12.8
101
- guidance_scale = args.guidance_scale
102
-
103
- os.makedirs(save_path, exist_ok=True)
104
- audiosr = build_model(model_name=args.model_name, device=args.device)
105
-
106
- if(args.input_file_list):
107
- print("Generate audio based on the text prompts in %s" % args.input_file_list)
108
- files_todo = read_list(args.input_file_list)
109
- else:
110
- files_todo = [input_file]
111
-
112
- for input_file in files_todo:
113
- name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
114
-
115
- waveform = super_resolution(
116
- audiosr,
117
- input_file,
118
- seed=random_seed,
119
- guidance_scale=guidance_scale,
120
- ddim_steps=args.ddim_steps,
121
- latent_t_per_second=latent_t_per_second
122
- )
123
- save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import torch
4
+ import logging
5
+ from audiosr import super_resolution, build_model, save_wave, get_time, read_list
6
+ import argparse
7
+
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
+ matplotlib_logger = logging.getLogger('matplotlib')
10
+ matplotlib_logger.setLevel(logging.WARNING)
11
+
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument(
15
+ "-i",
16
+ "--input_audio_file",
17
+ type=str,
18
+ required=False,
19
+ help="Input audio file for audio super resolution",
20
+ )
21
+
22
+ parser.add_argument(
23
+ "-il",
24
+ "--input_file_list",
25
+ type=str,
26
+ required=False,
27
+ default="",
28
+ help="A file that contains all audio files that need to perform audio super resolution",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "-s",
33
+ "--save_path",
34
+ type=str,
35
+ required=False,
36
+ help="The path to save model output",
37
+ default="./output",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--model_name",
42
+ type=str,
43
+ required=False,
44
+ help="The checkpoint you gonna use",
45
+ default="basic",
46
+ choices=["basic","speech"]
47
+ )
48
+
49
+ parser.add_argument(
50
+ "-d",
51
+ "--device",
52
+ type=str,
53
+ required=False,
54
+ help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
55
+ default="auto",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--ddim_steps",
60
+ type=int,
61
+ required=False,
62
+ default=50,
63
+ help="The sampling step for DDIM",
64
+ )
65
+
66
+ parser.add_argument(
67
+ "-gs",
68
+ "--guidance_scale",
69
+ type=float,
70
+ required=False,
71
+ default=3.5,
72
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--seed",
77
+ type=int,
78
+ required=False,
79
+ default=42,
80
+ help="Change this value (any integer number) will lead to a different generation result.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--suffix",
85
+ type=str,
86
+ required=False,
87
+ help="Suffix for the output file",
88
+ default="_AudioSR_Processed_48K",
89
+ )
90
+
91
+ args = parser.parse_args()
92
+ torch.set_float32_matmul_precision("high")
93
+ save_path = os.path.join(args.save_path, get_time())
94
+
95
+ assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
96
+
97
+ input_file = args.input_audio_file
98
+ random_seed = args.seed
99
+ sample_rate=48000
100
+ latent_t_per_second=12.8
101
+ guidance_scale = args.guidance_scale
102
+
103
+ os.makedirs(save_path, exist_ok=True)
104
+ audiosr = build_model(model_name=args.model_name, device=args.device)
105
+
106
+ if(args.input_file_list):
107
+ print("Generate audio based on the text prompts in %s" % args.input_file_list)
108
+ files_todo = read_list(args.input_file_list)
109
+ else:
110
+ files_todo = [input_file]
111
+
112
+ for input_file in files_todo:
113
+ name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
114
+
115
+ waveform = super_resolution(
116
+ audiosr,
117
+ input_file,
118
+ seed=random_seed,
119
+ guidance_scale=guidance_scale,
120
+ ddim_steps=args.ddim_steps,
121
+ latent_t_per_second=latent_t_per_second
122
+ )
123
+ save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
audiosr/clap/open_clip/__init__.py CHANGED
@@ -1,25 +1,25 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
audiosr/clap/open_clip/factory.py CHANGED
@@ -1,276 +1,276 @@
1
- import json
2
- import logging
3
- import os
4
- import re
5
- from copy import deepcopy
6
- from pathlib import Path
7
-
8
- import torch
9
-
10
- from .model import CLAP, convert_weights_to_fp16
11
- from .openai import load_openai_model
12
- from .pretrained import get_pretrained_url, download_pretrained
13
- from .transform import image_transform
14
-
15
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
16
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
17
-
18
-
19
- def _natural_key(string_):
20
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
21
-
22
-
23
- def _rescan_model_configs():
24
- global _MODEL_CONFIGS
25
-
26
- config_ext = (".json",)
27
- config_files = []
28
- for config_path in _MODEL_CONFIG_PATHS:
29
- if config_path.is_file() and config_path.suffix in config_ext:
30
- config_files.append(config_path)
31
- elif config_path.is_dir():
32
- for ext in config_ext:
33
- config_files.extend(config_path.glob(f"*{ext}"))
34
-
35
- for cf in config_files:
36
- if os.path.basename(cf)[0] == ".":
37
- continue # Ignore hidden files
38
-
39
- with open(cf, "r") as f:
40
- model_cfg = json.load(f)
41
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
- _MODEL_CONFIGS[cf.stem] = model_cfg
43
-
44
- _MODEL_CONFIGS = {
45
- k: v
46
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
- }
48
-
49
-
50
- _rescan_model_configs() # initial populate of model config registry
51
-
52
-
53
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
- state_dict = checkpoint["state_dict"]
57
- else:
58
- state_dict = checkpoint
59
- if skip_params:
60
- if next(iter(state_dict.items()))[0].startswith("module"):
61
- state_dict = {k[7:]: v for k, v in state_dict.items()}
62
- # for k in state_dict:
63
- # if k.startswith('transformer'):
64
- # v = state_dict.pop(k)
65
- # state_dict['text_branch.' + k[12:]] = v
66
- return state_dict
67
-
68
-
69
- def create_model(
70
- amodel_name: str,
71
- tmodel_name: str,
72
- pretrained: str = "",
73
- precision: str = "fp32",
74
- device: torch.device = torch.device("cpu"),
75
- jit: bool = False,
76
- force_quick_gelu: bool = False,
77
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
78
- skip_params=True,
79
- pretrained_audio: str = "",
80
- pretrained_text: str = "",
81
- enable_fusion: bool = False,
82
- fusion_type: str = "None"
83
- # pretrained_image: bool = False,
84
- ):
85
- amodel_name = amodel_name.replace(
86
- "/", "-"
87
- ) # for callers using old naming with / in ViT names
88
- pretrained_orig = pretrained
89
- pretrained = pretrained.lower()
90
- if pretrained == "openai":
91
- if amodel_name in _MODEL_CONFIGS:
92
- logging.info(f"Loading {amodel_name} model config.")
93
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
94
- else:
95
- logging.error(
96
- f"Model config for {amodel_name} not found; available models {list_models()}."
97
- )
98
- raise RuntimeError(f"Model config for {amodel_name} not found.")
99
-
100
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
101
- # Hard Code in model name
102
- model_cfg["text_cfg"]["model_type"] = tmodel_name
103
- model = load_openai_model(
104
- "ViT-B-16",
105
- model_cfg,
106
- device=device,
107
- jit=jit,
108
- cache_dir=openai_model_cache_dir,
109
- enable_fusion=enable_fusion,
110
- fusion_type=fusion_type,
111
- )
112
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
113
- if precision == "amp" or precision == "fp32":
114
- model = model.float()
115
- else:
116
- if amodel_name in _MODEL_CONFIGS:
117
- logging.info(f"Loading {amodel_name} model config.")
118
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
119
- else:
120
- logging.error(
121
- f"Model config for {amodel_name} not found; available models {list_models()}."
122
- )
123
- raise RuntimeError(f"Model config for {amodel_name} not found.")
124
-
125
- if force_quick_gelu:
126
- # override for use of QuickGELU on non-OpenAI transformer models
127
- model_cfg["quick_gelu"] = True
128
-
129
- # if pretrained_image:
130
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
131
- # # pretrained weight loading for timm models set via vision_cfg
132
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
133
- # else:
134
- # assert False, 'pretrained image towers currently only supported for timm models'
135
- model_cfg["text_cfg"]["model_type"] = tmodel_name
136
- model_cfg["enable_fusion"] = enable_fusion
137
- model_cfg["fusion_type"] = fusion_type
138
- model = CLAP(**model_cfg)
139
-
140
- if pretrained:
141
- checkpoint_path = ""
142
- url = get_pretrained_url(amodel_name, pretrained)
143
- if url:
144
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
145
- elif os.path.exists(pretrained_orig):
146
- checkpoint_path = pretrained_orig
147
- if checkpoint_path:
148
- logging.info(
149
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
150
- )
151
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
152
- model.load_state_dict(ckpt)
153
- param_names = [n for n, p in model.named_parameters()]
154
- # for n in param_names:
155
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
156
- else:
157
- logging.warning(
158
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
159
- )
160
- raise RuntimeError(
161
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162
- )
163
-
164
- if pretrained_audio:
165
- if amodel_name.startswith("PANN"):
166
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
167
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
168
- audio_ckpt = audio_ckpt["model"]
169
- keys = list(audio_ckpt.keys())
170
- for key in keys:
171
- if (
172
- "spectrogram_extractor" not in key
173
- and "logmel_extractor" not in key
174
- ):
175
- v = audio_ckpt.pop(key)
176
- audio_ckpt["audio_branch." + key] = v
177
- elif os.path.basename(pretrained_audio).startswith(
178
- "PANN"
179
- ): # checkpoint trained via HTSAT codebase
180
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
181
- audio_ckpt = audio_ckpt["state_dict"]
182
- keys = list(audio_ckpt.keys())
183
- for key in keys:
184
- if key.startswith("sed_model"):
185
- v = audio_ckpt.pop(key)
186
- audio_ckpt["audio_branch." + key[10:]] = v
187
- elif os.path.basename(pretrained_audio).startswith(
188
- "finetuned"
189
- ): # checkpoint trained via linear probe codebase
190
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
191
- else:
192
- raise ValueError("Unknown audio checkpoint")
193
- elif amodel_name.startswith("HTSAT"):
194
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
195
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
196
- audio_ckpt = audio_ckpt["state_dict"]
197
- keys = list(audio_ckpt.keys())
198
- for key in keys:
199
- if key.startswith("sed_model") and (
200
- "spectrogram_extractor" not in key
201
- and "logmel_extractor" not in key
202
- ):
203
- v = audio_ckpt.pop(key)
204
- audio_ckpt["audio_branch." + key[10:]] = v
205
- elif os.path.basename(pretrained_audio).startswith(
206
- "HTSAT"
207
- ): # checkpoint trained via HTSAT codebase
208
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
209
- audio_ckpt = audio_ckpt["state_dict"]
210
- keys = list(audio_ckpt.keys())
211
- for key in keys:
212
- if key.startswith("sed_model"):
213
- v = audio_ckpt.pop(key)
214
- audio_ckpt["audio_branch." + key[10:]] = v
215
- elif os.path.basename(pretrained_audio).startswith(
216
- "finetuned"
217
- ): # checkpoint trained via linear probe codebase
218
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
219
- else:
220
- raise ValueError("Unknown audio checkpoint")
221
- else:
222
- raise f"this audio encoder pretrained checkpoint is not support"
223
-
224
- model.load_state_dict(audio_ckpt, strict=False)
225
- logging.info(
226
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
227
- )
228
- param_names = [n for n, p in model.named_parameters()]
229
- for n in param_names:
230
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
231
-
232
- model.to(device=device)
233
- if precision == "fp16":
234
- assert device.type != "cpu"
235
- convert_weights_to_fp16(model)
236
-
237
- if jit:
238
- model = torch.jit.script(model)
239
-
240
- return model, model_cfg
241
-
242
-
243
- def create_model_and_transforms(
244
- model_name: str,
245
- pretrained: str = "",
246
- precision: str = "fp32",
247
- device: torch.device = torch.device("cpu"),
248
- jit: bool = False,
249
- force_quick_gelu: bool = False,
250
- # pretrained_image: bool = False,
251
- ):
252
- model = create_model(
253
- model_name,
254
- pretrained,
255
- precision,
256
- device,
257
- jit,
258
- force_quick_gelu=force_quick_gelu,
259
- # pretrained_image=pretrained_image
260
- )
261
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
262
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
263
- return model, preprocess_train, preprocess_val
264
-
265
-
266
- def list_models():
267
- """enumerate available model architectures based on config files"""
268
- return list(_MODEL_CONFIGS.keys())
269
-
270
-
271
- def add_model_config(path):
272
- """add model config path or file and update registry"""
273
- if not isinstance(path, Path):
274
- path = Path(path)
275
- _MODEL_CONFIG_PATHS.append(path)
276
- _rescan_model_configs()
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from .model import CLAP, convert_weights_to_fp16
11
+ from .openai import load_openai_model
12
+ from .pretrained import get_pretrained_url, download_pretrained
13
+ from .transform import image_transform
14
+
15
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
16
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
17
+
18
+
19
+ def _natural_key(string_):
20
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
21
+
22
+
23
+ def _rescan_model_configs():
24
+ global _MODEL_CONFIGS
25
+
26
+ config_ext = (".json",)
27
+ config_files = []
28
+ for config_path in _MODEL_CONFIG_PATHS:
29
+ if config_path.is_file() and config_path.suffix in config_ext:
30
+ config_files.append(config_path)
31
+ elif config_path.is_dir():
32
+ for ext in config_ext:
33
+ config_files.extend(config_path.glob(f"*{ext}"))
34
+
35
+ for cf in config_files:
36
+ if os.path.basename(cf)[0] == ".":
37
+ continue # Ignore hidden files
38
+
39
+ with open(cf, "r") as f:
40
+ model_cfg = json.load(f)
41
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
+ _MODEL_CONFIGS[cf.stem] = model_cfg
43
+
44
+ _MODEL_CONFIGS = {
45
+ k: v
46
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
+ }
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
+ state_dict = checkpoint["state_dict"]
57
+ else:
58
+ state_dict = checkpoint
59
+ if skip_params:
60
+ if next(iter(state_dict.items()))[0].startswith("module"):
61
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
62
+ # for k in state_dict:
63
+ # if k.startswith('transformer'):
64
+ # v = state_dict.pop(k)
65
+ # state_dict['text_branch.' + k[12:]] = v
66
+ return state_dict
67
+
68
+
69
+ def create_model(
70
+ amodel_name: str,
71
+ tmodel_name: str,
72
+ pretrained: str = "",
73
+ precision: str = "fp32",
74
+ device: torch.device = torch.device("cpu"),
75
+ jit: bool = False,
76
+ force_quick_gelu: bool = False,
77
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
78
+ skip_params=True,
79
+ pretrained_audio: str = "",
80
+ pretrained_text: str = "",
81
+ enable_fusion: bool = False,
82
+ fusion_type: str = "None"
83
+ # pretrained_image: bool = False,
84
+ ):
85
+ amodel_name = amodel_name.replace(
86
+ "/", "-"
87
+ ) # for callers using old naming with / in ViT names
88
+ pretrained_orig = pretrained
89
+ pretrained = pretrained.lower()
90
+ if pretrained == "openai":
91
+ if amodel_name in _MODEL_CONFIGS:
92
+ logging.info(f"Loading {amodel_name} model config.")
93
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
94
+ else:
95
+ logging.error(
96
+ f"Model config for {amodel_name} not found; available models {list_models()}."
97
+ )
98
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
99
+
100
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
101
+ # Hard Code in model name
102
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
103
+ model = load_openai_model(
104
+ "ViT-B-16",
105
+ model_cfg,
106
+ device=device,
107
+ jit=jit,
108
+ cache_dir=openai_model_cache_dir,
109
+ enable_fusion=enable_fusion,
110
+ fusion_type=fusion_type,
111
+ )
112
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
113
+ if precision == "amp" or precision == "fp32":
114
+ model = model.float()
115
+ else:
116
+ if amodel_name in _MODEL_CONFIGS:
117
+ logging.info(f"Loading {amodel_name} model config.")
118
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
119
+ else:
120
+ logging.error(
121
+ f"Model config for {amodel_name} not found; available models {list_models()}."
122
+ )
123
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
124
+
125
+ if force_quick_gelu:
126
+ # override for use of QuickGELU on non-OpenAI transformer models
127
+ model_cfg["quick_gelu"] = True
128
+
129
+ # if pretrained_image:
130
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
131
+ # # pretrained weight loading for timm models set via vision_cfg
132
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
133
+ # else:
134
+ # assert False, 'pretrained image towers currently only supported for timm models'
135
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
136
+ model_cfg["enable_fusion"] = enable_fusion
137
+ model_cfg["fusion_type"] = fusion_type
138
+ model = CLAP(**model_cfg)
139
+
140
+ if pretrained:
141
+ checkpoint_path = ""
142
+ url = get_pretrained_url(amodel_name, pretrained)
143
+ if url:
144
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
145
+ elif os.path.exists(pretrained_orig):
146
+ checkpoint_path = pretrained_orig
147
+ if checkpoint_path:
148
+ logging.info(
149
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
150
+ )
151
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
152
+ model.load_state_dict(ckpt)
153
+ param_names = [n for n, p in model.named_parameters()]
154
+ # for n in param_names:
155
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
156
+ else:
157
+ logging.warning(
158
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
159
+ )
160
+ raise RuntimeError(
161
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162
+ )
163
+
164
+ if pretrained_audio:
165
+ if amodel_name.startswith("PANN"):
166
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
167
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
168
+ audio_ckpt = audio_ckpt["model"]
169
+ keys = list(audio_ckpt.keys())
170
+ for key in keys:
171
+ if (
172
+ "spectrogram_extractor" not in key
173
+ and "logmel_extractor" not in key
174
+ ):
175
+ v = audio_ckpt.pop(key)
176
+ audio_ckpt["audio_branch." + key] = v
177
+ elif os.path.basename(pretrained_audio).startswith(
178
+ "PANN"
179
+ ): # checkpoint trained via HTSAT codebase
180
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
181
+ audio_ckpt = audio_ckpt["state_dict"]
182
+ keys = list(audio_ckpt.keys())
183
+ for key in keys:
184
+ if key.startswith("sed_model"):
185
+ v = audio_ckpt.pop(key)
186
+ audio_ckpt["audio_branch." + key[10:]] = v
187
+ elif os.path.basename(pretrained_audio).startswith(
188
+ "finetuned"
189
+ ): # checkpoint trained via linear probe codebase
190
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
191
+ else:
192
+ raise ValueError("Unknown audio checkpoint")
193
+ elif amodel_name.startswith("HTSAT"):
194
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
195
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
196
+ audio_ckpt = audio_ckpt["state_dict"]
197
+ keys = list(audio_ckpt.keys())
198
+ for key in keys:
199
+ if key.startswith("sed_model") and (
200
+ "spectrogram_extractor" not in key
201
+ and "logmel_extractor" not in key
202
+ ):
203
+ v = audio_ckpt.pop(key)
204
+ audio_ckpt["audio_branch." + key[10:]] = v
205
+ elif os.path.basename(pretrained_audio).startswith(
206
+ "HTSAT"
207
+ ): # checkpoint trained via HTSAT codebase
208
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
209
+ audio_ckpt = audio_ckpt["state_dict"]
210
+ keys = list(audio_ckpt.keys())
211
+ for key in keys:
212
+ if key.startswith("sed_model"):
213
+ v = audio_ckpt.pop(key)
214
+ audio_ckpt["audio_branch." + key[10:]] = v
215
+ elif os.path.basename(pretrained_audio).startswith(
216
+ "finetuned"
217
+ ): # checkpoint trained via linear probe codebase
218
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
219
+ else:
220
+ raise ValueError("Unknown audio checkpoint")
221
+ else:
222
+ raise f"this audio encoder pretrained checkpoint is not support"
223
+
224
+ model.load_state_dict(audio_ckpt, strict=False)
225
+ logging.info(
226
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
227
+ )
228
+ param_names = [n for n, p in model.named_parameters()]
229
+ for n in param_names:
230
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
231
+
232
+ model.to(device=device)
233
+ if precision == "fp16":
234
+ assert device.type != "cpu"
235
+ convert_weights_to_fp16(model)
236
+
237
+ if jit:
238
+ model = torch.jit.script(model)
239
+
240
+ return model, model_cfg
241
+
242
+
243
+ def create_model_and_transforms(
244
+ model_name: str,
245
+ pretrained: str = "",
246
+ precision: str = "fp32",
247
+ device: torch.device = torch.device("cpu"),
248
+ jit: bool = False,
249
+ force_quick_gelu: bool = False,
250
+ # pretrained_image: bool = False,
251
+ ):
252
+ model = create_model(
253
+ model_name,
254
+ pretrained,
255
+ precision,
256
+ device,
257
+ jit,
258
+ force_quick_gelu=force_quick_gelu,
259
+ # pretrained_image=pretrained_image
260
+ )
261
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
262
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
263
+ return model, preprocess_train, preprocess_val
264
+
265
+
266
+ def list_models():
267
+ """enumerate available model architectures based on config files"""
268
+ return list(_MODEL_CONFIGS.keys())
269
+
270
+
271
+ def add_model_config(path):
272
+ """add model config path or file and update registry"""
273
+ if not isinstance(path, Path):
274
+ path = Path(path)
275
+ _MODEL_CONFIG_PATHS.append(path)
276
+ _rescan_model_configs()
audiosr/clap/open_clip/feature_fusion.py CHANGED
@@ -1,192 +1,192 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo
 
1
+ """
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ """
13
+ 直接相加 DirectAddFuse
14
+ """
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ """
25
+ 多特征融合 iAFF
26
+ """
27
+
28
+ def __init__(self, channels=64, r=4, type="2D"):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == "1D":
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == "2D":
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f"the type is not supported"
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa, xa], dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ """
135
+ 多特征融合 AFF
136
+ """
137
+
138
+ def __init__(self, channels=64, r=4, type="2D"):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == "1D":
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == "2D":
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f"the type is not supported."
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa, xa], dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
audiosr/clap/open_clip/htsat.py CHANGED
The diff for this file is too large to render. See raw diff
 
audiosr/clap/open_clip/loss.py CHANGED
@@ -1,397 +1,397 @@
1
- import torch
2
- import torch.distributed.nn
3
- from torch import distributed as dist, nn as nn
4
- from torch.nn import functional as F
5
- import numpy as np
6
- from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
7
-
8
- try:
9
- import horovod.torch as hvd
10
- except ImportError:
11
- hvd = None
12
-
13
-
14
- def gather_features(
15
- audio_features,
16
- text_features,
17
- audio_features_mlp=None,
18
- text_features_mlp=None,
19
- local_loss=False,
20
- gather_with_grad=False,
21
- rank=0,
22
- world_size=1,
23
- use_horovod=False,
24
- mlp_loss=False,
25
- ):
26
- if use_horovod:
27
- assert hvd is not None, "Please install horovod"
28
- if gather_with_grad:
29
- all_audio_features = hvd.allgather(audio_features)
30
- all_text_features = hvd.allgather(text_features)
31
- if mlp_loss:
32
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
33
- all_text_features_mlp = hvd.allgather(text_features_mlp)
34
- else:
35
- with torch.no_grad():
36
- all_audio_features = hvd.allgather(audio_features)
37
- all_text_features = hvd.allgather(text_features)
38
- if mlp_loss:
39
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
40
- all_text_features_mlp = hvd.allgather(text_features_mlp)
41
- if not local_loss:
42
- # ensure grads for local rank when all_* features don't have a gradient
43
- gathered_audio_features = list(
44
- all_audio_features.chunk(world_size, dim=0)
45
- )
46
- gathered_text_features = list(
47
- all_text_features.chunk(world_size, dim=0)
48
- )
49
- gathered_audio_features[rank] = audio_features
50
- gathered_text_features[rank] = text_features
51
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
52
- all_text_features = torch.cat(gathered_text_features, dim=0)
53
- if mlp_loss:
54
- gathered_audio_features_mlp = list(
55
- all_audio_features_mlp.chunk(world_size, dim=0)
56
- )
57
- gathered_text_features_mlp = list(
58
- all_text_features_mlp.chunk(world_size, dim=0)
59
- )
60
- gathered_audio_features_mlp[rank] = audio_features_mlp
61
- gathered_text_features_mlp[rank] = text_features_mlp
62
- all_audio_features_mlp = torch.cat(
63
- gathered_audio_features_mlp, dim=0
64
- )
65
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
66
- else:
67
- # We gather tensors from all gpus
68
- if gather_with_grad:
69
- all_audio_features = torch.cat(
70
- torch.distributed.nn.all_gather(audio_features), dim=0
71
- )
72
- all_text_features = torch.cat(
73
- torch.distributed.nn.all_gather(text_features), dim=0
74
- )
75
- if mlp_loss:
76
- all_audio_features_mlp = torch.cat(
77
- torch.distributed.nn.all_gather(audio_features_mlp), dim=0
78
- )
79
- all_text_features_mlp = torch.cat(
80
- torch.distributed.nn.all_gather(text_features_mlp), dim=0
81
- )
82
- else:
83
- gathered_audio_features = [
84
- torch.zeros_like(audio_features) for _ in range(world_size)
85
- ]
86
- gathered_text_features = [
87
- torch.zeros_like(text_features) for _ in range(world_size)
88
- ]
89
- dist.all_gather(gathered_audio_features, audio_features)
90
- dist.all_gather(gathered_text_features, text_features)
91
- if mlp_loss:
92
- gathered_audio_features_mlp = [
93
- torch.zeros_like(audio_features_mlp) for _ in range(world_size)
94
- ]
95
- gathered_text_features_mlp = [
96
- torch.zeros_like(text_features_mlp) for _ in range(world_size)
97
- ]
98
- dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
99
- dist.all_gather(gathered_text_features_mlp, text_features_mlp)
100
- if not local_loss:
101
- # ensure grads for local rank when all_* features don't have a gradient
102
- gathered_audio_features[rank] = audio_features
103
- gathered_text_features[rank] = text_features
104
- if mlp_loss:
105
- gathered_audio_features_mlp[rank] = audio_features_mlp
106
- gathered_text_features_mlp[rank] = text_features_mlp
107
-
108
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
109
- all_text_features = torch.cat(gathered_text_features, dim=0)
110
- if mlp_loss:
111
- all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
112
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
113
- if mlp_loss:
114
- return (
115
- all_audio_features,
116
- all_text_features,
117
- all_audio_features_mlp,
118
- all_text_features_mlp,
119
- )
120
- else:
121
- return all_audio_features, all_text_features
122
-
123
-
124
- class ClipLoss(nn.Module):
125
- def __init__(
126
- self,
127
- local_loss=False,
128
- gather_with_grad=False,
129
- cache_labels=False,
130
- rank=0,
131
- world_size=1,
132
- use_horovod=False,
133
- mlp_loss=False,
134
- weight_loss_kappa=0,
135
- ):
136
- super().__init__()
137
- self.local_loss = local_loss
138
- self.gather_with_grad = gather_with_grad
139
- self.cache_labels = cache_labels
140
- self.rank = rank
141
- self.world_size = world_size
142
- self.use_horovod = use_horovod
143
- self.mlp_loss = mlp_loss
144
- self.weighted_loss = bool(weight_loss_kappa != 0)
145
- self.weight_loss_kappa = weight_loss_kappa
146
- # cache state
147
- self.prev_num_logits = 0
148
- self.labels = {}
149
-
150
- def forward(
151
- self,
152
- audio_features,
153
- text_features,
154
- logit_scale_a,
155
- logit_scale_t=None,
156
- audio_features_mlp=None,
157
- text_features_mlp=None,
158
- ):
159
- device = audio_features.device
160
- if self.mlp_loss:
161
- if self.world_size > 1:
162
- (
163
- all_audio_features,
164
- all_text_features,
165
- all_audio_features_mlp,
166
- all_text_features_mlp,
167
- ) = gather_features(
168
- audio_features=audio_features,
169
- text_features=text_features,
170
- audio_features_mlp=audio_features_mlp,
171
- text_features_mlp=text_features_mlp,
172
- local_loss=self.local_loss,
173
- gather_with_grad=self.gather_with_grad,
174
- rank=self.rank,
175
- world_size=self.world_size,
176
- use_horovod=self.use_horovod,
177
- mlp_loss=self.mlp_loss,
178
- )
179
- if self.local_loss:
180
- a_logits_per_audio = (
181
- logit_scale_a * audio_features @ all_text_features_mlp.T
182
- )
183
- a_logits_per_text = (
184
- logit_scale_a * text_features_mlp @ all_audio_features.T
185
- )
186
- t_logits_per_audio = (
187
- logit_scale_t * audio_features_mlp @ all_text_features.T
188
- )
189
- t_logits_per_text = (
190
- logit_scale_t * text_features @ all_audio_features_mlp.T
191
- )
192
- else:
193
- a_logits_per_audio = (
194
- logit_scale_a * all_audio_features @ all_text_features_mlp.T
195
- )
196
- a_logits_per_text = a_logits_per_audio.T
197
- t_logits_per_audio = (
198
- logit_scale_t * all_audio_features_mlp @ all_text_features.T
199
- )
200
- t_logits_per_text = t_logits_per_audio.T
201
- else:
202
- a_logits_per_audio = (
203
- logit_scale_a * audio_features @ text_features_mlp.T
204
- )
205
- a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
206
- t_logits_per_audio = (
207
- logit_scale_t * audio_features_mlp @ text_features.T
208
- )
209
- t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
210
-
211
- # calculated ground-truth and cache if enabled
212
- num_logits = a_logits_per_audio.shape[0]
213
- if self.prev_num_logits != num_logits or device not in self.labels:
214
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
215
- if self.world_size > 1 and self.local_loss:
216
- labels = labels + num_logits * self.rank
217
- if self.cache_labels:
218
- self.labels[device] = labels
219
- self.prev_num_logits = num_logits
220
- else:
221
- labels = self.labels[device]
222
-
223
- if not self.weighted_loss:
224
- total_loss = (
225
- F.cross_entropy(a_logits_per_audio, labels)
226
- + F.cross_entropy(a_logits_per_text, labels)
227
- + F.cross_entropy(t_logits_per_audio, labels)
228
- + F.cross_entropy(t_logits_per_text, labels)
229
- ) / 4
230
- else:
231
- audio_weight = (audio_features @ audio_features.T).detach()
232
- audio_weight = (
233
- torch.exp(
234
- torch.sum(audio_weight, axis=1)
235
- / (self.weight_loss_kappa * len(audio_weight))
236
- )
237
- ).detach()
238
- text_weight = (text_features @ text_features.T).detach()
239
- text_weight = (
240
- torch.exp(
241
- torch.sum(text_weight, axis=1)
242
- / (self.weight_loss_kappa * len(text_features))
243
- )
244
- ).detach()
245
- total_loss = (
246
- F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
247
- + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
248
- + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
249
- + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
250
- ) / 4
251
- else:
252
- if self.world_size > 1:
253
- all_audio_features, all_text_features = gather_features(
254
- audio_features=audio_features,
255
- text_features=text_features,
256
- local_loss=self.local_loss,
257
- gather_with_grad=self.gather_with_grad,
258
- rank=self.rank,
259
- world_size=self.world_size,
260
- use_horovod=self.use_horovod,
261
- mlp_loss=self.mlp_loss,
262
- )
263
-
264
- if self.local_loss:
265
- logits_per_audio = (
266
- logit_scale_a * audio_features @ all_text_features.T
267
- )
268
- logits_per_text = (
269
- logit_scale_a * text_features @ all_audio_features.T
270
- )
271
- else:
272
- logits_per_audio = (
273
- logit_scale_a * all_audio_features @ all_text_features.T
274
- )
275
- logits_per_text = logits_per_audio.T
276
- else:
277
- logits_per_audio = logit_scale_a * audio_features @ text_features.T
278
- logits_per_text = logit_scale_a * text_features @ audio_features.T
279
-
280
- # calculated ground-truth and cache if enabled
281
- num_logits = logits_per_audio.shape[0]
282
- if self.prev_num_logits != num_logits or device not in self.labels:
283
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
284
- if self.world_size > 1 and self.local_loss:
285
- labels = labels + num_logits * self.rank
286
- if self.cache_labels:
287
- self.labels[device] = labels
288
- self.prev_num_logits = num_logits
289
- else:
290
- labels = self.labels[device]
291
- if not self.weighted_loss:
292
- total_loss = (
293
- F.cross_entropy(logits_per_audio, labels)
294
- + F.cross_entropy(logits_per_text, labels)
295
- ) / 2
296
- else:
297
- audio_weight = (all_audio_features @ all_audio_features.T).detach()
298
- audio_weight = (
299
- torch.exp(
300
- torch.sum(audio_weight, axis=1)
301
- / (self.weight_loss_kappa * len(all_audio_features))
302
- )
303
- ).detach()
304
- text_weight = (all_text_features @ all_text_features.T).detach()
305
- text_weight = (
306
- torch.exp(
307
- torch.sum(text_weight, axis=1)
308
- / (self.weight_loss_kappa * len(all_text_features))
309
- )
310
- ).detach()
311
- total_loss = (
312
- F.cross_entropy(logits_per_audio, labels, weight=text_weight)
313
- + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
314
- ) / 2
315
- return total_loss
316
-
317
-
318
- def lp_gather_features(pred, target, world_size=1, use_horovod=False):
319
- if use_horovod:
320
- assert hvd is not None, "Please install horovod"
321
- with torch.no_grad():
322
- all_preds = hvd.allgather(pred)
323
- all_targets = hvd.allgath(target)
324
- else:
325
- gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
326
- gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
327
-
328
- dist.all_gather(gathered_preds, pred)
329
- dist.all_gather(gathered_targets, target)
330
- all_preds = torch.cat(gathered_preds, dim=0)
331
- all_targets = torch.cat(gathered_targets, dim=0)
332
-
333
- return all_preds, all_targets
334
-
335
-
336
- def get_map(pred, target):
337
- pred = torch.sigmoid(pred).numpy()
338
- target = target.numpy()
339
- return np.mean(average_precision_score(target, pred, average=None))
340
-
341
-
342
- def get_acc(pred, target):
343
- pred = torch.argmax(pred, 1).numpy()
344
- target = torch.argmax(target, 1).numpy()
345
- return accuracy_score(target, pred)
346
-
347
-
348
- def get_mauc(pred, target):
349
- pred = torch.sigmoid(pred).numpy()
350
- target = target.numpy()
351
- return np.mean(roc_auc_score(target, pred, average=None))
352
-
353
-
354
- class LPMetrics(object):
355
- def __init__(self, metric_names=["map", "acc", "mauc"]):
356
- self.metrics = []
357
- for name in metric_names:
358
- self.metrics.append(self.get_metric(name))
359
- self.metric_names = metric_names
360
-
361
- def get_metric(self, name):
362
- if name == "map":
363
- return get_map
364
- elif name == "acc":
365
- return get_acc
366
- elif name == "mauc":
367
- return get_mauc
368
- else:
369
- raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
370
-
371
- def evaluate_mertics(self, pred, target):
372
- metric_dict = {}
373
- for i in range(len(self.metric_names)):
374
- metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
375
- return metric_dict
376
-
377
-
378
- def calc_celoss(pred, target):
379
- target = torch.argmax(target, 1).long()
380
- return nn.CrossEntropyLoss()(pred, target)
381
-
382
-
383
- class LPLoss(nn.Module):
384
- def __init__(self, loss_name):
385
- super().__init__()
386
- if loss_name == "bce":
387
- self.loss_func = nn.BCEWithLogitsLoss()
388
- elif loss_name == "ce":
389
- self.loss_func = calc_celoss
390
- elif loss_name == "mse":
391
- self.loss_func = nn.MSELoss()
392
- else:
393
- raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
394
-
395
- def forward(self, pred, target):
396
- loss = self.loss_func(pred, target)
397
- return loss
 
1
+ import torch
2
+ import torch.distributed.nn
3
+ from torch import distributed as dist, nn as nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
7
+
8
+ try:
9
+ import horovod.torch as hvd
10
+ except ImportError:
11
+ hvd = None
12
+
13
+
14
+ def gather_features(
15
+ audio_features,
16
+ text_features,
17
+ audio_features_mlp=None,
18
+ text_features_mlp=None,
19
+ local_loss=False,
20
+ gather_with_grad=False,
21
+ rank=0,
22
+ world_size=1,
23
+ use_horovod=False,
24
+ mlp_loss=False,
25
+ ):
26
+ if use_horovod:
27
+ assert hvd is not None, "Please install horovod"
28
+ if gather_with_grad:
29
+ all_audio_features = hvd.allgather(audio_features)
30
+ all_text_features = hvd.allgather(text_features)
31
+ if mlp_loss:
32
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
33
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
34
+ else:
35
+ with torch.no_grad():
36
+ all_audio_features = hvd.allgather(audio_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if mlp_loss:
39
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
40
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
41
+ if not local_loss:
42
+ # ensure grads for local rank when all_* features don't have a gradient
43
+ gathered_audio_features = list(
44
+ all_audio_features.chunk(world_size, dim=0)
45
+ )
46
+ gathered_text_features = list(
47
+ all_text_features.chunk(world_size, dim=0)
48
+ )
49
+ gathered_audio_features[rank] = audio_features
50
+ gathered_text_features[rank] = text_features
51
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
52
+ all_text_features = torch.cat(gathered_text_features, dim=0)
53
+ if mlp_loss:
54
+ gathered_audio_features_mlp = list(
55
+ all_audio_features_mlp.chunk(world_size, dim=0)
56
+ )
57
+ gathered_text_features_mlp = list(
58
+ all_text_features_mlp.chunk(world_size, dim=0)
59
+ )
60
+ gathered_audio_features_mlp[rank] = audio_features_mlp
61
+ gathered_text_features_mlp[rank] = text_features_mlp
62
+ all_audio_features_mlp = torch.cat(
63
+ gathered_audio_features_mlp, dim=0
64
+ )
65
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
66
+ else:
67
+ # We gather tensors from all gpus
68
+ if gather_with_grad:
69
+ all_audio_features = torch.cat(
70
+ torch.distributed.nn.all_gather(audio_features), dim=0
71
+ )
72
+ all_text_features = torch.cat(
73
+ torch.distributed.nn.all_gather(text_features), dim=0
74
+ )
75
+ if mlp_loss:
76
+ all_audio_features_mlp = torch.cat(
77
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
78
+ )
79
+ all_text_features_mlp = torch.cat(
80
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
81
+ )
82
+ else:
83
+ gathered_audio_features = [
84
+ torch.zeros_like(audio_features) for _ in range(world_size)
85
+ ]
86
+ gathered_text_features = [
87
+ torch.zeros_like(text_features) for _ in range(world_size)
88
+ ]
89
+ dist.all_gather(gathered_audio_features, audio_features)
90
+ dist.all_gather(gathered_text_features, text_features)
91
+ if mlp_loss:
92
+ gathered_audio_features_mlp = [
93
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
94
+ ]
95
+ gathered_text_features_mlp = [
96
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
97
+ ]
98
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
99
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
100
+ if not local_loss:
101
+ # ensure grads for local rank when all_* features don't have a gradient
102
+ gathered_audio_features[rank] = audio_features
103
+ gathered_text_features[rank] = text_features
104
+ if mlp_loss:
105
+ gathered_audio_features_mlp[rank] = audio_features_mlp
106
+ gathered_text_features_mlp[rank] = text_features_mlp
107
+
108
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
109
+ all_text_features = torch.cat(gathered_text_features, dim=0)
110
+ if mlp_loss:
111
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
112
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
113
+ if mlp_loss:
114
+ return (
115
+ all_audio_features,
116
+ all_text_features,
117
+ all_audio_features_mlp,
118
+ all_text_features_mlp,
119
+ )
120
+ else:
121
+ return all_audio_features, all_text_features
122
+
123
+
124
+ class ClipLoss(nn.Module):
125
+ def __init__(
126
+ self,
127
+ local_loss=False,
128
+ gather_with_grad=False,
129
+ cache_labels=False,
130
+ rank=0,
131
+ world_size=1,
132
+ use_horovod=False,
133
+ mlp_loss=False,
134
+ weight_loss_kappa=0,
135
+ ):
136
+ super().__init__()
137
+ self.local_loss = local_loss
138
+ self.gather_with_grad = gather_with_grad
139
+ self.cache_labels = cache_labels
140
+ self.rank = rank
141
+ self.world_size = world_size
142
+ self.use_horovod = use_horovod
143
+ self.mlp_loss = mlp_loss
144
+ self.weighted_loss = bool(weight_loss_kappa != 0)
145
+ self.weight_loss_kappa = weight_loss_kappa
146
+ # cache state
147
+ self.prev_num_logits = 0
148
+ self.labels = {}
149
+
150
+ def forward(
151
+ self,
152
+ audio_features,
153
+ text_features,
154
+ logit_scale_a,
155
+ logit_scale_t=None,
156
+ audio_features_mlp=None,
157
+ text_features_mlp=None,
158
+ ):
159
+ device = audio_features.device
160
+ if self.mlp_loss:
161
+ if self.world_size > 1:
162
+ (
163
+ all_audio_features,
164
+ all_text_features,
165
+ all_audio_features_mlp,
166
+ all_text_features_mlp,
167
+ ) = gather_features(
168
+ audio_features=audio_features,
169
+ text_features=text_features,
170
+ audio_features_mlp=audio_features_mlp,
171
+ text_features_mlp=text_features_mlp,
172
+ local_loss=self.local_loss,
173
+ gather_with_grad=self.gather_with_grad,
174
+ rank=self.rank,
175
+ world_size=self.world_size,
176
+ use_horovod=self.use_horovod,
177
+ mlp_loss=self.mlp_loss,
178
+ )
179
+ if self.local_loss:
180
+ a_logits_per_audio = (
181
+ logit_scale_a * audio_features @ all_text_features_mlp.T
182
+ )
183
+ a_logits_per_text = (
184
+ logit_scale_a * text_features_mlp @ all_audio_features.T
185
+ )
186
+ t_logits_per_audio = (
187
+ logit_scale_t * audio_features_mlp @ all_text_features.T
188
+ )
189
+ t_logits_per_text = (
190
+ logit_scale_t * text_features @ all_audio_features_mlp.T
191
+ )
192
+ else:
193
+ a_logits_per_audio = (
194
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
195
+ )
196
+ a_logits_per_text = a_logits_per_audio.T
197
+ t_logits_per_audio = (
198
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
199
+ )
200
+ t_logits_per_text = t_logits_per_audio.T
201
+ else:
202
+ a_logits_per_audio = (
203
+ logit_scale_a * audio_features @ text_features_mlp.T
204
+ )
205
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
206
+ t_logits_per_audio = (
207
+ logit_scale_t * audio_features_mlp @ text_features.T
208
+ )
209
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
210
+
211
+ # calculated ground-truth and cache if enabled
212
+ num_logits = a_logits_per_audio.shape[0]
213
+ if self.prev_num_logits != num_logits or device not in self.labels:
214
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
215
+ if self.world_size > 1 and self.local_loss:
216
+ labels = labels + num_logits * self.rank
217
+ if self.cache_labels:
218
+ self.labels[device] = labels
219
+ self.prev_num_logits = num_logits
220
+ else:
221
+ labels = self.labels[device]
222
+
223
+ if not self.weighted_loss:
224
+ total_loss = (
225
+ F.cross_entropy(a_logits_per_audio, labels)
226
+ + F.cross_entropy(a_logits_per_text, labels)
227
+ + F.cross_entropy(t_logits_per_audio, labels)
228
+ + F.cross_entropy(t_logits_per_text, labels)
229
+ ) / 4
230
+ else:
231
+ audio_weight = (audio_features @ audio_features.T).detach()
232
+ audio_weight = (
233
+ torch.exp(
234
+ torch.sum(audio_weight, axis=1)
235
+ / (self.weight_loss_kappa * len(audio_weight))
236
+ )
237
+ ).detach()
238
+ text_weight = (text_features @ text_features.T).detach()
239
+ text_weight = (
240
+ torch.exp(
241
+ torch.sum(text_weight, axis=1)
242
+ / (self.weight_loss_kappa * len(text_features))
243
+ )
244
+ ).detach()
245
+ total_loss = (
246
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
247
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
248
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
249
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
250
+ ) / 4
251
+ else:
252
+ if self.world_size > 1:
253
+ all_audio_features, all_text_features = gather_features(
254
+ audio_features=audio_features,
255
+ text_features=text_features,
256
+ local_loss=self.local_loss,
257
+ gather_with_grad=self.gather_with_grad,
258
+ rank=self.rank,
259
+ world_size=self.world_size,
260
+ use_horovod=self.use_horovod,
261
+ mlp_loss=self.mlp_loss,
262
+ )
263
+
264
+ if self.local_loss:
265
+ logits_per_audio = (
266
+ logit_scale_a * audio_features @ all_text_features.T
267
+ )
268
+ logits_per_text = (
269
+ logit_scale_a * text_features @ all_audio_features.T
270
+ )
271
+ else:
272
+ logits_per_audio = (
273
+ logit_scale_a * all_audio_features @ all_text_features.T
274
+ )
275
+ logits_per_text = logits_per_audio.T
276
+ else:
277
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
278
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
279
+
280
+ # calculated ground-truth and cache if enabled
281
+ num_logits = logits_per_audio.shape[0]
282
+ if self.prev_num_logits != num_logits or device not in self.labels:
283
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
284
+ if self.world_size > 1 and self.local_loss:
285
+ labels = labels + num_logits * self.rank
286
+ if self.cache_labels:
287
+ self.labels[device] = labels
288
+ self.prev_num_logits = num_logits
289
+ else:
290
+ labels = self.labels[device]
291
+ if not self.weighted_loss:
292
+ total_loss = (
293
+ F.cross_entropy(logits_per_audio, labels)
294
+ + F.cross_entropy(logits_per_text, labels)
295
+ ) / 2
296
+ else:
297
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
298
+ audio_weight = (
299
+ torch.exp(
300
+ torch.sum(audio_weight, axis=1)
301
+ / (self.weight_loss_kappa * len(all_audio_features))
302
+ )
303
+ ).detach()
304
+ text_weight = (all_text_features @ all_text_features.T).detach()
305
+ text_weight = (
306
+ torch.exp(
307
+ torch.sum(text_weight, axis=1)
308
+ / (self.weight_loss_kappa * len(all_text_features))
309
+ )
310
+ ).detach()
311
+ total_loss = (
312
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
313
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
314
+ ) / 2
315
+ return total_loss
316
+
317
+
318
+ def lp_gather_features(pred, target, world_size=1, use_horovod=False):
319
+ if use_horovod:
320
+ assert hvd is not None, "Please install horovod"
321
+ with torch.no_grad():
322
+ all_preds = hvd.allgather(pred)
323
+ all_targets = hvd.allgath(target)
324
+ else:
325
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
326
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
327
+
328
+ dist.all_gather(gathered_preds, pred)
329
+ dist.all_gather(gathered_targets, target)
330
+ all_preds = torch.cat(gathered_preds, dim=0)
331
+ all_targets = torch.cat(gathered_targets, dim=0)
332
+
333
+ return all_preds, all_targets
334
+
335
+
336
+ def get_map(pred, target):
337
+ pred = torch.sigmoid(pred).numpy()
338
+ target = target.numpy()
339
+ return np.mean(average_precision_score(target, pred, average=None))
340
+
341
+
342
+ def get_acc(pred, target):
343
+ pred = torch.argmax(pred, 1).numpy()
344
+ target = torch.argmax(target, 1).numpy()
345
+ return accuracy_score(target, pred)
346
+
347
+
348
+ def get_mauc(pred, target):
349
+ pred = torch.sigmoid(pred).numpy()
350
+ target = target.numpy()
351
+ return np.mean(roc_auc_score(target, pred, average=None))
352
+
353
+
354
+ class LPMetrics(object):
355
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
356
+ self.metrics = []
357
+ for name in metric_names:
358
+ self.metrics.append(self.get_metric(name))
359
+ self.metric_names = metric_names
360
+
361
+ def get_metric(self, name):
362
+ if name == "map":
363
+ return get_map
364
+ elif name == "acc":
365
+ return get_acc
366
+ elif name == "mauc":
367
+ return get_mauc
368
+ else:
369
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
370
+
371
+ def evaluate_mertics(self, pred, target):
372
+ metric_dict = {}
373
+ for i in range(len(self.metric_names)):
374
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
375
+ return metric_dict
376
+
377
+
378
+ def calc_celoss(pred, target):
379
+ target = torch.argmax(target, 1).long()
380
+ return nn.CrossEntropyLoss()(pred, target)
381
+
382
+
383
+ class LPLoss(nn.Module):
384
+ def __init__(self, loss_name):
385
+ super().__init__()
386
+ if loss_name == "bce":
387
+ self.loss_func = nn.BCEWithLogitsLoss()
388
+ elif loss_name == "ce":
389
+ self.loss_func = calc_celoss
390
+ elif loss_name == "mse":
391
+ self.loss_func = nn.MSELoss()
392
+ else:
393
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
394
+
395
+ def forward(self, pred, target):
396
+ loss = self.loss_func(pred, target)
397
+ return loss
audiosr/clap/open_clip/model.py CHANGED
@@ -1,931 +1,931 @@
1
- """ CLAP Model
2
-
3
- Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- Adapted to the Audio Task.
5
- """
6
-
7
- from collections import OrderedDict
8
- from dataclasses import dataclass
9
- from typing import Tuple, Union, Callable, Optional
10
-
11
- import numpy as np
12
- import torch
13
- import torch.nn.functional as F
14
- from torch import nn
15
-
16
- import logging
17
- from .utils import freeze_batch_norm_2d
18
-
19
- from .pann_model import create_pann_model
20
- from .htsat import create_htsat_model
21
- from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
22
-
23
-
24
- class MLPLayers(nn.Module):
25
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
26
- super(MLPLayers, self).__init__()
27
- self.nonlin = nonlin
28
- self.dropout = dropout
29
-
30
- sequence = []
31
- for u0, u1 in zip(units[:-1], units[1:]):
32
- sequence.append(nn.Linear(u0, u1))
33
- sequence.append(self.nonlin)
34
- sequence.append(nn.Dropout(self.dropout))
35
- sequence = sequence[:-2]
36
-
37
- self.sequential = nn.Sequential(*sequence)
38
-
39
- def forward(self, X):
40
- X = self.sequential(X)
41
- return X
42
-
43
-
44
- class Bottleneck(nn.Module):
45
- expansion = 4
46
-
47
- def __init__(self, inplanes, planes, stride=1):
48
- super().__init__()
49
-
50
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
51
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
52
- self.bn1 = nn.BatchNorm2d(planes)
53
-
54
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
55
- self.bn2 = nn.BatchNorm2d(planes)
56
-
57
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
58
-
59
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
60
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61
-
62
- self.relu = nn.ReLU(inplace=True)
63
- self.downsample = None
64
- self.stride = stride
65
-
66
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
67
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
68
- self.downsample = nn.Sequential(
69
- OrderedDict(
70
- [
71
- ("-1", nn.AvgPool2d(stride)),
72
- (
73
- "0",
74
- nn.Conv2d(
75
- inplanes,
76
- planes * self.expansion,
77
- 1,
78
- stride=1,
79
- bias=False,
80
- ),
81
- ),
82
- ("1", nn.BatchNorm2d(planes * self.expansion)),
83
- ]
84
- )
85
- )
86
-
87
- def forward(self, x: torch.Tensor):
88
- identity = x
89
-
90
- out = self.relu(self.bn1(self.conv1(x)))
91
- out = self.relu(self.bn2(self.conv2(out)))
92
- out = self.avgpool(out)
93
- out = self.bn3(self.conv3(out))
94
-
95
- if self.downsample is not None:
96
- identity = self.downsample(x)
97
-
98
- out += identity
99
- out = self.relu(out)
100
- return out
101
-
102
-
103
- class AttentionPool2d(nn.Module):
104
- def __init__(
105
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
106
- ):
107
- super().__init__()
108
- self.positional_embedding = nn.Parameter(
109
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
110
- )
111
- self.k_proj = nn.Linear(embed_dim, embed_dim)
112
- self.q_proj = nn.Linear(embed_dim, embed_dim)
113
- self.v_proj = nn.Linear(embed_dim, embed_dim)
114
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
115
- self.num_heads = num_heads
116
-
117
- def forward(self, x):
118
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
119
- 2, 0, 1
120
- ) # NCHW -> (HW)NC
121
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
122
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
123
- x, _ = F.multi_head_attention_forward(
124
- query=x,
125
- key=x,
126
- value=x,
127
- embed_dim_to_check=x.shape[-1],
128
- num_heads=self.num_heads,
129
- q_proj_weight=self.q_proj.weight,
130
- k_proj_weight=self.k_proj.weight,
131
- v_proj_weight=self.v_proj.weight,
132
- in_proj_weight=None,
133
- in_proj_bias=torch.cat(
134
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
135
- ),
136
- bias_k=None,
137
- bias_v=None,
138
- add_zero_attn=False,
139
- dropout_p=0,
140
- out_proj_weight=self.c_proj.weight,
141
- out_proj_bias=self.c_proj.bias,
142
- use_separate_proj_weight=True,
143
- training=self.training,
144
- need_weights=False,
145
- )
146
-
147
- return x[0]
148
-
149
-
150
- class ModifiedResNet(nn.Module):
151
- """
152
- A ResNet class that is similar to torchvision's but contains the following changes:
153
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
154
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
155
- - The final pooling layer is a QKV attention instead of an average pool
156
- """
157
-
158
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
159
- super().__init__()
160
- self.output_dim = output_dim
161
- self.image_size = image_size
162
-
163
- # the 3-layer stem
164
- self.conv1 = nn.Conv2d(
165
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
166
- )
167
- self.bn1 = nn.BatchNorm2d(width // 2)
168
- self.conv2 = nn.Conv2d(
169
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
170
- )
171
- self.bn2 = nn.BatchNorm2d(width // 2)
172
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173
- self.bn3 = nn.BatchNorm2d(width)
174
- self.avgpool = nn.AvgPool2d(2)
175
- self.relu = nn.ReLU(inplace=True)
176
-
177
- # residual layers
178
- self._inplanes = width # this is a *mutable* variable used during construction
179
- self.layer1 = self._make_layer(width, layers[0])
180
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183
-
184
- embed_dim = width * 32 # the ResNet feature dimension
185
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
186
-
187
- self.init_parameters()
188
-
189
- def _make_layer(self, planes, blocks, stride=1):
190
- layers = [Bottleneck(self._inplanes, planes, stride)]
191
-
192
- self._inplanes = planes * Bottleneck.expansion
193
- for _ in range(1, blocks):
194
- layers.append(Bottleneck(self._inplanes, planes))
195
-
196
- return nn.Sequential(*layers)
197
-
198
- def init_parameters(self):
199
- if self.attnpool is not None:
200
- std = self.attnpool.c_proj.in_features**-0.5
201
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
202
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
203
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
204
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
205
-
206
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
207
- for name, param in resnet_block.named_parameters():
208
- if name.endswith("bn3.weight"):
209
- nn.init.zeros_(param)
210
-
211
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
212
- assert (
213
- unlocked_groups == 0
214
- ), "partial locking not currently supported for this model"
215
- for param in self.parameters():
216
- param.requires_grad = False
217
- if freeze_bn_stats:
218
- freeze_batch_norm_2d(self)
219
-
220
- def stem(self, x):
221
- for conv, bn in [
222
- (self.conv1, self.bn1),
223
- (self.conv2, self.bn2),
224
- (self.conv3, self.bn3),
225
- ]:
226
- x = self.relu(bn(conv(x)))
227
- x = self.avgpool(x)
228
- return x
229
-
230
- def forward(self, x):
231
- x = self.stem(x)
232
- x = self.layer1(x)
233
- x = self.layer2(x)
234
- x = self.layer3(x)
235
- x = self.layer4(x)
236
- x = self.attnpool(x)
237
-
238
- return x
239
-
240
-
241
- class LayerNorm(nn.LayerNorm):
242
- """Subclass torch's LayerNorm to handle fp16."""
243
-
244
- def forward(self, x: torch.Tensor):
245
- orig_type = x.dtype
246
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
247
- return x.to(orig_type)
248
-
249
-
250
- class QuickGELU(nn.Module):
251
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
252
- def forward(self, x: torch.Tensor):
253
- return x * torch.sigmoid(1.702 * x)
254
-
255
-
256
- class ResidualAttentionBlock(nn.Module):
257
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
258
- super().__init__()
259
-
260
- self.attn = nn.MultiheadAttention(d_model, n_head)
261
- self.ln_1 = LayerNorm(d_model)
262
- self.mlp = nn.Sequential(
263
- OrderedDict(
264
- [
265
- ("c_fc", nn.Linear(d_model, d_model * 4)),
266
- ("gelu", act_layer()),
267
- ("c_proj", nn.Linear(d_model * 4, d_model)),
268
- ]
269
- )
270
- )
271
- self.ln_2 = LayerNorm(d_model)
272
-
273
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
275
-
276
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
278
- x = x + self.mlp(self.ln_2(x))
279
- return x
280
-
281
-
282
- class Transformer(nn.Module):
283
- def __init__(
284
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
285
- ):
286
- super().__init__()
287
- self.width = width
288
- self.layers = layers
289
- self.resblocks = nn.ModuleList(
290
- [
291
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
292
- for _ in range(layers)
293
- ]
294
- )
295
-
296
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
- for r in self.resblocks:
298
- x = r(x, attn_mask=attn_mask)
299
- return x
300
-
301
-
302
- class VisualTransformer(nn.Module):
303
- def __init__(
304
- self,
305
- image_size: int,
306
- patch_size: int,
307
- width: int,
308
- layers: int,
309
- heads: int,
310
- output_dim: int,
311
- act_layer: Callable = nn.GELU,
312
- ):
313
- super().__init__()
314
- self.image_size = image_size
315
- self.output_dim = output_dim
316
- self.conv1 = nn.Conv2d(
317
- in_channels=3,
318
- out_channels=width,
319
- kernel_size=patch_size,
320
- stride=patch_size,
321
- bias=False,
322
- )
323
-
324
- scale = width**-0.5
325
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
326
- self.positional_embedding = nn.Parameter(
327
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
328
- )
329
- self.ln_pre = LayerNorm(width)
330
-
331
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
332
-
333
- self.ln_post = LayerNorm(width)
334
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
335
-
336
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
337
- assert (
338
- unlocked_groups == 0
339
- ), "partial locking not currently supported for this model"
340
- for param in self.parameters():
341
- param.requires_grad = False
342
-
343
- def forward(self, x: torch.Tensor):
344
- x = self.conv1(x) # shape = [*, width, grid, grid]
345
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
346
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
347
- x = torch.cat(
348
- [
349
- self.class_embedding.to(x.dtype)
350
- + torch.zeros(
351
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
352
- ),
353
- x,
354
- ],
355
- dim=1,
356
- ) # shape = [*, grid ** 2 + 1, width]
357
- x = x + self.positional_embedding.to(x.dtype)
358
- x = self.ln_pre(x)
359
-
360
- x = x.permute(1, 0, 2) # NLD -> LND
361
- x = self.text_branch(x)
362
- x = x.permute(1, 0, 2) # LND -> NLD
363
-
364
- x = self.ln_post(x[:, 0, :])
365
-
366
- if self.proj is not None:
367
- x = x @ self.proj
368
-
369
- return x
370
-
371
-
372
- @dataclass
373
- class CLAPVisionCfg:
374
- layers: Union[Tuple[int, int, int, int], int] = 12
375
- width: int = 768
376
- patch_size: int = 16
377
- image_size: Union[Tuple[int, int], int] = 224
378
- timm_model_name: str = (
379
- None # a valid model name overrides layers, width, patch_size
380
- )
381
- timm_model_pretrained: bool = (
382
- False # use (imagenet) pretrained weights for named model
383
- )
384
- timm_pool: str = (
385
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
386
- )
387
- timm_proj: str = (
388
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
389
- )
390
-
391
-
392
- # Audio Config Class
393
- @dataclass
394
- class CLAPAudioCfp:
395
- model_type: str = "PANN"
396
- model_name: str = "Cnn14"
397
- sample_rate: int = 48000
398
- # Param
399
- audio_length: int = 1024
400
- window_size: int = 1024
401
- hop_size: int = 1024
402
- fmin: int = 50
403
- fmax: int = 14000
404
- class_num: int = 527
405
- mel_bins: int = 64
406
- clip_samples: int = 480000
407
-
408
-
409
- @dataclass
410
- class CLAPTextCfg:
411
- context_length: int
412
- vocab_size: int
413
- width: int
414
- heads: int
415
- layers: int
416
- model_type: str
417
-
418
-
419
- class CLAP(nn.Module):
420
- def __init__(
421
- self,
422
- embed_dim: int,
423
- audio_cfg: CLAPAudioCfp,
424
- text_cfg: CLAPTextCfg,
425
- quick_gelu: bool = False,
426
- enable_fusion: bool = False,
427
- fusion_type: str = "None",
428
- joint_embed_shape: int = 512,
429
- mlp_act: str = "relu",
430
- ):
431
- super().__init__()
432
- if isinstance(audio_cfg, dict):
433
- audio_cfg = CLAPAudioCfp(**audio_cfg)
434
- if isinstance(text_cfg, dict):
435
- text_cfg = CLAPTextCfg(**text_cfg)
436
-
437
- self.audio_cfg = audio_cfg
438
- self.text_cfg = text_cfg
439
- self.enable_fusion = enable_fusion
440
- self.fusion_type = fusion_type
441
- self.joint_embed_shape = joint_embed_shape
442
- self.mlp_act = mlp_act
443
-
444
- self.context_length = text_cfg.context_length
445
-
446
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
447
- # memory efficient in recent PyTorch releases (>= 1.10).
448
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
449
- act_layer = QuickGELU if quick_gelu else nn.GELU
450
-
451
- if mlp_act == "relu":
452
- mlp_act_layer = nn.ReLU()
453
- elif mlp_act == "gelu":
454
- mlp_act_layer = nn.GELU()
455
- else:
456
- raise NotImplementedError
457
-
458
- # audio branch
459
- # audio branch parameters
460
- if audio_cfg.model_type == "PANN":
461
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
462
- elif audio_cfg.model_type == "HTSAT":
463
- self.audio_branch = create_htsat_model(
464
- audio_cfg, enable_fusion, fusion_type
465
- )
466
- else:
467
- logging.error(f"Model config for {audio_cfg.model_type} not found")
468
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
469
-
470
- # text branch
471
- # text branch parameters
472
- if text_cfg.model_type == "transformer":
473
- self.text_branch = Transformer(
474
- width=text_cfg.width,
475
- layers=text_cfg.layers,
476
- heads=text_cfg.heads,
477
- act_layer=act_layer,
478
- )
479
- self.vocab_size = text_cfg.vocab_size
480
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
481
- self.positional_embedding = nn.Parameter(
482
- torch.empty(self.context_length, text_cfg.width)
483
- )
484
- self.ln_final = LayerNorm(text_cfg.width)
485
- self.text_transform = MLPLayers(
486
- units=[
487
- self.joint_embed_shape,
488
- self.joint_embed_shape,
489
- self.joint_embed_shape,
490
- ],
491
- dropout=0.1,
492
- )
493
- self.text_projection = nn.Sequential(
494
- nn.Linear(text_cfg.width, self.joint_embed_shape),
495
- mlp_act_layer,
496
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
497
- )
498
- elif text_cfg.model_type == "bert":
499
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
500
- self.text_transform = MLPLayers(
501
- units=[
502
- self.joint_embed_shape,
503
- self.joint_embed_shape,
504
- self.joint_embed_shape,
505
- ],
506
- dropout=0.1,
507
- )
508
- self.text_projection = nn.Sequential(
509
- nn.Linear(768, self.joint_embed_shape),
510
- mlp_act_layer,
511
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
512
- )
513
- elif text_cfg.model_type == "roberta":
514
- self.text_branch = RobertaModel(
515
- RobertaConfig.from_pretrained("roberta-base")
516
- )
517
- self.text_transform = MLPLayers(
518
- units=[
519
- self.joint_embed_shape,
520
- self.joint_embed_shape,
521
- self.joint_embed_shape,
522
- ],
523
- dropout=0.1,
524
- )
525
- self.text_projection = nn.Sequential(
526
- nn.Linear(768, self.joint_embed_shape),
527
- mlp_act_layer,
528
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
529
- )
530
- elif text_cfg.model_type == "bart":
531
- self.text_branch = BartModel.from_pretrained("facebook/bart-base")
532
- self.text_transform = MLPLayers(
533
- units=[
534
- self.joint_embed_shape,
535
- self.joint_embed_shape,
536
- self.joint_embed_shape,
537
- ],
538
- dropout=0.1,
539
- )
540
- self.text_projection = nn.Sequential(
541
- nn.Linear(768, self.joint_embed_shape),
542
- mlp_act_layer,
543
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
544
- )
545
- else:
546
- logging.error(f"Model config for {text_cfg.model_type} not found")
547
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
548
- self.text_branch_type = text_cfg.model_type
549
- # text branch parameters
550
-
551
- # audio branch parameters
552
- self.audio_transform = MLPLayers(
553
- units=[
554
- self.joint_embed_shape,
555
- self.joint_embed_shape,
556
- self.joint_embed_shape,
557
- ],
558
- dropout=0.1,
559
- )
560
-
561
- # below here is text branch parameters
562
-
563
- # ============================================================================================================
564
- self.audio_projection = nn.Sequential(
565
- nn.Linear(embed_dim, self.joint_embed_shape),
566
- mlp_act_layer,
567
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
568
- )
569
-
570
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
571
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
573
-
574
- self.init_text_branch_parameters()
575
-
576
- def init_text_branch_parameters(self):
577
- if self.text_branch_type == "transformer":
578
- nn.init.normal_(self.token_embedding.weight, std=0.02)
579
- nn.init.normal_(self.positional_embedding, std=0.01)
580
- proj_std = (self.text_branch.width**-0.5) * (
581
- (2 * self.text_branch.layers) ** -0.5
582
- )
583
- attn_std = self.text_branch.width**-0.5
584
- fc_std = (2 * self.text_branch.width) ** -0.5
585
- for block in self.text_branch.resblocks:
586
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
587
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
588
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
589
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
590
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
591
- self.text_branch.embeddings.word_embeddings.weight.shape[-1]
592
- elif self.text_branch_type == "bart":
593
- self.text_branch.shared.weight.shape[-1]
594
- else:
595
- self.text_branch.width
596
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
597
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
598
-
599
- # deprecated
600
- # if hasattr(self.visual, 'init_parameters'):
601
- # self.visual.init_parameters()
602
-
603
- # if self.text_projection is not None:
604
- # nn.init.normal_(self.text_projection, std=width**-0.5)
605
-
606
- def build_attention_mask(self):
607
- # lazily create causal attention mask, with full attention between the vision tokens
608
- # pytorch uses additive attention mask; fill with -inf
609
- mask = torch.empty(self.context_length, self.context_length)
610
- mask.fill_(float("-inf"))
611
- mask.triu_(1) # zero out the lower diagonal
612
- return mask
613
-
614
- def encode_audio(self, audio, device):
615
- return self.audio_branch(
616
- audio, mixup_lambda=None, device=device
617
- ) # mix lambda needs to add
618
-
619
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
620
- # tmp = {}
621
- # for k in x[0].keys():
622
- # tmp[k] = []
623
- # for i in range(len(x)):
624
- # tmp[k].append(x[i][k][:77])
625
- # for k in x[0].keys():
626
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
627
- # return tmp
628
-
629
- def encode_text(self, text, device):
630
- if self.text_branch_type == "transformer":
631
- text = text.to(device=device, non_blocking=True)
632
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
633
-
634
- x = x + self.positional_embedding
635
- x = x.permute(1, 0, 2) # NLD -> LND
636
- x = self.text_branch(x, attn_mask=self.attn_mask)
637
- x = x.permute(1, 0, 2) # LND -> NLD
638
- x = self.ln_final(x)
639
-
640
- # x.shape = [batch_size, n_ctx, transformer.width]
641
- # take features from the eot embedding (eot_token is the highest number in each sequence)
642
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
643
- elif self.text_branch_type == "bert":
644
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
645
- # text = BatchEncoding(text)
646
- x = self.text_branch(
647
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
648
- attention_mask=text["attention_mask"].to(
649
- device=device, non_blocking=True
650
- ),
651
- token_type_ids=text["token_type_ids"].to(
652
- device=device, non_blocking=True
653
- ),
654
- )["pooler_output"]
655
- x = self.text_projection(x)
656
- elif self.text_branch_type == "roberta":
657
- x = self.text_branch(
658
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
659
- attention_mask=text["attention_mask"].to(
660
- device=device, non_blocking=True
661
- ),
662
- )["pooler_output"]
663
- x = self.text_projection(x)
664
- elif self.text_branch_type == "bart":
665
- x = torch.mean(
666
- self.text_branch(
667
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
668
- attention_mask=text["attention_mask"].to(
669
- device=device, non_blocking=True
670
- ),
671
- )["encoder_last_hidden_state"],
672
- axis=1,
673
- )
674
- x = self.text_projection(x)
675
- else:
676
- logging.error(f"Model type {self.text_branch_type} not found")
677
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
678
- return x
679
-
680
- def forward(self, audio, text, device=None):
681
- """Forward audio and text into the CLAP
682
-
683
- Parameters
684
- ----------
685
- audio: torch.Tensor (batch_size, audio_length)
686
- the time-domain audio input / the batch of mel_spec and longer list.
687
- text: torch.Tensor () // need to add
688
- the text token input
689
- """
690
- if device is None:
691
- if audio is not None:
692
- device = audio.device
693
- elif text is not None:
694
- device = text.device
695
- if audio is None and text is None:
696
- # a hack to get the logit scale
697
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
698
- elif audio is None:
699
- return self.encode_text(text, device=device)
700
- elif text is None:
701
- return self.audio_projection(
702
- self.encode_audio(audio, device=device)["embedding"]
703
- )
704
- audio_features = self.audio_projection(
705
- self.encode_audio(audio, device=device)["embedding"]
706
- )
707
- audio_features = F.normalize(audio_features, dim=-1)
708
-
709
- text_features = self.encode_text(text, device=device)
710
- # print("text_features", text_features)
711
- # print("text_features.shape", text_features.shape)
712
- # print("text_features.type", type(text_features))
713
- text_features = F.normalize(text_features, dim=-1)
714
-
715
- audio_features_mlp = self.audio_transform(audio_features)
716
- text_features_mlp = self.text_transform(text_features)
717
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
718
- return (
719
- audio_features,
720
- text_features,
721
- audio_features_mlp,
722
- text_features_mlp,
723
- self.logit_scale_a.exp(),
724
- self.logit_scale_t.exp(),
725
- )
726
-
727
- def get_logit_scale(self):
728
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
729
-
730
- def get_text_embedding(self, data):
731
- """Get the text embedding from the model
732
-
733
- Parameters
734
- ----------
735
- data: torch.Tensor
736
- a tensor of text embedding
737
-
738
- Returns
739
- ----------
740
- text_embed: torch.Tensor
741
- a tensor of text_embeds (N, D)
742
-
743
- """
744
- device = next(self.parameters()).device
745
- for k in data:
746
- data[k] = data[k].to(device)
747
- text_embeds = self.encode_text(data, device=device)
748
- text_embeds = F.normalize(text_embeds, dim=-1)
749
-
750
- return text_embeds
751
-
752
- def get_audio_embedding(self, data):
753
- """Get the audio embedding from the model
754
-
755
- Parameters
756
- ----------
757
- data: a list of dict
758
- the audio input dict list from 'get_audio_feature' method
759
-
760
- Returns
761
- ----------
762
- audio_embed: torch.Tensor
763
- a tensor of audio_embeds (N, D)
764
-
765
- """
766
- device = next(self.parameters()).device
767
- # input_dict = {}
768
- # keys = data[0].keys()
769
- # for k in keys:
770
- # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
771
- # device
772
- # )
773
- audio_embeds = self.audio_projection(
774
- self.encode_audio(data, device=device)["embedding"]
775
- )
776
- audio_embeds = F.normalize(audio_embeds, dim=-1)
777
-
778
- return audio_embeds
779
-
780
- def audio_infer(self, audio, hopsize=None, device=None):
781
- """Forward one audio and produce the audio embedding
782
-
783
- Parameters
784
- ----------
785
- audio: (audio_length)
786
- the time-domain audio input, notice that it must be only one input
787
- hopsize: int
788
- the overlap hopsize as the sliding window
789
-
790
- Returns
791
- ----------
792
- output_dict: {
793
- key: [n, (embedding_shape)] if "HTS-AT"
794
- or
795
- key: [(embedding_shape)] if "PANN"
796
- }
797
- the list of key values of the audio branch
798
-
799
- """
800
-
801
- assert not self.training, "the inference mode must be run at eval stage"
802
- output_dict = {}
803
- # PANN
804
- if self.audio_cfg.model_type == "PANN":
805
- audio_input = audio.unsqueeze(dim=0)
806
- output_dict[key] = self.encode_audio(audio_input, device=device)[
807
- key
808
- ].squeeze(dim=0)
809
- elif self.audio_cfg.model_type == "HTSAT":
810
- # repeat
811
- audio_len = len(audio)
812
- k = self.audio_cfg.clip_samples // audio_len
813
- if k > 1:
814
- audio = audio.repeat(k)
815
- audio_len = len(audio)
816
-
817
- if hopsize is None:
818
- hopsize = min(hopsize, audio_len)
819
-
820
- if audio_len > self.audio_cfg.clip_samples:
821
- audio_input = [
822
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
823
- for pos in range(
824
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
825
- )
826
- ]
827
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
828
- audio_input = torch.stack(audio_input)
829
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
830
- else:
831
- audio_input = audio.unsqueeze(dim=0)
832
- output_dict[key] = self.encode_audio(audio_input, device=device)[
833
- key
834
- ].squeeze(dim=0)
835
-
836
- return output_dict
837
-
838
-
839
- def convert_weights_to_fp16(model: nn.Module):
840
- """Convert applicable model parameters to fp16"""
841
-
842
- def _convert_weights_to_fp16(l):
843
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
844
- l.weight.data = l.weight.data.half()
845
- if l.bias is not None:
846
- l.bias.data = l.bias.data.half()
847
-
848
- if isinstance(l, nn.MultiheadAttention):
849
- for attr in [
850
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
851
- "in_proj_bias",
852
- "bias_k",
853
- "bias_v",
854
- ]:
855
- tensor = getattr(l, attr)
856
- if tensor is not None:
857
- tensor.data = tensor.data.half()
858
-
859
- for name in ["text_projection", "proj"]:
860
- if hasattr(l, name):
861
- attr = getattr(l, name)
862
- if attr is not None:
863
- attr.data = attr.data.half()
864
-
865
- model.apply(_convert_weights_to_fp16)
866
-
867
-
868
- # Ignore the state dict of the vision part
869
- def build_model_from_openai_state_dict(
870
- state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
871
- ):
872
- embed_dim = model_cfg["embed_dim"]
873
- audio_cfg = model_cfg["audio_cfg"]
874
- text_cfg = model_cfg["text_cfg"]
875
- state_dict["positional_embedding"].shape[0]
876
- state_dict["token_embedding.weight"].shape[0]
877
- transformer_width = state_dict["ln_final.weight"].shape[0]
878
- transformer_width // 64
879
- transformer_layers = len(
880
- set(
881
- k.split(".")[2]
882
- for k in state_dict
883
- if k.startswith(f"transformer.resblocks")
884
- )
885
- )
886
-
887
- audio_cfg = CLAPAudioCfp(**audio_cfg)
888
- text_cfg = CLAPTextCfg(**text_cfg)
889
-
890
- model = CLAP(
891
- embed_dim,
892
- audio_cfg=audio_cfg,
893
- text_cfg=text_cfg,
894
- quick_gelu=True, # OpenAI models were trained with QuickGELU
895
- enable_fusion=enable_fusion,
896
- fusion_type=fusion_type,
897
- )
898
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
899
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
900
- pop_keys = list(state_dict.keys())[::]
901
- # pop the visual branch saved weights
902
- for key in pop_keys:
903
- if key.startswith("visual."):
904
- state_dict.pop(key, None)
905
-
906
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
907
- state_dict.pop(key, None)
908
-
909
- # not use fp16
910
- # convert_weights_to_fp16(model)
911
- model.load_state_dict(state_dict, strict=False)
912
- return model.eval()
913
-
914
-
915
- def trace_model(model, batch_size=256, device=torch.device("cpu")):
916
- model.eval()
917
- audio_length = model.audio_cfg.audio_length
918
- example_audio = torch.ones((batch_size, audio_length), device=device)
919
- example_text = torch.zeros(
920
- (batch_size, model.context_length), dtype=torch.int, device=device
921
- )
922
- model = torch.jit.trace_module(
923
- model,
924
- inputs=dict(
925
- forward=(example_audio, example_text),
926
- encode_text=(example_text,),
927
- encode_image=(example_audio,),
928
- ),
929
- )
930
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
931
- return model
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from typing import Tuple, Union, Callable, Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ import logging
17
+ from .utils import freeze_batch_norm_2d
18
+
19
+ from .pann_model import create_pann_model
20
+ from .htsat import create_htsat_model
21
+ from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
22
+
23
+
24
+ class MLPLayers(nn.Module):
25
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
26
+ super(MLPLayers, self).__init__()
27
+ self.nonlin = nonlin
28
+ self.dropout = dropout
29
+
30
+ sequence = []
31
+ for u0, u1 in zip(units[:-1], units[1:]):
32
+ sequence.append(nn.Linear(u0, u1))
33
+ sequence.append(self.nonlin)
34
+ sequence.append(nn.Dropout(self.dropout))
35
+ sequence = sequence[:-2]
36
+
37
+ self.sequential = nn.Sequential(*sequence)
38
+
39
+ def forward(self, X):
40
+ X = self.sequential(X)
41
+ return X
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1):
48
+ super().__init__()
49
+
50
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
51
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
52
+ self.bn1 = nn.BatchNorm2d(planes)
53
+
54
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
55
+ self.bn2 = nn.BatchNorm2d(planes)
56
+
57
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
58
+
59
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
60
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61
+
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.downsample = None
64
+ self.stride = stride
65
+
66
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
67
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
68
+ self.downsample = nn.Sequential(
69
+ OrderedDict(
70
+ [
71
+ ("-1", nn.AvgPool2d(stride)),
72
+ (
73
+ "0",
74
+ nn.Conv2d(
75
+ inplanes,
76
+ planes * self.expansion,
77
+ 1,
78
+ stride=1,
79
+ bias=False,
80
+ ),
81
+ ),
82
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
83
+ ]
84
+ )
85
+ )
86
+
87
+ def forward(self, x: torch.Tensor):
88
+ identity = x
89
+
90
+ out = self.relu(self.bn1(self.conv1(x)))
91
+ out = self.relu(self.bn2(self.conv2(out)))
92
+ out = self.avgpool(out)
93
+ out = self.bn3(self.conv3(out))
94
+
95
+ if self.downsample is not None:
96
+ identity = self.downsample(x)
97
+
98
+ out += identity
99
+ out = self.relu(out)
100
+ return out
101
+
102
+
103
+ class AttentionPool2d(nn.Module):
104
+ def __init__(
105
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
106
+ ):
107
+ super().__init__()
108
+ self.positional_embedding = nn.Parameter(
109
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
110
+ )
111
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
112
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
113
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
114
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
115
+ self.num_heads = num_heads
116
+
117
+ def forward(self, x):
118
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
119
+ 2, 0, 1
120
+ ) # NCHW -> (HW)NC
121
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
122
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
123
+ x, _ = F.multi_head_attention_forward(
124
+ query=x,
125
+ key=x,
126
+ value=x,
127
+ embed_dim_to_check=x.shape[-1],
128
+ num_heads=self.num_heads,
129
+ q_proj_weight=self.q_proj.weight,
130
+ k_proj_weight=self.k_proj.weight,
131
+ v_proj_weight=self.v_proj.weight,
132
+ in_proj_weight=None,
133
+ in_proj_bias=torch.cat(
134
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
135
+ ),
136
+ bias_k=None,
137
+ bias_v=None,
138
+ add_zero_attn=False,
139
+ dropout_p=0,
140
+ out_proj_weight=self.c_proj.weight,
141
+ out_proj_bias=self.c_proj.bias,
142
+ use_separate_proj_weight=True,
143
+ training=self.training,
144
+ need_weights=False,
145
+ )
146
+
147
+ return x[0]
148
+
149
+
150
+ class ModifiedResNet(nn.Module):
151
+ """
152
+ A ResNet class that is similar to torchvision's but contains the following changes:
153
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
154
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
155
+ - The final pooling layer is a QKV attention instead of an average pool
156
+ """
157
+
158
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
159
+ super().__init__()
160
+ self.output_dim = output_dim
161
+ self.image_size = image_size
162
+
163
+ # the 3-layer stem
164
+ self.conv1 = nn.Conv2d(
165
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
166
+ )
167
+ self.bn1 = nn.BatchNorm2d(width // 2)
168
+ self.conv2 = nn.Conv2d(
169
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
170
+ )
171
+ self.bn2 = nn.BatchNorm2d(width // 2)
172
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173
+ self.bn3 = nn.BatchNorm2d(width)
174
+ self.avgpool = nn.AvgPool2d(2)
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ # residual layers
178
+ self._inplanes = width # this is a *mutable* variable used during construction
179
+ self.layer1 = self._make_layer(width, layers[0])
180
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183
+
184
+ embed_dim = width * 32 # the ResNet feature dimension
185
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
186
+
187
+ self.init_parameters()
188
+
189
+ def _make_layer(self, planes, blocks, stride=1):
190
+ layers = [Bottleneck(self._inplanes, planes, stride)]
191
+
192
+ self._inplanes = planes * Bottleneck.expansion
193
+ for _ in range(1, blocks):
194
+ layers.append(Bottleneck(self._inplanes, planes))
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def init_parameters(self):
199
+ if self.attnpool is not None:
200
+ std = self.attnpool.c_proj.in_features**-0.5
201
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
202
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
203
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
204
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
205
+
206
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
207
+ for name, param in resnet_block.named_parameters():
208
+ if name.endswith("bn3.weight"):
209
+ nn.init.zeros_(param)
210
+
211
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
212
+ assert (
213
+ unlocked_groups == 0
214
+ ), "partial locking not currently supported for this model"
215
+ for param in self.parameters():
216
+ param.requires_grad = False
217
+ if freeze_bn_stats:
218
+ freeze_batch_norm_2d(self)
219
+
220
+ def stem(self, x):
221
+ for conv, bn in [
222
+ (self.conv1, self.bn1),
223
+ (self.conv2, self.bn2),
224
+ (self.conv3, self.bn3),
225
+ ]:
226
+ x = self.relu(bn(conv(x)))
227
+ x = self.avgpool(x)
228
+ return x
229
+
230
+ def forward(self, x):
231
+ x = self.stem(x)
232
+ x = self.layer1(x)
233
+ x = self.layer2(x)
234
+ x = self.layer3(x)
235
+ x = self.layer4(x)
236
+ x = self.attnpool(x)
237
+
238
+ return x
239
+
240
+
241
+ class LayerNorm(nn.LayerNorm):
242
+ """Subclass torch's LayerNorm to handle fp16."""
243
+
244
+ def forward(self, x: torch.Tensor):
245
+ orig_type = x.dtype
246
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
247
+ return x.to(orig_type)
248
+
249
+
250
+ class QuickGELU(nn.Module):
251
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
252
+ def forward(self, x: torch.Tensor):
253
+ return x * torch.sigmoid(1.702 * x)
254
+
255
+
256
+ class ResidualAttentionBlock(nn.Module):
257
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
258
+ super().__init__()
259
+
260
+ self.attn = nn.MultiheadAttention(d_model, n_head)
261
+ self.ln_1 = LayerNorm(d_model)
262
+ self.mlp = nn.Sequential(
263
+ OrderedDict(
264
+ [
265
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
266
+ ("gelu", act_layer()),
267
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
268
+ ]
269
+ )
270
+ )
271
+ self.ln_2 = LayerNorm(d_model)
272
+
273
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
275
+
276
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
278
+ x = x + self.mlp(self.ln_2(x))
279
+ return x
280
+
281
+
282
+ class Transformer(nn.Module):
283
+ def __init__(
284
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
285
+ ):
286
+ super().__init__()
287
+ self.width = width
288
+ self.layers = layers
289
+ self.resblocks = nn.ModuleList(
290
+ [
291
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
292
+ for _ in range(layers)
293
+ ]
294
+ )
295
+
296
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
+ for r in self.resblocks:
298
+ x = r(x, attn_mask=attn_mask)
299
+ return x
300
+
301
+
302
+ class VisualTransformer(nn.Module):
303
+ def __init__(
304
+ self,
305
+ image_size: int,
306
+ patch_size: int,
307
+ width: int,
308
+ layers: int,
309
+ heads: int,
310
+ output_dim: int,
311
+ act_layer: Callable = nn.GELU,
312
+ ):
313
+ super().__init__()
314
+ self.image_size = image_size
315
+ self.output_dim = output_dim
316
+ self.conv1 = nn.Conv2d(
317
+ in_channels=3,
318
+ out_channels=width,
319
+ kernel_size=patch_size,
320
+ stride=patch_size,
321
+ bias=False,
322
+ )
323
+
324
+ scale = width**-0.5
325
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
326
+ self.positional_embedding = nn.Parameter(
327
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
328
+ )
329
+ self.ln_pre = LayerNorm(width)
330
+
331
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
332
+
333
+ self.ln_post = LayerNorm(width)
334
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
335
+
336
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
337
+ assert (
338
+ unlocked_groups == 0
339
+ ), "partial locking not currently supported for this model"
340
+ for param in self.parameters():
341
+ param.requires_grad = False
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ x = self.conv1(x) # shape = [*, width, grid, grid]
345
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
346
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
347
+ x = torch.cat(
348
+ [
349
+ self.class_embedding.to(x.dtype)
350
+ + torch.zeros(
351
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
352
+ ),
353
+ x,
354
+ ],
355
+ dim=1,
356
+ ) # shape = [*, grid ** 2 + 1, width]
357
+ x = x + self.positional_embedding.to(x.dtype)
358
+ x = self.ln_pre(x)
359
+
360
+ x = x.permute(1, 0, 2) # NLD -> LND
361
+ x = self.text_branch(x)
362
+ x = x.permute(1, 0, 2) # LND -> NLD
363
+
364
+ x = self.ln_post(x[:, 0, :])
365
+
366
+ if self.proj is not None:
367
+ x = x @ self.proj
368
+
369
+ return x
370
+
371
+
372
+ @dataclass
373
+ class CLAPVisionCfg:
374
+ layers: Union[Tuple[int, int, int, int], int] = 12
375
+ width: int = 768
376
+ patch_size: int = 16
377
+ image_size: Union[Tuple[int, int], int] = 224
378
+ timm_model_name: str = (
379
+ None # a valid model name overrides layers, width, patch_size
380
+ )
381
+ timm_model_pretrained: bool = (
382
+ False # use (imagenet) pretrained weights for named model
383
+ )
384
+ timm_pool: str = (
385
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
386
+ )
387
+ timm_proj: str = (
388
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
389
+ )
390
+
391
+
392
+ # Audio Config Class
393
+ @dataclass
394
+ class CLAPAudioCfp:
395
+ model_type: str = "PANN"
396
+ model_name: str = "Cnn14"
397
+ sample_rate: int = 48000
398
+ # Param
399
+ audio_length: int = 1024
400
+ window_size: int = 1024
401
+ hop_size: int = 1024
402
+ fmin: int = 50
403
+ fmax: int = 14000
404
+ class_num: int = 527
405
+ mel_bins: int = 64
406
+ clip_samples: int = 480000
407
+
408
+
409
+ @dataclass
410
+ class CLAPTextCfg:
411
+ context_length: int
412
+ vocab_size: int
413
+ width: int
414
+ heads: int
415
+ layers: int
416
+ model_type: str
417
+
418
+
419
+ class CLAP(nn.Module):
420
+ def __init__(
421
+ self,
422
+ embed_dim: int,
423
+ audio_cfg: CLAPAudioCfp,
424
+ text_cfg: CLAPTextCfg,
425
+ quick_gelu: bool = False,
426
+ enable_fusion: bool = False,
427
+ fusion_type: str = "None",
428
+ joint_embed_shape: int = 512,
429
+ mlp_act: str = "relu",
430
+ ):
431
+ super().__init__()
432
+ if isinstance(audio_cfg, dict):
433
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
434
+ if isinstance(text_cfg, dict):
435
+ text_cfg = CLAPTextCfg(**text_cfg)
436
+
437
+ self.audio_cfg = audio_cfg
438
+ self.text_cfg = text_cfg
439
+ self.enable_fusion = enable_fusion
440
+ self.fusion_type = fusion_type
441
+ self.joint_embed_shape = joint_embed_shape
442
+ self.mlp_act = mlp_act
443
+
444
+ self.context_length = text_cfg.context_length
445
+
446
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
447
+ # memory efficient in recent PyTorch releases (>= 1.10).
448
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
449
+ act_layer = QuickGELU if quick_gelu else nn.GELU
450
+
451
+ if mlp_act == "relu":
452
+ mlp_act_layer = nn.ReLU()
453
+ elif mlp_act == "gelu":
454
+ mlp_act_layer = nn.GELU()
455
+ else:
456
+ raise NotImplementedError
457
+
458
+ # audio branch
459
+ # audio branch parameters
460
+ if audio_cfg.model_type == "PANN":
461
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
462
+ elif audio_cfg.model_type == "HTSAT":
463
+ self.audio_branch = create_htsat_model(
464
+ audio_cfg, enable_fusion, fusion_type
465
+ )
466
+ else:
467
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
468
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
469
+
470
+ # text branch
471
+ # text branch parameters
472
+ if text_cfg.model_type == "transformer":
473
+ self.text_branch = Transformer(
474
+ width=text_cfg.width,
475
+ layers=text_cfg.layers,
476
+ heads=text_cfg.heads,
477
+ act_layer=act_layer,
478
+ )
479
+ self.vocab_size = text_cfg.vocab_size
480
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
481
+ self.positional_embedding = nn.Parameter(
482
+ torch.empty(self.context_length, text_cfg.width)
483
+ )
484
+ self.ln_final = LayerNorm(text_cfg.width)
485
+ self.text_transform = MLPLayers(
486
+ units=[
487
+ self.joint_embed_shape,
488
+ self.joint_embed_shape,
489
+ self.joint_embed_shape,
490
+ ],
491
+ dropout=0.1,
492
+ )
493
+ self.text_projection = nn.Sequential(
494
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
495
+ mlp_act_layer,
496
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
497
+ )
498
+ elif text_cfg.model_type == "bert":
499
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
500
+ self.text_transform = MLPLayers(
501
+ units=[
502
+ self.joint_embed_shape,
503
+ self.joint_embed_shape,
504
+ self.joint_embed_shape,
505
+ ],
506
+ dropout=0.1,
507
+ )
508
+ self.text_projection = nn.Sequential(
509
+ nn.Linear(768, self.joint_embed_shape),
510
+ mlp_act_layer,
511
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
512
+ )
513
+ elif text_cfg.model_type == "roberta":
514
+ self.text_branch = RobertaModel(
515
+ RobertaConfig.from_pretrained("roberta-base")
516
+ )
517
+ self.text_transform = MLPLayers(
518
+ units=[
519
+ self.joint_embed_shape,
520
+ self.joint_embed_shape,
521
+ self.joint_embed_shape,
522
+ ],
523
+ dropout=0.1,
524
+ )
525
+ self.text_projection = nn.Sequential(
526
+ nn.Linear(768, self.joint_embed_shape),
527
+ mlp_act_layer,
528
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
529
+ )
530
+ elif text_cfg.model_type == "bart":
531
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
532
+ self.text_transform = MLPLayers(
533
+ units=[
534
+ self.joint_embed_shape,
535
+ self.joint_embed_shape,
536
+ self.joint_embed_shape,
537
+ ],
538
+ dropout=0.1,
539
+ )
540
+ self.text_projection = nn.Sequential(
541
+ nn.Linear(768, self.joint_embed_shape),
542
+ mlp_act_layer,
543
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
544
+ )
545
+ else:
546
+ logging.error(f"Model config for {text_cfg.model_type} not found")
547
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
548
+ self.text_branch_type = text_cfg.model_type
549
+ # text branch parameters
550
+
551
+ # audio branch parameters
552
+ self.audio_transform = MLPLayers(
553
+ units=[
554
+ self.joint_embed_shape,
555
+ self.joint_embed_shape,
556
+ self.joint_embed_shape,
557
+ ],
558
+ dropout=0.1,
559
+ )
560
+
561
+ # below here is text branch parameters
562
+
563
+ # ============================================================================================================
564
+ self.audio_projection = nn.Sequential(
565
+ nn.Linear(embed_dim, self.joint_embed_shape),
566
+ mlp_act_layer,
567
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
568
+ )
569
+
570
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
571
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
573
+
574
+ self.init_text_branch_parameters()
575
+
576
+ def init_text_branch_parameters(self):
577
+ if self.text_branch_type == "transformer":
578
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
579
+ nn.init.normal_(self.positional_embedding, std=0.01)
580
+ proj_std = (self.text_branch.width**-0.5) * (
581
+ (2 * self.text_branch.layers) ** -0.5
582
+ )
583
+ attn_std = self.text_branch.width**-0.5
584
+ fc_std = (2 * self.text_branch.width) ** -0.5
585
+ for block in self.text_branch.resblocks:
586
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
587
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
588
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
589
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
590
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
591
+ self.text_branch.embeddings.word_embeddings.weight.shape[-1]
592
+ elif self.text_branch_type == "bart":
593
+ self.text_branch.shared.weight.shape[-1]
594
+ else:
595
+ self.text_branch.width
596
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
597
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
598
+
599
+ # deprecated
600
+ # if hasattr(self.visual, 'init_parameters'):
601
+ # self.visual.init_parameters()
602
+
603
+ # if self.text_projection is not None:
604
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
605
+
606
+ def build_attention_mask(self):
607
+ # lazily create causal attention mask, with full attention between the vision tokens
608
+ # pytorch uses additive attention mask; fill with -inf
609
+ mask = torch.empty(self.context_length, self.context_length)
610
+ mask.fill_(float("-inf"))
611
+ mask.triu_(1) # zero out the lower diagonal
612
+ return mask
613
+
614
+ def encode_audio(self, audio, device):
615
+ return self.audio_branch(
616
+ audio, mixup_lambda=None, device=device
617
+ ) # mix lambda needs to add
618
+
619
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
620
+ # tmp = {}
621
+ # for k in x[0].keys():
622
+ # tmp[k] = []
623
+ # for i in range(len(x)):
624
+ # tmp[k].append(x[i][k][:77])
625
+ # for k in x[0].keys():
626
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
627
+ # return tmp
628
+
629
+ def encode_text(self, text, device):
630
+ if self.text_branch_type == "transformer":
631
+ text = text.to(device=device, non_blocking=True)
632
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
633
+
634
+ x = x + self.positional_embedding
635
+ x = x.permute(1, 0, 2) # NLD -> LND
636
+ x = self.text_branch(x, attn_mask=self.attn_mask)
637
+ x = x.permute(1, 0, 2) # LND -> NLD
638
+ x = self.ln_final(x)
639
+
640
+ # x.shape = [batch_size, n_ctx, transformer.width]
641
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
642
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
643
+ elif self.text_branch_type == "bert":
644
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
645
+ # text = BatchEncoding(text)
646
+ x = self.text_branch(
647
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
648
+ attention_mask=text["attention_mask"].to(
649
+ device=device, non_blocking=True
650
+ ),
651
+ token_type_ids=text["token_type_ids"].to(
652
+ device=device, non_blocking=True
653
+ ),
654
+ )["pooler_output"]
655
+ x = self.text_projection(x)
656
+ elif self.text_branch_type == "roberta":
657
+ x = self.text_branch(
658
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
659
+ attention_mask=text["attention_mask"].to(
660
+ device=device, non_blocking=True
661
+ ),
662
+ )["pooler_output"]
663
+ x = self.text_projection(x)
664
+ elif self.text_branch_type == "bart":
665
+ x = torch.mean(
666
+ self.text_branch(
667
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
668
+ attention_mask=text["attention_mask"].to(
669
+ device=device, non_blocking=True
670
+ ),
671
+ )["encoder_last_hidden_state"],
672
+ axis=1,
673
+ )
674
+ x = self.text_projection(x)
675
+ else:
676
+ logging.error(f"Model type {self.text_branch_type} not found")
677
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
678
+ return x
679
+
680
+ def forward(self, audio, text, device=None):
681
+ """Forward audio and text into the CLAP
682
+
683
+ Parameters
684
+ ----------
685
+ audio: torch.Tensor (batch_size, audio_length)
686
+ the time-domain audio input / the batch of mel_spec and longer list.
687
+ text: torch.Tensor () // need to add
688
+ the text token input
689
+ """
690
+ if device is None:
691
+ if audio is not None:
692
+ device = audio.device
693
+ elif text is not None:
694
+ device = text.device
695
+ if audio is None and text is None:
696
+ # a hack to get the logit scale
697
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
698
+ elif audio is None:
699
+ return self.encode_text(text, device=device)
700
+ elif text is None:
701
+ return self.audio_projection(
702
+ self.encode_audio(audio, device=device)["embedding"]
703
+ )
704
+ audio_features = self.audio_projection(
705
+ self.encode_audio(audio, device=device)["embedding"]
706
+ )
707
+ audio_features = F.normalize(audio_features, dim=-1)
708
+
709
+ text_features = self.encode_text(text, device=device)
710
+ # print("text_features", text_features)
711
+ # print("text_features.shape", text_features.shape)
712
+ # print("text_features.type", type(text_features))
713
+ text_features = F.normalize(text_features, dim=-1)
714
+
715
+ audio_features_mlp = self.audio_transform(audio_features)
716
+ text_features_mlp = self.text_transform(text_features)
717
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
718
+ return (
719
+ audio_features,
720
+ text_features,
721
+ audio_features_mlp,
722
+ text_features_mlp,
723
+ self.logit_scale_a.exp(),
724
+ self.logit_scale_t.exp(),
725
+ )
726
+
727
+ def get_logit_scale(self):
728
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
729
+
730
+ def get_text_embedding(self, data):
731
+ """Get the text embedding from the model
732
+
733
+ Parameters
734
+ ----------
735
+ data: torch.Tensor
736
+ a tensor of text embedding
737
+
738
+ Returns
739
+ ----------
740
+ text_embed: torch.Tensor
741
+ a tensor of text_embeds (N, D)
742
+
743
+ """
744
+ device = next(self.parameters()).device
745
+ for k in data:
746
+ data[k] = data[k].to(device)
747
+ text_embeds = self.encode_text(data, device=device)
748
+ text_embeds = F.normalize(text_embeds, dim=-1)
749
+
750
+ return text_embeds
751
+
752
+ def get_audio_embedding(self, data):
753
+ """Get the audio embedding from the model
754
+
755
+ Parameters
756
+ ----------
757
+ data: a list of dict
758
+ the audio input dict list from 'get_audio_feature' method
759
+
760
+ Returns
761
+ ----------
762
+ audio_embed: torch.Tensor
763
+ a tensor of audio_embeds (N, D)
764
+
765
+ """
766
+ device = next(self.parameters()).device
767
+ # input_dict = {}
768
+ # keys = data[0].keys()
769
+ # for k in keys:
770
+ # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
771
+ # device
772
+ # )
773
+ audio_embeds = self.audio_projection(
774
+ self.encode_audio(data, device=device)["embedding"]
775
+ )
776
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
777
+
778
+ return audio_embeds
779
+
780
+ def audio_infer(self, audio, hopsize=None, device=None):
781
+ """Forward one audio and produce the audio embedding
782
+
783
+ Parameters
784
+ ----------
785
+ audio: (audio_length)
786
+ the time-domain audio input, notice that it must be only one input
787
+ hopsize: int
788
+ the overlap hopsize as the sliding window
789
+
790
+ Returns
791
+ ----------
792
+ output_dict: {
793
+ key: [n, (embedding_shape)] if "HTS-AT"
794
+ or
795
+ key: [(embedding_shape)] if "PANN"
796
+ }
797
+ the list of key values of the audio branch
798
+
799
+ """
800
+
801
+ assert not self.training, "the inference mode must be run at eval stage"
802
+ output_dict = {}
803
+ # PANN
804
+ if self.audio_cfg.model_type == "PANN":
805
+ audio_input = audio.unsqueeze(dim=0)
806
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
807
+ key
808
+ ].squeeze(dim=0)
809
+ elif self.audio_cfg.model_type == "HTSAT":
810
+ # repeat
811
+ audio_len = len(audio)
812
+ k = self.audio_cfg.clip_samples // audio_len
813
+ if k > 1:
814
+ audio = audio.repeat(k)
815
+ audio_len = len(audio)
816
+
817
+ if hopsize is None:
818
+ hopsize = min(hopsize, audio_len)
819
+
820
+ if audio_len > self.audio_cfg.clip_samples:
821
+ audio_input = [
822
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
823
+ for pos in range(
824
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
825
+ )
826
+ ]
827
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
828
+ audio_input = torch.stack(audio_input)
829
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
830
+ else:
831
+ audio_input = audio.unsqueeze(dim=0)
832
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
833
+ key
834
+ ].squeeze(dim=0)
835
+
836
+ return output_dict
837
+
838
+
839
+ def convert_weights_to_fp16(model: nn.Module):
840
+ """Convert applicable model parameters to fp16"""
841
+
842
+ def _convert_weights_to_fp16(l):
843
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
844
+ l.weight.data = l.weight.data.half()
845
+ if l.bias is not None:
846
+ l.bias.data = l.bias.data.half()
847
+
848
+ if isinstance(l, nn.MultiheadAttention):
849
+ for attr in [
850
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
851
+ "in_proj_bias",
852
+ "bias_k",
853
+ "bias_v",
854
+ ]:
855
+ tensor = getattr(l, attr)
856
+ if tensor is not None:
857
+ tensor.data = tensor.data.half()
858
+
859
+ for name in ["text_projection", "proj"]:
860
+ if hasattr(l, name):
861
+ attr = getattr(l, name)
862
+ if attr is not None:
863
+ attr.data = attr.data.half()
864
+
865
+ model.apply(_convert_weights_to_fp16)
866
+
867
+
868
+ # Ignore the state dict of the vision part
869
+ def build_model_from_openai_state_dict(
870
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
871
+ ):
872
+ embed_dim = model_cfg["embed_dim"]
873
+ audio_cfg = model_cfg["audio_cfg"]
874
+ text_cfg = model_cfg["text_cfg"]
875
+ state_dict["positional_embedding"].shape[0]
876
+ state_dict["token_embedding.weight"].shape[0]
877
+ transformer_width = state_dict["ln_final.weight"].shape[0]
878
+ transformer_width // 64
879
+ transformer_layers = len(
880
+ set(
881
+ k.split(".")[2]
882
+ for k in state_dict
883
+ if k.startswith(f"transformer.resblocks")
884
+ )
885
+ )
886
+
887
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
888
+ text_cfg = CLAPTextCfg(**text_cfg)
889
+
890
+ model = CLAP(
891
+ embed_dim,
892
+ audio_cfg=audio_cfg,
893
+ text_cfg=text_cfg,
894
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
895
+ enable_fusion=enable_fusion,
896
+ fusion_type=fusion_type,
897
+ )
898
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
899
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
900
+ pop_keys = list(state_dict.keys())[::]
901
+ # pop the visual branch saved weights
902
+ for key in pop_keys:
903
+ if key.startswith("visual."):
904
+ state_dict.pop(key, None)
905
+
906
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
907
+ state_dict.pop(key, None)
908
+
909
+ # not use fp16
910
+ # convert_weights_to_fp16(model)
911
+ model.load_state_dict(state_dict, strict=False)
912
+ return model.eval()
913
+
914
+
915
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
916
+ model.eval()
917
+ audio_length = model.audio_cfg.audio_length
918
+ example_audio = torch.ones((batch_size, audio_length), device=device)
919
+ example_text = torch.zeros(
920
+ (batch_size, model.context_length), dtype=torch.int, device=device
921
+ )
922
+ model = torch.jit.trace_module(
923
+ model,
924
+ inputs=dict(
925
+ forward=(example_audio, example_text),
926
+ encode_text=(example_text,),
927
+ encode_image=(example_audio,),
928
+ ),
929
+ )
930
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
931
+ return model
audiosr/clap/open_clip/model_configs/HTSAT-base.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "base"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/HTSAT-large.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "large"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/HTSAT-tiny.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-10.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn10"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 18000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 960000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 360,
10
- "fmin": 50,
11
- "fmax": 8000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 4
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-14.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/PANN-6.json CHANGED
@@ -1,23 +1,23 @@
1
- {
2
- "embed_dim": 512,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn6"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn6"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
  }
audiosr/clap/open_clip/model_configs/RN101-quickgelu.json CHANGED
@@ -1,22 +1,22 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 23,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
  }
audiosr/clap/open_clip/model_configs/RN101.json CHANGED
@@ -1,21 +1,21 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 23,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
  }
audiosr/clap/open_clip/model_configs/RN50-quickgelu.json CHANGED
@@ -1,22 +1,22 @@
1
- {
2
- "embed_dim": 1024,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 6,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audiosr/clap/open_clip/model_configs/RN50.json CHANGED
@@ -1,21 +1,21 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 6,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
  }
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
  }
audiosr/clap/open_clip/model_configs/RN50x16.json CHANGED
@@ -1,21 +1,21 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 384,
5
- "layers": [
6
- 6,
7
- 8,
8
- 18,
9
- 8
10
- ],
11
- "width": 96,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 768,
18
- "heads": 12,
19
- "layers": 12
20
- }
21
  }
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
  }
audiosr/clap/open_clip/model_configs/RN50x4.json CHANGED
@@ -1,21 +1,21 @@
1
- {
2
- "embed_dim": 640,
3
- "vision_cfg": {
4
- "image_size": 288,
5
- "layers": [
6
- 4,
7
- 6,
8
- 10,
9
- 6
10
- ],
11
- "width": 80,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 640,
18
- "heads": 10,
19
- "layers": 12
20
- }
21
  }
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
  }
audiosr/clap/open_clip/model_configs/ViT-B-16.json CHANGED
@@ -1,16 +1,16 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 16
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
  }
audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json CHANGED
@@ -1,17 +1,17 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": 12,
7
- "width": 768,
8
- "patch_size": 32
9
- },
10
- "text_cfg": {
11
- "context_length": 77,
12
- "vocab_size": 49408,
13
- "width": 512,
14
- "heads": 8,
15
- "layers": 12
16
- }
17
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
  }
audiosr/clap/open_clip/model_configs/ViT-B-32.json CHANGED
@@ -1,16 +1,16 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 32
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
  }
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
  }
audiosr/clap/open_clip/model_configs/ViT-L-14.json CHANGED
@@ -1,16 +1,16 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 24,
6
- "width": 1024,
7
- "patch_size": 14
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 768,
13
- "heads": 12,
14
- "layers": 12
15
- }
16
  }
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
  }
audiosr/clap/open_clip/openai.py CHANGED
@@ -1,156 +1,156 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import Union, List
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict
13
- from .pretrained import (
14
- get_pretrained_url,
15
- list_pretrained_tag_models,
16
- download_pretrained,
17
- )
18
-
19
- __all__ = ["list_openai_models", "load_openai_model"]
20
-
21
-
22
- def list_openai_models() -> List[str]:
23
- """Returns the names of available CLIP models"""
24
- return list_pretrained_tag_models("openai")
25
-
26
-
27
- def load_openai_model(
28
- name: str,
29
- model_cfg,
30
- device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
- jit=True,
32
- cache_dir=os.path.expanduser("~/.cache/clip"),
33
- enable_fusion: bool = False,
34
- fusion_type: str = "None",
35
- ):
36
- """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
-
38
- Parameters
39
- ----------
40
- name : str
41
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
- device : Union[str, torch.device]
43
- The device to put the loaded model
44
- jit : bool
45
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
-
47
- Returns
48
- -------
49
- model : torch.nn.Module
50
- The CLAP model
51
- preprocess : Callable[[PIL.Image], torch.Tensor]
52
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
- """
54
- if get_pretrained_url(name, "openai"):
55
- model_path = download_pretrained(
56
- get_pretrained_url(name, "openai"), root=cache_dir
57
- )
58
- elif os.path.isfile(name):
59
- model_path = name
60
- else:
61
- raise RuntimeError(
62
- f"Model {name} not found; available models = {list_openai_models()}"
63
- )
64
-
65
- try:
66
- # loading JIT archive
67
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
- state_dict = None
69
- except RuntimeError:
70
- # loading saved state dict
71
- if jit:
72
- warnings.warn(
73
- f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
- )
75
- jit = False
76
- state_dict = torch.load(model_path, map_location="cpu")
77
-
78
- if not jit:
79
- try:
80
- model = build_model_from_openai_state_dict(
81
- state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
- ).to(device)
83
- except KeyError:
84
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
- model = build_model_from_openai_state_dict(
86
- sd, model_cfg, enable_fusion, fusion_type
87
- ).to(device)
88
-
89
- if str(device) == "cpu":
90
- model.float()
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(
95
- lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
- )
97
- device_node = [
98
- n
99
- for n in device_holder.graph.findAllNodes("prim::Constant")
100
- if "Device" in repr(n)
101
- ][-1]
102
-
103
- def patch_device(module):
104
- try:
105
- graphs = [module.graph] if hasattr(module, "graph") else []
106
- except RuntimeError:
107
- graphs = []
108
-
109
- if hasattr(module, "forward1"):
110
- graphs.append(module.forward1.graph)
111
-
112
- for graph in graphs:
113
- for node in graph.findAllNodes("prim::Constant"):
114
- if "value" in node.attributeNames() and str(node["value"]).startswith(
115
- "cuda"
116
- ):
117
- node.copyAttributes(device_node)
118
-
119
- model.apply(patch_device)
120
- patch_device(model.encode_audio)
121
- patch_device(model.encode_text)
122
-
123
- # patch dtype to float32 on CPU
124
- if str(device) == "cpu":
125
- float_holder = torch.jit.trace(
126
- lambda: torch.ones([]).float(), example_inputs=[]
127
- )
128
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
- float_node = float_input.node()
130
-
131
- def patch_float(module):
132
- try:
133
- graphs = [module.graph] if hasattr(module, "graph") else []
134
- except RuntimeError:
135
- graphs = []
136
-
137
- if hasattr(module, "forward1"):
138
- graphs.append(module.forward1.graph)
139
-
140
- for graph in graphs:
141
- for node in graph.findAllNodes("aten::to"):
142
- inputs = list(node.inputs())
143
- for i in [
144
- 1,
145
- 2,
146
- ]: # dtype can be the second or third argument to aten::to()
147
- if inputs[i].node()["value"] == 5:
148
- inputs[i].node().copyAttributes(float_node)
149
-
150
- model.apply(patch_float)
151
- patch_float(model.encode_audio)
152
- patch_float(model.encode_text)
153
- model.float()
154
-
155
- model.audio_branch.audio_length = model.audio_cfg.audio_length
156
- return model
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import Union, List
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict
13
+ from .pretrained import (
14
+ get_pretrained_url,
15
+ list_pretrained_tag_models,
16
+ download_pretrained,
17
+ )
18
+
19
+ __all__ = ["list_openai_models", "load_openai_model"]
20
+
21
+
22
+ def list_openai_models() -> List[str]:
23
+ """Returns the names of available CLIP models"""
24
+ return list_pretrained_tag_models("openai")
25
+
26
+
27
+ def load_openai_model(
28
+ name: str,
29
+ model_cfg,
30
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
+ jit=True,
32
+ cache_dir=os.path.expanduser("~/.cache/clip"),
33
+ enable_fusion: bool = False,
34
+ fusion_type: str = "None",
35
+ ):
36
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
+
38
+ Parameters
39
+ ----------
40
+ name : str
41
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
+ device : Union[str, torch.device]
43
+ The device to put the loaded model
44
+ jit : bool
45
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
+
47
+ Returns
48
+ -------
49
+ model : torch.nn.Module
50
+ The CLAP model
51
+ preprocess : Callable[[PIL.Image], torch.Tensor]
52
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
+ """
54
+ if get_pretrained_url(name, "openai"):
55
+ model_path = download_pretrained(
56
+ get_pretrained_url(name, "openai"), root=cache_dir
57
+ )
58
+ elif os.path.isfile(name):
59
+ model_path = name
60
+ else:
61
+ raise RuntimeError(
62
+ f"Model {name} not found; available models = {list_openai_models()}"
63
+ )
64
+
65
+ try:
66
+ # loading JIT archive
67
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
+ state_dict = None
69
+ except RuntimeError:
70
+ # loading saved state dict
71
+ if jit:
72
+ warnings.warn(
73
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
+ )
75
+ jit = False
76
+ state_dict = torch.load(model_path, map_location="cpu")
77
+
78
+ if not jit:
79
+ try:
80
+ model = build_model_from_openai_state_dict(
81
+ state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
+ ).to(device)
83
+ except KeyError:
84
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
+ model = build_model_from_openai_state_dict(
86
+ sd, model_cfg, enable_fusion, fusion_type
87
+ ).to(device)
88
+
89
+ if str(device) == "cpu":
90
+ model.float()
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(
95
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
+ )
97
+ device_node = [
98
+ n
99
+ for n in device_holder.graph.findAllNodes("prim::Constant")
100
+ if "Device" in repr(n)
101
+ ][-1]
102
+
103
+ def patch_device(module):
104
+ try:
105
+ graphs = [module.graph] if hasattr(module, "graph") else []
106
+ except RuntimeError:
107
+ graphs = []
108
+
109
+ if hasattr(module, "forward1"):
110
+ graphs.append(module.forward1.graph)
111
+
112
+ for graph in graphs:
113
+ for node in graph.findAllNodes("prim::Constant"):
114
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
115
+ "cuda"
116
+ ):
117
+ node.copyAttributes(device_node)
118
+
119
+ model.apply(patch_device)
120
+ patch_device(model.encode_audio)
121
+ patch_device(model.encode_text)
122
+
123
+ # patch dtype to float32 on CPU
124
+ if str(device) == "cpu":
125
+ float_holder = torch.jit.trace(
126
+ lambda: torch.ones([]).float(), example_inputs=[]
127
+ )
128
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
+ float_node = float_input.node()
130
+
131
+ def patch_float(module):
132
+ try:
133
+ graphs = [module.graph] if hasattr(module, "graph") else []
134
+ except RuntimeError:
135
+ graphs = []
136
+
137
+ if hasattr(module, "forward1"):
138
+ graphs.append(module.forward1.graph)
139
+
140
+ for graph in graphs:
141
+ for node in graph.findAllNodes("aten::to"):
142
+ inputs = list(node.inputs())
143
+ for i in [
144
+ 1,
145
+ 2,
146
+ ]: # dtype can be the second or third argument to aten::to()
147
+ if inputs[i].node()["value"] == 5:
148
+ inputs[i].node().copyAttributes(float_node)
149
+
150
+ model.apply(patch_float)
151
+ patch_float(model.encode_audio)
152
+ patch_float(model.encode_text)
153
+ model.float()
154
+
155
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
156
+ return model
audiosr/clap/open_clip/pann_model.py CHANGED
@@ -1,697 +1,697 @@
1
- # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
- # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
- # Some layers are re-designed for CLAP
4
- import os
5
-
6
- os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
- from torchlibrosa.augmentation import SpecAugmentation
13
-
14
- from .utils import do_mixup, interpolate
15
- from .feature_fusion import iAFF, AFF, DAF
16
-
17
-
18
- def init_layer(layer):
19
- """Initialize a Linear or Convolutional layer."""
20
- nn.init.xavier_uniform_(layer.weight)
21
-
22
- if hasattr(layer, "bias"):
23
- if layer.bias is not None:
24
- layer.bias.data.fill_(0.0)
25
-
26
-
27
- def init_bn(bn):
28
- """Initialize a Batchnorm layer."""
29
- bn.bias.data.fill_(0.0)
30
- bn.weight.data.fill_(1.0)
31
-
32
-
33
- class ConvBlock(nn.Module):
34
- def __init__(self, in_channels, out_channels):
35
- super(ConvBlock, self).__init__()
36
-
37
- self.conv1 = nn.Conv2d(
38
- in_channels=in_channels,
39
- out_channels=out_channels,
40
- kernel_size=(3, 3),
41
- stride=(1, 1),
42
- padding=(1, 1),
43
- bias=False,
44
- )
45
-
46
- self.conv2 = nn.Conv2d(
47
- in_channels=out_channels,
48
- out_channels=out_channels,
49
- kernel_size=(3, 3),
50
- stride=(1, 1),
51
- padding=(1, 1),
52
- bias=False,
53
- )
54
-
55
- self.bn1 = nn.BatchNorm2d(out_channels)
56
- self.bn2 = nn.BatchNorm2d(out_channels)
57
-
58
- self.init_weight()
59
-
60
- def init_weight(self):
61
- init_layer(self.conv1)
62
- init_layer(self.conv2)
63
- init_bn(self.bn1)
64
- init_bn(self.bn2)
65
-
66
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
- x = input
68
- x = F.relu_(self.bn1(self.conv1(x)))
69
- x = F.relu_(self.bn2(self.conv2(x)))
70
- if pool_type == "max":
71
- x = F.max_pool2d(x, kernel_size=pool_size)
72
- elif pool_type == "avg":
73
- x = F.avg_pool2d(x, kernel_size=pool_size)
74
- elif pool_type == "avg+max":
75
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
76
- x2 = F.max_pool2d(x, kernel_size=pool_size)
77
- x = x1 + x2
78
- else:
79
- raise Exception("Incorrect argument!")
80
-
81
- return x
82
-
83
-
84
- class ConvBlock5x5(nn.Module):
85
- def __init__(self, in_channels, out_channels):
86
- super(ConvBlock5x5, self).__init__()
87
-
88
- self.conv1 = nn.Conv2d(
89
- in_channels=in_channels,
90
- out_channels=out_channels,
91
- kernel_size=(5, 5),
92
- stride=(1, 1),
93
- padding=(2, 2),
94
- bias=False,
95
- )
96
-
97
- self.bn1 = nn.BatchNorm2d(out_channels)
98
-
99
- self.init_weight()
100
-
101
- def init_weight(self):
102
- init_layer(self.conv1)
103
- init_bn(self.bn1)
104
-
105
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
106
- x = input
107
- x = F.relu_(self.bn1(self.conv1(x)))
108
- if pool_type == "max":
109
- x = F.max_pool2d(x, kernel_size=pool_size)
110
- elif pool_type == "avg":
111
- x = F.avg_pool2d(x, kernel_size=pool_size)
112
- elif pool_type == "avg+max":
113
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
114
- x2 = F.max_pool2d(x, kernel_size=pool_size)
115
- x = x1 + x2
116
- else:
117
- raise Exception("Incorrect argument!")
118
-
119
- return x
120
-
121
-
122
- class AttBlock(nn.Module):
123
- def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
124
- super(AttBlock, self).__init__()
125
-
126
- self.activation = activation
127
- self.temperature = temperature
128
- self.att = nn.Conv1d(
129
- in_channels=n_in,
130
- out_channels=n_out,
131
- kernel_size=1,
132
- stride=1,
133
- padding=0,
134
- bias=True,
135
- )
136
- self.cla = nn.Conv1d(
137
- in_channels=n_in,
138
- out_channels=n_out,
139
- kernel_size=1,
140
- stride=1,
141
- padding=0,
142
- bias=True,
143
- )
144
-
145
- self.bn_att = nn.BatchNorm1d(n_out)
146
- self.init_weights()
147
-
148
- def init_weights(self):
149
- init_layer(self.att)
150
- init_layer(self.cla)
151
- init_bn(self.bn_att)
152
-
153
- def forward(self, x):
154
- # x: (n_samples, n_in, n_time)
155
- norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
156
- cla = self.nonlinear_transform(self.cla(x))
157
- x = torch.sum(norm_att * cla, dim=2)
158
- return x, norm_att, cla
159
-
160
- def nonlinear_transform(self, x):
161
- if self.activation == "linear":
162
- return x
163
- elif self.activation == "sigmoid":
164
- return torch.sigmoid(x)
165
-
166
-
167
- class Cnn14(nn.Module):
168
- def __init__(
169
- self,
170
- sample_rate,
171
- window_size,
172
- hop_size,
173
- mel_bins,
174
- fmin,
175
- fmax,
176
- classes_num,
177
- enable_fusion=False,
178
- fusion_type="None",
179
- ):
180
- super(Cnn14, self).__init__()
181
-
182
- window = "hann"
183
- center = True
184
- pad_mode = "reflect"
185
- ref = 1.0
186
- amin = 1e-10
187
- top_db = None
188
-
189
- self.enable_fusion = enable_fusion
190
- self.fusion_type = fusion_type
191
-
192
- # Spectrogram extractor
193
- self.spectrogram_extractor = Spectrogram(
194
- n_fft=window_size,
195
- hop_length=hop_size,
196
- win_length=window_size,
197
- window=window,
198
- center=center,
199
- pad_mode=pad_mode,
200
- freeze_parameters=True,
201
- )
202
-
203
- # Logmel feature extractor
204
- self.logmel_extractor = LogmelFilterBank(
205
- sr=sample_rate,
206
- n_fft=window_size,
207
- n_mels=mel_bins,
208
- fmin=fmin,
209
- fmax=fmax,
210
- ref=ref,
211
- amin=amin,
212
- top_db=top_db,
213
- freeze_parameters=True,
214
- )
215
-
216
- # Spec augmenter
217
- self.spec_augmenter = SpecAugmentation(
218
- time_drop_width=64,
219
- time_stripes_num=2,
220
- freq_drop_width=8,
221
- freq_stripes_num=2,
222
- )
223
-
224
- self.bn0 = nn.BatchNorm2d(64)
225
-
226
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
227
- self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
228
- else:
229
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
230
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
231
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
232
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
233
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
234
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
235
-
236
- self.fc1 = nn.Linear(2048, 2048, bias=True)
237
- self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
238
-
239
- if (self.enable_fusion) and (
240
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
241
- ):
242
- self.mel_conv1d = nn.Sequential(
243
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
244
- nn.BatchNorm1d(64), # No Relu
245
- )
246
- if self.fusion_type == "daf_1d":
247
- self.fusion_model = DAF()
248
- elif self.fusion_type == "aff_1d":
249
- self.fusion_model = AFF(channels=64, type="1D")
250
- elif self.fusion_type == "iaff_1d":
251
- self.fusion_model = iAFF(channels=64, type="1D")
252
-
253
- if (self.enable_fusion) and (
254
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
255
- ):
256
- self.mel_conv2d = nn.Sequential(
257
- nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
258
- nn.BatchNorm2d(64),
259
- nn.ReLU(inplace=True),
260
- )
261
-
262
- if self.fusion_type == "daf_2d":
263
- self.fusion_model = DAF()
264
- elif self.fusion_type == "aff_2d":
265
- self.fusion_model = AFF(channels=64, type="2D")
266
- elif self.fusion_type == "iaff_2d":
267
- self.fusion_model = iAFF(channels=64, type="2D")
268
- self.init_weight()
269
-
270
- def init_weight(self):
271
- init_bn(self.bn0)
272
- init_layer(self.fc1)
273
- init_layer(self.fc_audioset)
274
-
275
- def forward(self, input, mixup_lambda=None, device=None):
276
- """
277
- Input: (batch_size, data_length)"""
278
-
279
- if self.enable_fusion and input["longer"].sum() == 0:
280
- # if no audio is longer than 10s, then randomly select one audio to be longer
281
- input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
282
-
283
- if not self.enable_fusion:
284
- x = self.spectrogram_extractor(
285
- input["waveform"].to(device=device, non_blocking=True)
286
- ) # (batch_size, 1, time_steps, freq_bins)
287
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
288
-
289
- x = x.transpose(1, 3)
290
- x = self.bn0(x)
291
- x = x.transpose(1, 3)
292
- else:
293
- longer_list = input["longer"].to(device=device, non_blocking=True)
294
- x = input["mel_fusion"].to(device=device, non_blocking=True)
295
- longer_list_idx = torch.where(longer_list)[0]
296
- x = x.transpose(1, 3)
297
- x = self.bn0(x)
298
- x = x.transpose(1, 3)
299
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
300
- new_x = x[:, 0:1, :, :].clone().contiguous()
301
- # local processing
302
- if len(longer_list_idx) > 0:
303
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
304
- FB, FC, FT, FF = fusion_x_local.size()
305
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
306
- fusion_x_local = torch.permute(
307
- fusion_x_local, (0, 2, 1)
308
- ).contiguous()
309
- fusion_x_local = self.mel_conv1d(fusion_x_local)
310
- fusion_x_local = fusion_x_local.view(
311
- FB, FC, FF, fusion_x_local.size(-1)
312
- )
313
- fusion_x_local = (
314
- torch.permute(fusion_x_local, (0, 2, 1, 3))
315
- .contiguous()
316
- .flatten(2)
317
- )
318
- if fusion_x_local.size(-1) < FT:
319
- fusion_x_local = torch.cat(
320
- [
321
- fusion_x_local,
322
- torch.zeros(
323
- (FB, FF, FT - fusion_x_local.size(-1)),
324
- device=device,
325
- ),
326
- ],
327
- dim=-1,
328
- )
329
- else:
330
- fusion_x_local = fusion_x_local[:, :, :FT]
331
- # 1D fusion
332
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
333
- new_x[longer_list_idx] = self.fusion_model(
334
- new_x[longer_list_idx], fusion_x_local
335
- )
336
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
337
- else:
338
- x = new_x
339
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
340
- x = x # no change
341
-
342
- if self.training:
343
- x = self.spec_augmenter(x)
344
- # Mixup on spectrogram
345
- if self.training and mixup_lambda is not None:
346
- x = do_mixup(x, mixup_lambda)
347
- if (self.enable_fusion) and (
348
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
349
- ):
350
- global_x = x[:, 0:1, :, :]
351
-
352
- # global processing
353
- B, C, H, W = global_x.shape
354
- global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
355
- if len(longer_list_idx) > 0:
356
- local_x = x[longer_list_idx, 1:, :, :].contiguous()
357
- TH = global_x.size(-2)
358
- # local processing
359
- B, C, H, W = local_x.shape
360
- local_x = local_x.view(B * C, 1, H, W)
361
- local_x = self.mel_conv2d(local_x)
362
- local_x = local_x.view(
363
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
364
- )
365
- local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
366
- TB, TC, _, TW = local_x.size()
367
- if local_x.size(-2) < TH:
368
- local_x = torch.cat(
369
- [
370
- local_x,
371
- torch.zeros(
372
- (TB, TC, TH - local_x.size(-2), TW),
373
- device=global_x.device,
374
- ),
375
- ],
376
- dim=-2,
377
- )
378
- else:
379
- local_x = local_x[:, :, :TH, :]
380
-
381
- global_x[longer_list_idx] = self.fusion_model(
382
- global_x[longer_list_idx], local_x
383
- )
384
- x = global_x
385
- else:
386
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
387
-
388
- x = F.dropout(x, p=0.2, training=self.training)
389
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
390
- x = F.dropout(x, p=0.2, training=self.training)
391
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
392
- x = F.dropout(x, p=0.2, training=self.training)
393
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
394
- x = F.dropout(x, p=0.2, training=self.training)
395
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
396
- x = F.dropout(x, p=0.2, training=self.training)
397
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
398
- x = F.dropout(x, p=0.2, training=self.training)
399
- x = torch.mean(x, dim=3)
400
-
401
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
402
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
403
- latent_x = latent_x1 + latent_x2
404
- latent_x = latent_x.transpose(1, 2)
405
- latent_x = F.relu_(self.fc1(latent_x))
406
- latent_output = interpolate(latent_x, 32)
407
-
408
- (x1, _) = torch.max(x, dim=2)
409
- x2 = torch.mean(x, dim=2)
410
- x = x1 + x2
411
- x = F.dropout(x, p=0.5, training=self.training)
412
- x = F.relu_(self.fc1(x))
413
- embedding = F.dropout(x, p=0.5, training=self.training)
414
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
415
-
416
- output_dict = {
417
- "clipwise_output": clipwise_output,
418
- "embedding": embedding,
419
- "fine_grained_embedding": latent_output,
420
- }
421
- return output_dict
422
-
423
-
424
- class Cnn6(nn.Module):
425
- def __init__(
426
- self,
427
- sample_rate,
428
- window_size,
429
- hop_size,
430
- mel_bins,
431
- fmin,
432
- fmax,
433
- classes_num,
434
- enable_fusion=False,
435
- fusion_type="None",
436
- ):
437
- super(Cnn6, self).__init__()
438
-
439
- window = "hann"
440
- center = True
441
- pad_mode = "reflect"
442
- ref = 1.0
443
- amin = 1e-10
444
- top_db = None
445
-
446
- self.enable_fusion = enable_fusion
447
- self.fusion_type = fusion_type
448
-
449
- # Spectrogram extractor
450
- self.spectrogram_extractor = Spectrogram(
451
- n_fft=window_size,
452
- hop_length=hop_size,
453
- win_length=window_size,
454
- window=window,
455
- center=center,
456
- pad_mode=pad_mode,
457
- freeze_parameters=True,
458
- )
459
-
460
- # Logmel feature extractor
461
- self.logmel_extractor = LogmelFilterBank(
462
- sr=sample_rate,
463
- n_fft=window_size,
464
- n_mels=mel_bins,
465
- fmin=fmin,
466
- fmax=fmax,
467
- ref=ref,
468
- amin=amin,
469
- top_db=top_db,
470
- freeze_parameters=True,
471
- )
472
-
473
- # Spec augmenter
474
- self.spec_augmenter = SpecAugmentation(
475
- time_drop_width=64,
476
- time_stripes_num=2,
477
- freq_drop_width=8,
478
- freq_stripes_num=2,
479
- )
480
-
481
- self.bn0 = nn.BatchNorm2d(64)
482
-
483
- self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
484
- self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
485
- self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
486
- self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
487
-
488
- self.fc1 = nn.Linear(512, 512, bias=True)
489
- self.fc_audioset = nn.Linear(512, classes_num, bias=True)
490
-
491
- self.init_weight()
492
-
493
- def init_weight(self):
494
- init_bn(self.bn0)
495
- init_layer(self.fc1)
496
- init_layer(self.fc_audioset)
497
-
498
- def forward(self, input, mixup_lambda=None, device=None):
499
- """
500
- Input: (batch_size, data_length)"""
501
-
502
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
503
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
504
-
505
- x = x.transpose(1, 3)
506
- x = self.bn0(x)
507
- x = x.transpose(1, 3)
508
-
509
- if self.training:
510
- x = self.spec_augmenter(x)
511
-
512
- # Mixup on spectrogram
513
- if self.training and mixup_lambda is not None:
514
- x = do_mixup(x, mixup_lambda)
515
-
516
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
517
- x = F.dropout(x, p=0.2, training=self.training)
518
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
519
- x = F.dropout(x, p=0.2, training=self.training)
520
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
521
- x = F.dropout(x, p=0.2, training=self.training)
522
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
523
- x = F.dropout(x, p=0.2, training=self.training)
524
- x = torch.mean(x, dim=3)
525
-
526
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
527
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
528
- latent_x = latent_x1 + latent_x2
529
- latent_x = latent_x.transpose(1, 2)
530
- latent_x = F.relu_(self.fc1(latent_x))
531
- latent_output = interpolate(latent_x, 16)
532
-
533
- (x1, _) = torch.max(x, dim=2)
534
- x2 = torch.mean(x, dim=2)
535
- x = x1 + x2
536
- x = F.dropout(x, p=0.5, training=self.training)
537
- x = F.relu_(self.fc1(x))
538
- embedding = F.dropout(x, p=0.5, training=self.training)
539
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
540
-
541
- output_dict = {
542
- "clipwise_output": clipwise_output,
543
- "embedding": embedding,
544
- "fine_grained_embedding": latent_output,
545
- }
546
-
547
- return output_dict
548
-
549
-
550
- class Cnn10(nn.Module):
551
- def __init__(
552
- self,
553
- sample_rate,
554
- window_size,
555
- hop_size,
556
- mel_bins,
557
- fmin,
558
- fmax,
559
- classes_num,
560
- enable_fusion=False,
561
- fusion_type="None",
562
- ):
563
- super(Cnn10, self).__init__()
564
-
565
- window = "hann"
566
- center = True
567
- pad_mode = "reflect"
568
- ref = 1.0
569
- amin = 1e-10
570
- top_db = None
571
-
572
- self.enable_fusion = enable_fusion
573
- self.fusion_type = fusion_type
574
-
575
- # Spectrogram extractor
576
- self.spectrogram_extractor = Spectrogram(
577
- n_fft=window_size,
578
- hop_length=hop_size,
579
- win_length=window_size,
580
- window=window,
581
- center=center,
582
- pad_mode=pad_mode,
583
- freeze_parameters=True,
584
- )
585
-
586
- # Logmel feature extractor
587
- self.logmel_extractor = LogmelFilterBank(
588
- sr=sample_rate,
589
- n_fft=window_size,
590
- n_mels=mel_bins,
591
- fmin=fmin,
592
- fmax=fmax,
593
- ref=ref,
594
- amin=amin,
595
- top_db=top_db,
596
- freeze_parameters=True,
597
- )
598
-
599
- # Spec augmenter
600
- self.spec_augmenter = SpecAugmentation(
601
- time_drop_width=64,
602
- time_stripes_num=2,
603
- freq_drop_width=8,
604
- freq_stripes_num=2,
605
- )
606
-
607
- self.bn0 = nn.BatchNorm2d(64)
608
-
609
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
610
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
611
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
612
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
613
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
614
-
615
- self.fc1 = nn.Linear(1024, 1024, bias=True)
616
- self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
617
-
618
- self.init_weight()
619
-
620
- def init_weight(self):
621
- init_bn(self.bn0)
622
- init_layer(self.fc1)
623
- init_layer(self.fc_audioset)
624
-
625
- def forward(self, input, mixup_lambda=None, device=None):
626
- """
627
- Input: (batch_size, data_length)"""
628
-
629
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
630
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
631
-
632
- x = x.transpose(1, 3)
633
- x = self.bn0(x)
634
- x = x.transpose(1, 3)
635
-
636
- if self.training:
637
- x = self.spec_augmenter(x)
638
-
639
- # Mixup on spectrogram
640
- if self.training and mixup_lambda is not None:
641
- x = do_mixup(x, mixup_lambda)
642
-
643
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
644
- x = F.dropout(x, p=0.2, training=self.training)
645
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
646
- x = F.dropout(x, p=0.2, training=self.training)
647
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
648
- x = F.dropout(x, p=0.2, training=self.training)
649
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
650
- x = F.dropout(x, p=0.2, training=self.training)
651
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
652
- x = F.dropout(x, p=0.2, training=self.training)
653
- x = torch.mean(x, dim=3)
654
-
655
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
656
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
657
- latent_x = latent_x1 + latent_x2
658
- latent_x = latent_x.transpose(1, 2)
659
- latent_x = F.relu_(self.fc1(latent_x))
660
- latent_output = interpolate(latent_x, 32)
661
-
662
- (x1, _) = torch.max(x, dim=2)
663
- x2 = torch.mean(x, dim=2)
664
- x = x1 + x2
665
- x = F.dropout(x, p=0.5, training=self.training)
666
- x = F.relu_(self.fc1(x))
667
- embedding = F.dropout(x, p=0.5, training=self.training)
668
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
669
-
670
- output_dict = {
671
- "clipwise_output": clipwise_output,
672
- "embedding": embedding,
673
- "fine_grained_embedding": latent_output,
674
- }
675
-
676
- return output_dict
677
-
678
-
679
- def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
680
- try:
681
- ModelProto = eval(audio_cfg.model_name)
682
- model = ModelProto(
683
- sample_rate=audio_cfg.sample_rate,
684
- window_size=audio_cfg.window_size,
685
- hop_size=audio_cfg.hop_size,
686
- mel_bins=audio_cfg.mel_bins,
687
- fmin=audio_cfg.fmin,
688
- fmax=audio_cfg.fmax,
689
- classes_num=audio_cfg.class_num,
690
- enable_fusion=enable_fusion,
691
- fusion_type=fusion_type,
692
- )
693
- return model
694
- except:
695
- raise RuntimeError(
696
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
697
- )
 
1
+ # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
+ # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
+ # Some layers are re-designed for CLAP
4
+ import os
5
+
6
+ os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
+ from torchlibrosa.augmentation import SpecAugmentation
13
+
14
+ from .utils import do_mixup, interpolate
15
+ from .feature_fusion import iAFF, AFF, DAF
16
+
17
+
18
+ def init_layer(layer):
19
+ """Initialize a Linear or Convolutional layer."""
20
+ nn.init.xavier_uniform_(layer.weight)
21
+
22
+ if hasattr(layer, "bias"):
23
+ if layer.bias is not None:
24
+ layer.bias.data.fill_(0.0)
25
+
26
+
27
+ def init_bn(bn):
28
+ """Initialize a Batchnorm layer."""
29
+ bn.bias.data.fill_(0.0)
30
+ bn.weight.data.fill_(1.0)
31
+
32
+
33
+ class ConvBlock(nn.Module):
34
+ def __init__(self, in_channels, out_channels):
35
+ super(ConvBlock, self).__init__()
36
+
37
+ self.conv1 = nn.Conv2d(
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ kernel_size=(3, 3),
41
+ stride=(1, 1),
42
+ padding=(1, 1),
43
+ bias=False,
44
+ )
45
+
46
+ self.conv2 = nn.Conv2d(
47
+ in_channels=out_channels,
48
+ out_channels=out_channels,
49
+ kernel_size=(3, 3),
50
+ stride=(1, 1),
51
+ padding=(1, 1),
52
+ bias=False,
53
+ )
54
+
55
+ self.bn1 = nn.BatchNorm2d(out_channels)
56
+ self.bn2 = nn.BatchNorm2d(out_channels)
57
+
58
+ self.init_weight()
59
+
60
+ def init_weight(self):
61
+ init_layer(self.conv1)
62
+ init_layer(self.conv2)
63
+ init_bn(self.bn1)
64
+ init_bn(self.bn2)
65
+
66
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
+ x = input
68
+ x = F.relu_(self.bn1(self.conv1(x)))
69
+ x = F.relu_(self.bn2(self.conv2(x)))
70
+ if pool_type == "max":
71
+ x = F.max_pool2d(x, kernel_size=pool_size)
72
+ elif pool_type == "avg":
73
+ x = F.avg_pool2d(x, kernel_size=pool_size)
74
+ elif pool_type == "avg+max":
75
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
76
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
77
+ x = x1 + x2
78
+ else:
79
+ raise Exception("Incorrect argument!")
80
+
81
+ return x
82
+
83
+
84
+ class ConvBlock5x5(nn.Module):
85
+ def __init__(self, in_channels, out_channels):
86
+ super(ConvBlock5x5, self).__init__()
87
+
88
+ self.conv1 = nn.Conv2d(
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ kernel_size=(5, 5),
92
+ stride=(1, 1),
93
+ padding=(2, 2),
94
+ bias=False,
95
+ )
96
+
97
+ self.bn1 = nn.BatchNorm2d(out_channels)
98
+
99
+ self.init_weight()
100
+
101
+ def init_weight(self):
102
+ init_layer(self.conv1)
103
+ init_bn(self.bn1)
104
+
105
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
106
+ x = input
107
+ x = F.relu_(self.bn1(self.conv1(x)))
108
+ if pool_type == "max":
109
+ x = F.max_pool2d(x, kernel_size=pool_size)
110
+ elif pool_type == "avg":
111
+ x = F.avg_pool2d(x, kernel_size=pool_size)
112
+ elif pool_type == "avg+max":
113
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
114
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
115
+ x = x1 + x2
116
+ else:
117
+ raise Exception("Incorrect argument!")
118
+
119
+ return x
120
+
121
+
122
+ class AttBlock(nn.Module):
123
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
124
+ super(AttBlock, self).__init__()
125
+
126
+ self.activation = activation
127
+ self.temperature = temperature
128
+ self.att = nn.Conv1d(
129
+ in_channels=n_in,
130
+ out_channels=n_out,
131
+ kernel_size=1,
132
+ stride=1,
133
+ padding=0,
134
+ bias=True,
135
+ )
136
+ self.cla = nn.Conv1d(
137
+ in_channels=n_in,
138
+ out_channels=n_out,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0,
142
+ bias=True,
143
+ )
144
+
145
+ self.bn_att = nn.BatchNorm1d(n_out)
146
+ self.init_weights()
147
+
148
+ def init_weights(self):
149
+ init_layer(self.att)
150
+ init_layer(self.cla)
151
+ init_bn(self.bn_att)
152
+
153
+ def forward(self, x):
154
+ # x: (n_samples, n_in, n_time)
155
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
156
+ cla = self.nonlinear_transform(self.cla(x))
157
+ x = torch.sum(norm_att * cla, dim=2)
158
+ return x, norm_att, cla
159
+
160
+ def nonlinear_transform(self, x):
161
+ if self.activation == "linear":
162
+ return x
163
+ elif self.activation == "sigmoid":
164
+ return torch.sigmoid(x)
165
+
166
+
167
+ class Cnn14(nn.Module):
168
+ def __init__(
169
+ self,
170
+ sample_rate,
171
+ window_size,
172
+ hop_size,
173
+ mel_bins,
174
+ fmin,
175
+ fmax,
176
+ classes_num,
177
+ enable_fusion=False,
178
+ fusion_type="None",
179
+ ):
180
+ super(Cnn14, self).__init__()
181
+
182
+ window = "hann"
183
+ center = True
184
+ pad_mode = "reflect"
185
+ ref = 1.0
186
+ amin = 1e-10
187
+ top_db = None
188
+
189
+ self.enable_fusion = enable_fusion
190
+ self.fusion_type = fusion_type
191
+
192
+ # Spectrogram extractor
193
+ self.spectrogram_extractor = Spectrogram(
194
+ n_fft=window_size,
195
+ hop_length=hop_size,
196
+ win_length=window_size,
197
+ window=window,
198
+ center=center,
199
+ pad_mode=pad_mode,
200
+ freeze_parameters=True,
201
+ )
202
+
203
+ # Logmel feature extractor
204
+ self.logmel_extractor = LogmelFilterBank(
205
+ sr=sample_rate,
206
+ n_fft=window_size,
207
+ n_mels=mel_bins,
208
+ fmin=fmin,
209
+ fmax=fmax,
210
+ ref=ref,
211
+ amin=amin,
212
+ top_db=top_db,
213
+ freeze_parameters=True,
214
+ )
215
+
216
+ # Spec augmenter
217
+ self.spec_augmenter = SpecAugmentation(
218
+ time_drop_width=64,
219
+ time_stripes_num=2,
220
+ freq_drop_width=8,
221
+ freq_stripes_num=2,
222
+ )
223
+
224
+ self.bn0 = nn.BatchNorm2d(64)
225
+
226
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
227
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
228
+ else:
229
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
230
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
231
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
232
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
233
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
234
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
235
+
236
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
237
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
238
+
239
+ if (self.enable_fusion) and (
240
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
241
+ ):
242
+ self.mel_conv1d = nn.Sequential(
243
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
244
+ nn.BatchNorm1d(64), # No Relu
245
+ )
246
+ if self.fusion_type == "daf_1d":
247
+ self.fusion_model = DAF()
248
+ elif self.fusion_type == "aff_1d":
249
+ self.fusion_model = AFF(channels=64, type="1D")
250
+ elif self.fusion_type == "iaff_1d":
251
+ self.fusion_model = iAFF(channels=64, type="1D")
252
+
253
+ if (self.enable_fusion) and (
254
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
255
+ ):
256
+ self.mel_conv2d = nn.Sequential(
257
+ nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
258
+ nn.BatchNorm2d(64),
259
+ nn.ReLU(inplace=True),
260
+ )
261
+
262
+ if self.fusion_type == "daf_2d":
263
+ self.fusion_model = DAF()
264
+ elif self.fusion_type == "aff_2d":
265
+ self.fusion_model = AFF(channels=64, type="2D")
266
+ elif self.fusion_type == "iaff_2d":
267
+ self.fusion_model = iAFF(channels=64, type="2D")
268
+ self.init_weight()
269
+
270
+ def init_weight(self):
271
+ init_bn(self.bn0)
272
+ init_layer(self.fc1)
273
+ init_layer(self.fc_audioset)
274
+
275
+ def forward(self, input, mixup_lambda=None, device=None):
276
+ """
277
+ Input: (batch_size, data_length)"""
278
+
279
+ if self.enable_fusion and input["longer"].sum() == 0:
280
+ # if no audio is longer than 10s, then randomly select one audio to be longer
281
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
282
+
283
+ if not self.enable_fusion:
284
+ x = self.spectrogram_extractor(
285
+ input["waveform"].to(device=device, non_blocking=True)
286
+ ) # (batch_size, 1, time_steps, freq_bins)
287
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
288
+
289
+ x = x.transpose(1, 3)
290
+ x = self.bn0(x)
291
+ x = x.transpose(1, 3)
292
+ else:
293
+ longer_list = input["longer"].to(device=device, non_blocking=True)
294
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
295
+ longer_list_idx = torch.where(longer_list)[0]
296
+ x = x.transpose(1, 3)
297
+ x = self.bn0(x)
298
+ x = x.transpose(1, 3)
299
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
300
+ new_x = x[:, 0:1, :, :].clone().contiguous()
301
+ # local processing
302
+ if len(longer_list_idx) > 0:
303
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
304
+ FB, FC, FT, FF = fusion_x_local.size()
305
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
306
+ fusion_x_local = torch.permute(
307
+ fusion_x_local, (0, 2, 1)
308
+ ).contiguous()
309
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
310
+ fusion_x_local = fusion_x_local.view(
311
+ FB, FC, FF, fusion_x_local.size(-1)
312
+ )
313
+ fusion_x_local = (
314
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
315
+ .contiguous()
316
+ .flatten(2)
317
+ )
318
+ if fusion_x_local.size(-1) < FT:
319
+ fusion_x_local = torch.cat(
320
+ [
321
+ fusion_x_local,
322
+ torch.zeros(
323
+ (FB, FF, FT - fusion_x_local.size(-1)),
324
+ device=device,
325
+ ),
326
+ ],
327
+ dim=-1,
328
+ )
329
+ else:
330
+ fusion_x_local = fusion_x_local[:, :, :FT]
331
+ # 1D fusion
332
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
333
+ new_x[longer_list_idx] = self.fusion_model(
334
+ new_x[longer_list_idx], fusion_x_local
335
+ )
336
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
337
+ else:
338
+ x = new_x
339
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
340
+ x = x # no change
341
+
342
+ if self.training:
343
+ x = self.spec_augmenter(x)
344
+ # Mixup on spectrogram
345
+ if self.training and mixup_lambda is not None:
346
+ x = do_mixup(x, mixup_lambda)
347
+ if (self.enable_fusion) and (
348
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
349
+ ):
350
+ global_x = x[:, 0:1, :, :]
351
+
352
+ # global processing
353
+ B, C, H, W = global_x.shape
354
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
355
+ if len(longer_list_idx) > 0:
356
+ local_x = x[longer_list_idx, 1:, :, :].contiguous()
357
+ TH = global_x.size(-2)
358
+ # local processing
359
+ B, C, H, W = local_x.shape
360
+ local_x = local_x.view(B * C, 1, H, W)
361
+ local_x = self.mel_conv2d(local_x)
362
+ local_x = local_x.view(
363
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
364
+ )
365
+ local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
366
+ TB, TC, _, TW = local_x.size()
367
+ if local_x.size(-2) < TH:
368
+ local_x = torch.cat(
369
+ [
370
+ local_x,
371
+ torch.zeros(
372
+ (TB, TC, TH - local_x.size(-2), TW),
373
+ device=global_x.device,
374
+ ),
375
+ ],
376
+ dim=-2,
377
+ )
378
+ else:
379
+ local_x = local_x[:, :, :TH, :]
380
+
381
+ global_x[longer_list_idx] = self.fusion_model(
382
+ global_x[longer_list_idx], local_x
383
+ )
384
+ x = global_x
385
+ else:
386
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
387
+
388
+ x = F.dropout(x, p=0.2, training=self.training)
389
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
390
+ x = F.dropout(x, p=0.2, training=self.training)
391
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
392
+ x = F.dropout(x, p=0.2, training=self.training)
393
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
394
+ x = F.dropout(x, p=0.2, training=self.training)
395
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
396
+ x = F.dropout(x, p=0.2, training=self.training)
397
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
398
+ x = F.dropout(x, p=0.2, training=self.training)
399
+ x = torch.mean(x, dim=3)
400
+
401
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
402
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
403
+ latent_x = latent_x1 + latent_x2
404
+ latent_x = latent_x.transpose(1, 2)
405
+ latent_x = F.relu_(self.fc1(latent_x))
406
+ latent_output = interpolate(latent_x, 32)
407
+
408
+ (x1, _) = torch.max(x, dim=2)
409
+ x2 = torch.mean(x, dim=2)
410
+ x = x1 + x2
411
+ x = F.dropout(x, p=0.5, training=self.training)
412
+ x = F.relu_(self.fc1(x))
413
+ embedding = F.dropout(x, p=0.5, training=self.training)
414
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
415
+
416
+ output_dict = {
417
+ "clipwise_output": clipwise_output,
418
+ "embedding": embedding,
419
+ "fine_grained_embedding": latent_output,
420
+ }
421
+ return output_dict
422
+
423
+
424
+ class Cnn6(nn.Module):
425
+ def __init__(
426
+ self,
427
+ sample_rate,
428
+ window_size,
429
+ hop_size,
430
+ mel_bins,
431
+ fmin,
432
+ fmax,
433
+ classes_num,
434
+ enable_fusion=False,
435
+ fusion_type="None",
436
+ ):
437
+ super(Cnn6, self).__init__()
438
+
439
+ window = "hann"
440
+ center = True
441
+ pad_mode = "reflect"
442
+ ref = 1.0
443
+ amin = 1e-10
444
+ top_db = None
445
+
446
+ self.enable_fusion = enable_fusion
447
+ self.fusion_type = fusion_type
448
+
449
+ # Spectrogram extractor
450
+ self.spectrogram_extractor = Spectrogram(
451
+ n_fft=window_size,
452
+ hop_length=hop_size,
453
+ win_length=window_size,
454
+ window=window,
455
+ center=center,
456
+ pad_mode=pad_mode,
457
+ freeze_parameters=True,
458
+ )
459
+
460
+ # Logmel feature extractor
461
+ self.logmel_extractor = LogmelFilterBank(
462
+ sr=sample_rate,
463
+ n_fft=window_size,
464
+ n_mels=mel_bins,
465
+ fmin=fmin,
466
+ fmax=fmax,
467
+ ref=ref,
468
+ amin=amin,
469
+ top_db=top_db,
470
+ freeze_parameters=True,
471
+ )
472
+
473
+ # Spec augmenter
474
+ self.spec_augmenter = SpecAugmentation(
475
+ time_drop_width=64,
476
+ time_stripes_num=2,
477
+ freq_drop_width=8,
478
+ freq_stripes_num=2,
479
+ )
480
+
481
+ self.bn0 = nn.BatchNorm2d(64)
482
+
483
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
484
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
485
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
486
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
487
+
488
+ self.fc1 = nn.Linear(512, 512, bias=True)
489
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
490
+
491
+ self.init_weight()
492
+
493
+ def init_weight(self):
494
+ init_bn(self.bn0)
495
+ init_layer(self.fc1)
496
+ init_layer(self.fc_audioset)
497
+
498
+ def forward(self, input, mixup_lambda=None, device=None):
499
+ """
500
+ Input: (batch_size, data_length)"""
501
+
502
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
503
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
504
+
505
+ x = x.transpose(1, 3)
506
+ x = self.bn0(x)
507
+ x = x.transpose(1, 3)
508
+
509
+ if self.training:
510
+ x = self.spec_augmenter(x)
511
+
512
+ # Mixup on spectrogram
513
+ if self.training and mixup_lambda is not None:
514
+ x = do_mixup(x, mixup_lambda)
515
+
516
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
517
+ x = F.dropout(x, p=0.2, training=self.training)
518
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
519
+ x = F.dropout(x, p=0.2, training=self.training)
520
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
521
+ x = F.dropout(x, p=0.2, training=self.training)
522
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
523
+ x = F.dropout(x, p=0.2, training=self.training)
524
+ x = torch.mean(x, dim=3)
525
+
526
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
527
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
528
+ latent_x = latent_x1 + latent_x2
529
+ latent_x = latent_x.transpose(1, 2)
530
+ latent_x = F.relu_(self.fc1(latent_x))
531
+ latent_output = interpolate(latent_x, 16)
532
+
533
+ (x1, _) = torch.max(x, dim=2)
534
+ x2 = torch.mean(x, dim=2)
535
+ x = x1 + x2
536
+ x = F.dropout(x, p=0.5, training=self.training)
537
+ x = F.relu_(self.fc1(x))
538
+ embedding = F.dropout(x, p=0.5, training=self.training)
539
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
540
+
541
+ output_dict = {
542
+ "clipwise_output": clipwise_output,
543
+ "embedding": embedding,
544
+ "fine_grained_embedding": latent_output,
545
+ }
546
+
547
+ return output_dict
548
+
549
+
550
+ class Cnn10(nn.Module):
551
+ def __init__(
552
+ self,
553
+ sample_rate,
554
+ window_size,
555
+ hop_size,
556
+ mel_bins,
557
+ fmin,
558
+ fmax,
559
+ classes_num,
560
+ enable_fusion=False,
561
+ fusion_type="None",
562
+ ):
563
+ super(Cnn10, self).__init__()
564
+
565
+ window = "hann"
566
+ center = True
567
+ pad_mode = "reflect"
568
+ ref = 1.0
569
+ amin = 1e-10
570
+ top_db = None
571
+
572
+ self.enable_fusion = enable_fusion
573
+ self.fusion_type = fusion_type
574
+
575
+ # Spectrogram extractor
576
+ self.spectrogram_extractor = Spectrogram(
577
+ n_fft=window_size,
578
+ hop_length=hop_size,
579
+ win_length=window_size,
580
+ window=window,
581
+ center=center,
582
+ pad_mode=pad_mode,
583
+ freeze_parameters=True,
584
+ )
585
+
586
+ # Logmel feature extractor
587
+ self.logmel_extractor = LogmelFilterBank(
588
+ sr=sample_rate,
589
+ n_fft=window_size,
590
+ n_mels=mel_bins,
591
+ fmin=fmin,
592
+ fmax=fmax,
593
+ ref=ref,
594
+ amin=amin,
595
+ top_db=top_db,
596
+ freeze_parameters=True,
597
+ )
598
+
599
+ # Spec augmenter
600
+ self.spec_augmenter = SpecAugmentation(
601
+ time_drop_width=64,
602
+ time_stripes_num=2,
603
+ freq_drop_width=8,
604
+ freq_stripes_num=2,
605
+ )
606
+
607
+ self.bn0 = nn.BatchNorm2d(64)
608
+
609
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
610
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
611
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
612
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
613
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
614
+
615
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
616
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
617
+
618
+ self.init_weight()
619
+
620
+ def init_weight(self):
621
+ init_bn(self.bn0)
622
+ init_layer(self.fc1)
623
+ init_layer(self.fc_audioset)
624
+
625
+ def forward(self, input, mixup_lambda=None, device=None):
626
+ """
627
+ Input: (batch_size, data_length)"""
628
+
629
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
630
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
631
+
632
+ x = x.transpose(1, 3)
633
+ x = self.bn0(x)
634
+ x = x.transpose(1, 3)
635
+
636
+ if self.training:
637
+ x = self.spec_augmenter(x)
638
+
639
+ # Mixup on spectrogram
640
+ if self.training and mixup_lambda is not None:
641
+ x = do_mixup(x, mixup_lambda)
642
+
643
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
644
+ x = F.dropout(x, p=0.2, training=self.training)
645
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
646
+ x = F.dropout(x, p=0.2, training=self.training)
647
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
648
+ x = F.dropout(x, p=0.2, training=self.training)
649
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
650
+ x = F.dropout(x, p=0.2, training=self.training)
651
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
652
+ x = F.dropout(x, p=0.2, training=self.training)
653
+ x = torch.mean(x, dim=3)
654
+
655
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
656
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
657
+ latent_x = latent_x1 + latent_x2
658
+ latent_x = latent_x.transpose(1, 2)
659
+ latent_x = F.relu_(self.fc1(latent_x))
660
+ latent_output = interpolate(latent_x, 32)
661
+
662
+ (x1, _) = torch.max(x, dim=2)
663
+ x2 = torch.mean(x, dim=2)
664
+ x = x1 + x2
665
+ x = F.dropout(x, p=0.5, training=self.training)
666
+ x = F.relu_(self.fc1(x))
667
+ embedding = F.dropout(x, p=0.5, training=self.training)
668
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
669
+
670
+ output_dict = {
671
+ "clipwise_output": clipwise_output,
672
+ "embedding": embedding,
673
+ "fine_grained_embedding": latent_output,
674
+ }
675
+
676
+ return output_dict
677
+
678
+
679
+ def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
680
+ try:
681
+ ModelProto = eval(audio_cfg.model_name)
682
+ model = ModelProto(
683
+ sample_rate=audio_cfg.sample_rate,
684
+ window_size=audio_cfg.window_size,
685
+ hop_size=audio_cfg.hop_size,
686
+ mel_bins=audio_cfg.mel_bins,
687
+ fmin=audio_cfg.fmin,
688
+ fmax=audio_cfg.fmax,
689
+ classes_num=audio_cfg.class_num,
690
+ enable_fusion=enable_fusion,
691
+ fusion_type=fusion_type,
692
+ )
693
+ return model
694
+ except:
695
+ raise RuntimeError(
696
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
697
+ )
audiosr/clap/open_clip/pretrained.py CHANGED
@@ -1,167 +1,167 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
-
6
- from tqdm import tqdm
7
-
8
- _RN50 = dict(
9
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
- )
13
-
14
- _RN50_quickgelu = dict(
15
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
- )
19
-
20
- _RN101 = dict(
21
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
- )
24
-
25
- _RN101_quickgelu = dict(
26
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
- )
29
-
30
- _RN50x4 = dict(
31
- openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
- )
33
-
34
- _RN50x16 = dict(
35
- openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
- )
37
-
38
- _RN50x64 = dict(
39
- openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
- )
41
-
42
- _VITB32 = dict(
43
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
- )
48
-
49
- _VITB32_quickgelu = dict(
50
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
- )
55
-
56
- _VITB16 = dict(
57
- openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
- )
59
-
60
- _VITL14 = dict(
61
- openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
- )
63
-
64
- _PRETRAINED = {
65
- "RN50": _RN50,
66
- "RN50-quickgelu": _RN50_quickgelu,
67
- "RN101": _RN101,
68
- "RN101-quickgelu": _RN101_quickgelu,
69
- "RN50x4": _RN50x4,
70
- "RN50x16": _RN50x16,
71
- "ViT-B-32": _VITB32,
72
- "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
- "ViT-B-16": _VITB16,
74
- "ViT-L-14": _VITL14,
75
- }
76
-
77
-
78
- def list_pretrained(as_str: bool = False):
79
- """returns list of pretrained models
80
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
- """
82
- return [
83
- ":".join([k, t]) if as_str else (k, t)
84
- for k in _PRETRAINED.keys()
85
- for t in _PRETRAINED[k].keys()
86
- ]
87
-
88
-
89
- def list_pretrained_tag_models(tag: str):
90
- """return all models having the specified pretrain tag"""
91
- models = []
92
- for k in _PRETRAINED.keys():
93
- if tag in _PRETRAINED[k]:
94
- models.append(k)
95
- return models
96
-
97
-
98
- def list_pretrained_model_tags(model: str):
99
- """return all pretrain tags for the specified model architecture"""
100
- tags = []
101
- if model in _PRETRAINED:
102
- tags.extend(_PRETRAINED[model].keys())
103
- return tags
104
-
105
-
106
- def get_pretrained_url(model: str, tag: str):
107
- if model not in _PRETRAINED:
108
- return ""
109
- model_pretrained = _PRETRAINED[model]
110
- if tag not in model_pretrained:
111
- return ""
112
- return model_pretrained[tag]
113
-
114
-
115
- def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
- os.makedirs(root, exist_ok=True)
117
- filename = os.path.basename(url)
118
-
119
- if "openaipublic" in url:
120
- expected_sha256 = url.split("/")[-2]
121
- else:
122
- expected_sha256 = ""
123
-
124
- download_target = os.path.join(root, filename)
125
-
126
- if os.path.exists(download_target) and not os.path.isfile(download_target):
127
- raise RuntimeError(f"{download_target} exists and is not a regular file")
128
-
129
- if os.path.isfile(download_target):
130
- if expected_sha256:
131
- if (
132
- hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
- == expected_sha256
134
- ):
135
- return download_target
136
- else:
137
- warnings.warn(
138
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
- )
140
- else:
141
- return download_target
142
-
143
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
- with tqdm(
145
- total=int(source.info().get("Content-Length")),
146
- ncols=80,
147
- unit="iB",
148
- unit_scale=True,
149
- ) as loop:
150
- while True:
151
- buffer = source.read(8192)
152
- if not buffer:
153
- break
154
-
155
- output.write(buffer)
156
- loop.update(len(buffer))
157
-
158
- if (
159
- expected_sha256
160
- and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
- != expected_sha256
162
- ):
163
- raise RuntimeError(
164
- f"Model has been downloaded but the SHA256 checksum does not not match"
165
- )
166
-
167
- return download_target
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+
6
+ from tqdm import tqdm
7
+
8
+ _RN50 = dict(
9
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
+ )
13
+
14
+ _RN50_quickgelu = dict(
15
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
+ )
19
+
20
+ _RN101 = dict(
21
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
+ )
24
+
25
+ _RN101_quickgelu = dict(
26
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
+ )
29
+
30
+ _RN50x4 = dict(
31
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
+ )
33
+
34
+ _RN50x16 = dict(
35
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
+ )
37
+
38
+ _RN50x64 = dict(
39
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
+ )
41
+
42
+ _VITB32 = dict(
43
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
+ )
48
+
49
+ _VITB32_quickgelu = dict(
50
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
+ )
55
+
56
+ _VITB16 = dict(
57
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
+ )
59
+
60
+ _VITL14 = dict(
61
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
+ )
63
+
64
+ _PRETRAINED = {
65
+ "RN50": _RN50,
66
+ "RN50-quickgelu": _RN50_quickgelu,
67
+ "RN101": _RN101,
68
+ "RN101-quickgelu": _RN101_quickgelu,
69
+ "RN50x4": _RN50x4,
70
+ "RN50x16": _RN50x16,
71
+ "ViT-B-32": _VITB32,
72
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
+ "ViT-B-16": _VITB16,
74
+ "ViT-L-14": _VITL14,
75
+ }
76
+
77
+
78
+ def list_pretrained(as_str: bool = False):
79
+ """returns list of pretrained models
80
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
+ """
82
+ return [
83
+ ":".join([k, t]) if as_str else (k, t)
84
+ for k in _PRETRAINED.keys()
85
+ for t in _PRETRAINED[k].keys()
86
+ ]
87
+
88
+
89
+ def list_pretrained_tag_models(tag: str):
90
+ """return all models having the specified pretrain tag"""
91
+ models = []
92
+ for k in _PRETRAINED.keys():
93
+ if tag in _PRETRAINED[k]:
94
+ models.append(k)
95
+ return models
96
+
97
+
98
+ def list_pretrained_model_tags(model: str):
99
+ """return all pretrain tags for the specified model architecture"""
100
+ tags = []
101
+ if model in _PRETRAINED:
102
+ tags.extend(_PRETRAINED[model].keys())
103
+ return tags
104
+
105
+
106
+ def get_pretrained_url(model: str, tag: str):
107
+ if model not in _PRETRAINED:
108
+ return ""
109
+ model_pretrained = _PRETRAINED[model]
110
+ if tag not in model_pretrained:
111
+ return ""
112
+ return model_pretrained[tag]
113
+
114
+
115
+ def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
+ os.makedirs(root, exist_ok=True)
117
+ filename = os.path.basename(url)
118
+
119
+ if "openaipublic" in url:
120
+ expected_sha256 = url.split("/")[-2]
121
+ else:
122
+ expected_sha256 = ""
123
+
124
+ download_target = os.path.join(root, filename)
125
+
126
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
127
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
128
+
129
+ if os.path.isfile(download_target):
130
+ if expected_sha256:
131
+ if (
132
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
+ == expected_sha256
134
+ ):
135
+ return download_target
136
+ else:
137
+ warnings.warn(
138
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
+ )
140
+ else:
141
+ return download_target
142
+
143
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
+ with tqdm(
145
+ total=int(source.info().get("Content-Length")),
146
+ ncols=80,
147
+ unit="iB",
148
+ unit_scale=True,
149
+ ) as loop:
150
+ while True:
151
+ buffer = source.read(8192)
152
+ if not buffer:
153
+ break
154
+
155
+ output.write(buffer)
156
+ loop.update(len(buffer))
157
+
158
+ if (
159
+ expected_sha256
160
+ and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
+ != expected_sha256
162
+ ):
163
+ raise RuntimeError(
164
+ f"Model has been downloaded but the SHA256 checksum does not not match"
165
+ )
166
+
167
+ return download_target
audiosr/clap/open_clip/timm_model.py CHANGED
@@ -1,112 +1,112 @@
1
- """ timm model adapter
2
-
3
- Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
- """
5
- from collections import OrderedDict
6
-
7
- import torch.nn as nn
8
-
9
- try:
10
- import timm
11
- from timm.models.layers import Mlp, to_2tuple
12
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
- from timm.models.layers.attention_pool2d import (
14
- AttentionPool2d as AbsAttentionPool2d,
15
- )
16
- except ImportError:
17
- timm = None
18
-
19
- from .utils import freeze_batch_norm_2d
20
-
21
-
22
- class TimmModel(nn.Module):
23
- """timm model adapter
24
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
- """
26
-
27
- def __init__(
28
- self,
29
- model_name,
30
- embed_dim,
31
- image_size=224,
32
- pool="avg",
33
- proj="linear",
34
- drop=0.0,
35
- pretrained=False,
36
- ):
37
- super().__init__()
38
- if timm is None:
39
- raise RuntimeError("Please `pip install timm` to use timm models.")
40
-
41
- self.image_size = to_2tuple(image_size)
42
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
- feat_size = self.trunk.default_cfg.get("pool_size", None)
44
- feature_ndim = 1 if not feat_size else 2
45
- if pool in ("abs_attn", "rot_attn"):
46
- assert feature_ndim == 2
47
- # if attn pooling used, remove both classifier and default pool
48
- self.trunk.reset_classifier(0, global_pool="")
49
- else:
50
- # reset global pool if pool config set, otherwise leave as network default
51
- reset_kwargs = dict(global_pool=pool) if pool else {}
52
- self.trunk.reset_classifier(0, **reset_kwargs)
53
- prev_chs = self.trunk.num_features
54
-
55
- head_layers = OrderedDict()
56
- if pool == "abs_attn":
57
- head_layers["pool"] = AbsAttentionPool2d(
58
- prev_chs, feat_size=feat_size, out_features=embed_dim
59
- )
60
- prev_chs = embed_dim
61
- elif pool == "rot_attn":
62
- head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
- prev_chs = embed_dim
64
- else:
65
- assert proj, "projection layer needed if non-attention pooling is used."
66
-
67
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
- if proj == "linear":
69
- head_layers["drop"] = nn.Dropout(drop)
70
- head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
- elif proj == "mlp":
72
- head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
-
74
- self.head = nn.Sequential(head_layers)
75
-
76
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
- """lock modules
78
- Args:
79
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
- """
81
- if not unlocked_groups:
82
- # lock full model
83
- for param in self.trunk.parameters():
84
- param.requires_grad = False
85
- if freeze_bn_stats:
86
- freeze_batch_norm_2d(self.trunk)
87
- else:
88
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
- try:
90
- # FIXME import here until API stable and in an official release
91
- from timm.models.helpers import group_parameters, group_modules
92
- except ImportError:
93
- raise RuntimeError(
94
- "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
- )
96
- matcher = self.trunk.group_matcher()
97
- gparams = group_parameters(self.trunk, matcher)
98
- max_layer_id = max(gparams.keys())
99
- max_layer_id = max_layer_id - unlocked_groups
100
- for group_idx in range(max_layer_id + 1):
101
- group = gparams[group_idx]
102
- for param in group:
103
- self.trunk.get_parameter(param).requires_grad = False
104
- if freeze_bn_stats:
105
- gmodules = group_modules(self.trunk, matcher, reverse=True)
106
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
- freeze_batch_norm_2d(self.trunk, gmodules)
108
-
109
- def forward(self, x):
110
- x = self.trunk(x)
111
- x = self.head(x)
112
- return x
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch.nn as nn
8
+
9
+ try:
10
+ import timm
11
+ from timm.models.layers import Mlp, to_2tuple
12
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
+ from timm.models.layers.attention_pool2d import (
14
+ AttentionPool2d as AbsAttentionPool2d,
15
+ )
16
+ except ImportError:
17
+ timm = None
18
+
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+
22
+ class TimmModel(nn.Module):
23
+ """timm model adapter
24
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_name,
30
+ embed_dim,
31
+ image_size=224,
32
+ pool="avg",
33
+ proj="linear",
34
+ drop=0.0,
35
+ pretrained=False,
36
+ ):
37
+ super().__init__()
38
+ if timm is None:
39
+ raise RuntimeError("Please `pip install timm` to use timm models.")
40
+
41
+ self.image_size = to_2tuple(image_size)
42
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
44
+ feature_ndim = 1 if not feat_size else 2
45
+ if pool in ("abs_attn", "rot_attn"):
46
+ assert feature_ndim == 2
47
+ # if attn pooling used, remove both classifier and default pool
48
+ self.trunk.reset_classifier(0, global_pool="")
49
+ else:
50
+ # reset global pool if pool config set, otherwise leave as network default
51
+ reset_kwargs = dict(global_pool=pool) if pool else {}
52
+ self.trunk.reset_classifier(0, **reset_kwargs)
53
+ prev_chs = self.trunk.num_features
54
+
55
+ head_layers = OrderedDict()
56
+ if pool == "abs_attn":
57
+ head_layers["pool"] = AbsAttentionPool2d(
58
+ prev_chs, feat_size=feat_size, out_features=embed_dim
59
+ )
60
+ prev_chs = embed_dim
61
+ elif pool == "rot_attn":
62
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
+ prev_chs = embed_dim
64
+ else:
65
+ assert proj, "projection layer needed if non-attention pooling is used."
66
+
67
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
+ if proj == "linear":
69
+ head_layers["drop"] = nn.Dropout(drop)
70
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
+ elif proj == "mlp":
72
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
+
74
+ self.head = nn.Sequential(head_layers)
75
+
76
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
+ """lock modules
78
+ Args:
79
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
+ """
81
+ if not unlocked_groups:
82
+ # lock full model
83
+ for param in self.trunk.parameters():
84
+ param.requires_grad = False
85
+ if freeze_bn_stats:
86
+ freeze_batch_norm_2d(self.trunk)
87
+ else:
88
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
+ try:
90
+ # FIXME import here until API stable and in an official release
91
+ from timm.models.helpers import group_parameters, group_modules
92
+ except ImportError:
93
+ raise RuntimeError(
94
+ "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
+ )
96
+ matcher = self.trunk.group_matcher()
97
+ gparams = group_parameters(self.trunk, matcher)
98
+ max_layer_id = max(gparams.keys())
99
+ max_layer_id = max_layer_id - unlocked_groups
100
+ for group_idx in range(max_layer_id + 1):
101
+ group = gparams[group_idx]
102
+ for param in group:
103
+ self.trunk.get_parameter(param).requires_grad = False
104
+ if freeze_bn_stats:
105
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
106
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
+ freeze_batch_norm_2d(self.trunk, gmodules)
108
+
109
+ def forward(self, x):
110
+ x = self.trunk(x)
111
+ x = self.head(x)
112
+ return x
audiosr/clap/open_clip/tokenizer.py CHANGED
@@ -1,197 +1,197 @@
1
- """ CLIP tokenizer
2
-
3
- Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import gzip
6
- import html
7
- import os
8
- from functools import lru_cache
9
- from typing import Union, List
10
-
11
- import ftfy
12
- import regex as re
13
- import torch
14
-
15
-
16
- @lru_cache()
17
- def default_bpe():
18
- return os.path.join(
19
- os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
- )
21
-
22
-
23
- @lru_cache()
24
- def bytes_to_unicode():
25
- """
26
- Returns list of utf-8 byte and a corresponding list of unicode strings.
27
- The reversible bpe codes work on unicode strings.
28
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
- This is a signficant percentage of your normal, say, 32K bpe vocab.
31
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
- And avoids mapping to whitespace/control characters the bpe code barfs on.
33
- """
34
- bs = (
35
- list(range(ord("!"), ord("~") + 1))
36
- + list(range(ord("¡"), ord("¬") + 1))
37
- + list(range(ord("®"), ord("ÿ") + 1))
38
- )
39
- cs = bs[:]
40
- n = 0
41
- for b in range(2**8):
42
- if b not in bs:
43
- bs.append(b)
44
- cs.append(2**8 + n)
45
- n += 1
46
- cs = [chr(n) for n in cs]
47
- return dict(zip(bs, cs))
48
-
49
-
50
- def get_pairs(word):
51
- """Return set of symbol pairs in a word.
52
- Word is represented as tuple of symbols (symbols being variable-length strings).
53
- """
54
- pairs = set()
55
- prev_char = word[0]
56
- for char in word[1:]:
57
- pairs.add((prev_char, char))
58
- prev_char = char
59
- return pairs
60
-
61
-
62
- def basic_clean(text):
63
- text = ftfy.fix_text(text)
64
- text = html.unescape(html.unescape(text))
65
- return text.strip()
66
-
67
-
68
- def whitespace_clean(text):
69
- text = re.sub(r"\s+", " ", text)
70
- text = text.strip()
71
- return text
72
-
73
-
74
- class SimpleTokenizer(object):
75
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
- self.byte_encoder = bytes_to_unicode()
77
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
- merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
- merges = merges[1 : 49152 - 256 - 2 + 1]
80
- merges = [tuple(merge.split()) for merge in merges]
81
- vocab = list(bytes_to_unicode().values())
82
- vocab = vocab + [v + "</w>" for v in vocab]
83
- for merge in merges:
84
- vocab.append("".join(merge))
85
- if not special_tokens:
86
- special_tokens = ["<start_of_text>", "<end_of_text>"]
87
- else:
88
- special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
- vocab.extend(special_tokens)
90
- self.encoder = dict(zip(vocab, range(len(vocab))))
91
- self.decoder = {v: k for k, v in self.encoder.items()}
92
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
- self.cache = {t: t for t in special_tokens}
94
- special = "|".join(special_tokens)
95
- self.pat = re.compile(
96
- special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
- re.IGNORECASE,
98
- )
99
-
100
- self.vocab_size = len(self.encoder)
101
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
-
103
- def bpe(self, token):
104
- if token in self.cache:
105
- return self.cache[token]
106
- word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
- pairs = get_pairs(word)
108
-
109
- if not pairs:
110
- return token + "</w>"
111
-
112
- while True:
113
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
- if bigram not in self.bpe_ranks:
115
- break
116
- first, second = bigram
117
- new_word = []
118
- i = 0
119
- while i < len(word):
120
- try:
121
- j = word.index(first, i)
122
- new_word.extend(word[i:j])
123
- i = j
124
- except:
125
- new_word.extend(word[i:])
126
- break
127
-
128
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
- new_word.append(first + second)
130
- i += 2
131
- else:
132
- new_word.append(word[i])
133
- i += 1
134
- new_word = tuple(new_word)
135
- word = new_word
136
- if len(word) == 1:
137
- break
138
- else:
139
- pairs = get_pairs(word)
140
- word = " ".join(word)
141
- self.cache[token] = word
142
- return word
143
-
144
- def encode(self, text):
145
- bpe_tokens = []
146
- text = whitespace_clean(basic_clean(text)).lower()
147
- for token in re.findall(self.pat, text):
148
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
- bpe_tokens.extend(
150
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
- )
152
- return bpe_tokens
153
-
154
- def decode(self, tokens):
155
- text = "".join([self.decoder[token] for token in tokens])
156
- text = (
157
- bytearray([self.byte_decoder[c] for c in text])
158
- .decode("utf-8", errors="replace")
159
- .replace("</w>", " ")
160
- )
161
- return text
162
-
163
-
164
- _tokenizer = SimpleTokenizer()
165
-
166
-
167
- def tokenize(
168
- texts: Union[str, List[str]], context_length: int = 77
169
- ) -> torch.LongTensor:
170
- """
171
- Returns the tokenized representation of given input string(s)
172
-
173
- Parameters
174
- ----------
175
- texts : Union[str, List[str]]
176
- An input string or a list of input strings to tokenize
177
- context_length : int
178
- The context length to use; all CLIP models use 77 as the context length
179
-
180
- Returns
181
- -------
182
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
- """
184
- if isinstance(texts, str):
185
- texts = [texts]
186
-
187
- sot_token = _tokenizer.encoder["<start_of_text>"]
188
- eot_token = _tokenizer.encoder["<end_of_text>"]
189
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
-
192
- for i, tokens in enumerate(all_tokens):
193
- if len(tokens) > context_length:
194
- tokens = tokens[:context_length] # Truncate
195
- result[i, : len(tokens)] = torch.tensor(tokens)
196
-
197
- return result
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+
16
+ @lru_cache()
17
+ def default_bpe():
18
+ return os.path.join(
19
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
+ )
21
+
22
+
23
+ @lru_cache()
24
+ def bytes_to_unicode():
25
+ """
26
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
27
+ The reversible bpe codes work on unicode strings.
28
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
31
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
33
+ """
34
+ bs = (
35
+ list(range(ord("!"), ord("~") + 1))
36
+ + list(range(ord("¡"), ord("¬") + 1))
37
+ + list(range(ord("®"), ord("ÿ") + 1))
38
+ )
39
+ cs = bs[:]
40
+ n = 0
41
+ for b in range(2**8):
42
+ if b not in bs:
43
+ bs.append(b)
44
+ cs.append(2**8 + n)
45
+ n += 1
46
+ cs = [chr(n) for n in cs]
47
+ return dict(zip(bs, cs))
48
+
49
+
50
+ def get_pairs(word):
51
+ """Return set of symbol pairs in a word.
52
+ Word is represented as tuple of symbols (symbols being variable-length strings).
53
+ """
54
+ pairs = set()
55
+ prev_char = word[0]
56
+ for char in word[1:]:
57
+ pairs.add((prev_char, char))
58
+ prev_char = char
59
+ return pairs
60
+
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r"\s+", " ", text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ class SimpleTokenizer(object):
75
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
+ self.byte_encoder = bytes_to_unicode()
77
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
+ merges = merges[1 : 49152 - 256 - 2 + 1]
80
+ merges = [tuple(merge.split()) for merge in merges]
81
+ vocab = list(bytes_to_unicode().values())
82
+ vocab = vocab + [v + "</w>" for v in vocab]
83
+ for merge in merges:
84
+ vocab.append("".join(merge))
85
+ if not special_tokens:
86
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
87
+ else:
88
+ special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
+ vocab.extend(special_tokens)
90
+ self.encoder = dict(zip(vocab, range(len(vocab))))
91
+ self.decoder = {v: k for k, v in self.encoder.items()}
92
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
+ self.cache = {t: t for t in special_tokens}
94
+ special = "|".join(special_tokens)
95
+ self.pat = re.compile(
96
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
+ re.IGNORECASE,
98
+ )
99
+
100
+ self.vocab_size = len(self.encoder)
101
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
+
103
+ def bpe(self, token):
104
+ if token in self.cache:
105
+ return self.cache[token]
106
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
+ pairs = get_pairs(word)
108
+
109
+ if not pairs:
110
+ return token + "</w>"
111
+
112
+ while True:
113
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
+ if bigram not in self.bpe_ranks:
115
+ break
116
+ first, second = bigram
117
+ new_word = []
118
+ i = 0
119
+ while i < len(word):
120
+ try:
121
+ j = word.index(first, i)
122
+ new_word.extend(word[i:j])
123
+ i = j
124
+ except:
125
+ new_word.extend(word[i:])
126
+ break
127
+
128
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
+ new_word.append(first + second)
130
+ i += 2
131
+ else:
132
+ new_word.append(word[i])
133
+ i += 1
134
+ new_word = tuple(new_word)
135
+ word = new_word
136
+ if len(word) == 1:
137
+ break
138
+ else:
139
+ pairs = get_pairs(word)
140
+ word = " ".join(word)
141
+ self.cache[token] = word
142
+ return word
143
+
144
+ def encode(self, text):
145
+ bpe_tokens = []
146
+ text = whitespace_clean(basic_clean(text)).lower()
147
+ for token in re.findall(self.pat, text):
148
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
+ bpe_tokens.extend(
150
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
+ )
152
+ return bpe_tokens
153
+
154
+ def decode(self, tokens):
155
+ text = "".join([self.decoder[token] for token in tokens])
156
+ text = (
157
+ bytearray([self.byte_decoder[c] for c in text])
158
+ .decode("utf-8", errors="replace")
159
+ .replace("</w>", " ")
160
+ )
161
+ return text
162
+
163
+
164
+ _tokenizer = SimpleTokenizer()
165
+
166
+
167
+ def tokenize(
168
+ texts: Union[str, List[str]], context_length: int = 77
169
+ ) -> torch.LongTensor:
170
+ """
171
+ Returns the tokenized representation of given input string(s)
172
+
173
+ Parameters
174
+ ----------
175
+ texts : Union[str, List[str]]
176
+ An input string or a list of input strings to tokenize
177
+ context_length : int
178
+ The context length to use; all CLIP models use 77 as the context length
179
+
180
+ Returns
181
+ -------
182
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
+ """
184
+ if isinstance(texts, str):
185
+ texts = [texts]
186
+
187
+ sot_token = _tokenizer.encoder["<start_of_text>"]
188
+ eot_token = _tokenizer.encoder["<end_of_text>"]
189
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
+
192
+ for i, tokens in enumerate(all_tokens):
193
+ if len(tokens) > context_length:
194
+ tokens = tokens[:context_length] # Truncate
195
+ result[i, : len(tokens)] = torch.tensor(tokens)
196
+
197
+ return result
audiosr/clap/open_clip/transform.py CHANGED
@@ -1,45 +1,45 @@
1
- from torchvision.transforms import (
2
- Normalize,
3
- Compose,
4
- RandomResizedCrop,
5
- InterpolationMode,
6
- ToTensor,
7
- Resize,
8
- CenterCrop,
9
- )
10
-
11
-
12
- def _convert_to_rgb(image):
13
- return image.convert("RGB")
14
-
15
-
16
- def image_transform(
17
- image_size: int,
18
- is_train: bool,
19
- mean=(0.48145466, 0.4578275, 0.40821073),
20
- std=(0.26862954, 0.26130258, 0.27577711),
21
- ):
22
- normalize = Normalize(mean=mean, std=std)
23
- if is_train:
24
- return Compose(
25
- [
26
- RandomResizedCrop(
27
- image_size,
28
- scale=(0.9, 1.0),
29
- interpolation=InterpolationMode.BICUBIC,
30
- ),
31
- _convert_to_rgb,
32
- ToTensor(),
33
- normalize,
34
- ]
35
- )
36
- else:
37
- return Compose(
38
- [
39
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
- CenterCrop(image_size),
41
- _convert_to_rgb,
42
- ToTensor(),
43
- normalize,
44
- ]
45
- )
 
1
+ from torchvision.transforms import (
2
+ Normalize,
3
+ Compose,
4
+ RandomResizedCrop,
5
+ InterpolationMode,
6
+ ToTensor,
7
+ Resize,
8
+ CenterCrop,
9
+ )
10
+
11
+
12
+ def _convert_to_rgb(image):
13
+ return image.convert("RGB")
14
+
15
+
16
+ def image_transform(
17
+ image_size: int,
18
+ is_train: bool,
19
+ mean=(0.48145466, 0.4578275, 0.40821073),
20
+ std=(0.26862954, 0.26130258, 0.27577711),
21
+ ):
22
+ normalize = Normalize(mean=mean, std=std)
23
+ if is_train:
24
+ return Compose(
25
+ [
26
+ RandomResizedCrop(
27
+ image_size,
28
+ scale=(0.9, 1.0),
29
+ interpolation=InterpolationMode.BICUBIC,
30
+ ),
31
+ _convert_to_rgb,
32
+ ToTensor(),
33
+ normalize,
34
+ ]
35
+ )
36
+ else:
37
+ return Compose(
38
+ [
39
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
+ CenterCrop(image_size),
41
+ _convert_to_rgb,
42
+ ToTensor(),
43
+ normalize,
44
+ ]
45
+ )
audiosr/clap/open_clip/utils.py CHANGED
@@ -1,355 +1,355 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn as nn
4
- from torchvision.ops.misc import FrozenBatchNorm2d
5
- import logging
6
- from tqdm import tqdm
7
- import random
8
- import json
9
- import os
10
- import pathlib
11
-
12
- # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
13
- dataset_split = {
14
- "audiocaps": ["train", "valid", "test"],
15
- "audioset": ["balanced_train", "unbalanced_train", "eval"],
16
- "BBCSoundEffects": ["train", "test"],
17
- "Clotho": ["train", "test", "valid"],
18
- "free_to_use_sounds": ["train", "test"],
19
- "paramount_motion": ["train", "test"],
20
- "sonniss_game_effects": ["train", "test"],
21
- "wesoundeffects": ["train", "test"],
22
- "MACS": ["train", "test"],
23
- "freesound": ["train", "test"],
24
- "FSD50K": ["train", "test", "valid"],
25
- "fsd50k_class_label": ["train", "test", "valid"],
26
- "esc50": ["train", "test"],
27
- "audiostock": ["train", "test"],
28
- "freesound_no_overlap_noesc50": ["train", "test"],
29
- "epidemic_sound_effects": ["train", "test"],
30
- "VGGSound": ["train", "test"],
31
- "urbansound8k_class_label": ["train", "test"],
32
- "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
33
- "epidemic_sound_effects_t5": ["train", "test"],
34
- "WavText5K": ["train", "test"],
35
- "esc50_no_overlap": ["train", "test"],
36
- "usd8k_no_overlap": ["train", "test"],
37
- "fsd50k_200_class_label": ["train", "test", "valid"],
38
- }
39
-
40
-
41
- def freeze_batch_norm_2d(module, module_match={}, name=""):
42
- """
43
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
44
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
45
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
46
-
47
- Args:
48
- module (torch.nn.Module): Any PyTorch module.
49
- module_match (dict): Dictionary of full module names to freeze (all if empty)
50
- name (str): Full module name (prefix)
51
-
52
- Returns:
53
- torch.nn.Module: Resulting module
54
-
55
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
56
- """
57
- res = module
58
- is_match = True
59
- if module_match:
60
- is_match = name in module_match
61
- if is_match and isinstance(
62
- module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
63
- ):
64
- res = FrozenBatchNorm2d(module.num_features)
65
- res.num_features = module.num_features
66
- res.affine = module.affine
67
- if module.affine:
68
- res.weight.data = module.weight.data.clone().detach()
69
- res.bias.data = module.bias.data.clone().detach()
70
- res.running_mean.data = module.running_mean.data
71
- res.running_var.data = module.running_var.data
72
- res.eps = module.eps
73
- else:
74
- for child_name, child in module.named_children():
75
- full_child_name = ".".join([name, child_name]) if name else child_name
76
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
77
- if new_child is not child:
78
- res.add_module(child_name, new_child)
79
- return res
80
-
81
-
82
- def exist(dataset_name, dataset_type):
83
- """
84
- Check if dataset exists
85
- """
86
- if dataset_type in dataset_split[dataset_name]:
87
- return True
88
- else:
89
- return False
90
-
91
-
92
- def get_tar_path_from_dataset_name(
93
- dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
94
- ):
95
- """
96
- Get tar path from dataset name and type
97
- """
98
- output = []
99
- for n in dataset_names:
100
- if full_dataset is not None and n in full_dataset:
101
- current_dataset_types = dataset_split[n]
102
- else:
103
- current_dataset_types = dataset_types
104
- for s in current_dataset_types:
105
- tmp = []
106
- if islocal:
107
- sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
108
- if not os.path.exists(sizefilepath_):
109
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
110
- else:
111
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
112
- if not os.path.exists(sizefilepath_):
113
- continue
114
- sizes = json.load(open(sizefilepath_, "r"))
115
- for k in sizes.keys():
116
- if islocal:
117
- tmp.append(f"{dataset_path}/{n}/{s}/{k}")
118
- else:
119
- tmp.append(
120
- f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
121
- )
122
- if proportion != 1:
123
- tmp = random.sample(tmp, int(proportion * len(tmp)))
124
- output.append(tmp)
125
- return sum(output, [])
126
-
127
-
128
- def get_tar_path_from_txts(txt_path, islocal, proportion=1):
129
- """
130
- Get tar path from txt path
131
- """
132
- if isinstance(txt_path, (list, tuple)):
133
- return sum(
134
- [
135
- get_tar_path_from_txts(
136
- txt_path[i], islocal=islocal, proportion=proportion
137
- )
138
- for i in range(len(txt_path))
139
- ],
140
- [],
141
- )
142
- if isinstance(txt_path, str):
143
- with open(txt_path) as f:
144
- lines = f.readlines()
145
- if islocal:
146
- lines = [
147
- lines[i]
148
- .split("\n")[0]
149
- .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
150
- for i in range(len(lines))
151
- ]
152
- else:
153
- lines = [
154
- lines[i].split("\n")[0].replace(".tar", ".tar -")
155
- for i in range(len(lines))
156
- ]
157
- if proportion != 1:
158
- print("Sampling tars with proportion of {}".format(proportion))
159
- lines = random.sample(lines, int(proportion * len(lines)))
160
- return lines
161
-
162
-
163
- def get_mix_lambda(mixup_alpha, batch_size):
164
- mixup_lambdas = [
165
- np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
166
- ]
167
- return np.array(mixup_lambdas).astype(np.float32)
168
-
169
-
170
- def do_mixup(x, mixup_lambda):
171
- """
172
- Args:
173
- x: (batch_size , ...)
174
- mixup_lambda: (batch_size,)
175
- Returns:
176
- out: (batch_size, ...)
177
- """
178
- out = (
179
- x.transpose(0, -1) * mixup_lambda
180
- + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
181
- ).transpose(0, -1)
182
- return out
183
-
184
-
185
- def interpolate(x, ratio):
186
- """Interpolate data in time domain. This is used to compensate the
187
- resolution reduction in downsampling of a CNN.
188
-
189
- Args:
190
- x: (batch_size, time_steps, classes_num)
191
- ratio: int, ratio to interpolate
192
- Returns:
193
- upsampled: (batch_size, time_steps * ratio, classes_num)
194
- """
195
- (batch_size, time_steps, classes_num) = x.shape
196
- upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
197
- upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
198
- return upsampled
199
-
200
-
201
- def pad_framewise_output(framewise_output, frames_num):
202
- """Pad framewise_output to the same length as input frames. The pad value
203
- is the same as the value of the last frame.
204
- Args:
205
- framewise_output: (batch_size, frames_num, classes_num)
206
- frames_num: int, number of frames to pad
207
- Outputs:
208
- output: (batch_size, frames_num, classes_num)
209
- """
210
- pad = framewise_output[:, -1:, :].repeat(
211
- 1, frames_num - framewise_output.shape[1], 1
212
- )
213
- """tensor for padding"""
214
-
215
- output = torch.cat((framewise_output, pad), dim=1)
216
- """(batch_size, frames_num, classes_num)"""
217
-
218
-
219
- # def process_ipc(index_path, classes_num, filename):
220
- # # load data
221
- # logging.info("Load Data...............")
222
- # ipc = [[] for _ in range(classes_num)]
223
- # with h5py.File(index_path, "r") as f:
224
- # for i in tqdm(range(len(f["target"]))):
225
- # t_class = np.where(f["target"][i])[0]
226
- # for t in t_class:
227
- # ipc[t].append(i)
228
- # print(ipc)
229
- # np.save(filename, ipc)
230
- # logging.info("Load Data Succeed...............")
231
-
232
-
233
- def save_to_dict(s, o_={}):
234
- sp = s.split(": ")
235
- o_.update({sp[0]: float(sp[1])})
236
- return o_
237
-
238
-
239
- def get_data_from_log(txt_path):
240
- """
241
- Output dictionary from out.txt log file
242
- """
243
- with open(txt_path) as f:
244
- lines = f.readlines()
245
- val_data = {}
246
- train_data = {}
247
- train_losses = []
248
- train_losses_epoch = []
249
- for i in range(len(lines)):
250
- if "| INFO |" in lines[i]:
251
- if "Eval Epoch" in lines[i]:
252
- if "val_loss" in lines[i]:
253
- # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
254
- line = lines[i].split("Eval Epoch: ")[-1]
255
- num_epoch = int(line.split(" ")[0].split(" ")[0])
256
- d = {
257
- line.split(" ")[0]
258
- .split(" ")[1]
259
- .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
260
- }
261
- for i in range(1, len(line.split(" "))):
262
- d = save_to_dict(line.split(" ")[i], d)
263
- val_data[num_epoch] = d
264
- elif "Train Epoch" in lines[i]:
265
- num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
266
- loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
267
- train_losses.append(loss)
268
- train_losses_epoch.append(num_epoch)
269
- for i in range(len(train_losses)):
270
- train_data[i] = {
271
- "num_epoch": train_losses_epoch[i],
272
- "train_loss": train_losses[i],
273
- }
274
- return train_data, val_data
275
-
276
-
277
- def save_p(obj, filename):
278
- import pickle
279
-
280
- try:
281
- from deepdiff import DeepDiff
282
- except:
283
- os.system("pip install deepdiff")
284
- from deepdiff import DeepDiff
285
- with open(filename, "wb") as file:
286
- pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
287
- with open(filename, "rb") as file:
288
- z = pickle.load(file)
289
- assert (
290
- DeepDiff(obj, z, ignore_string_case=True) == {}
291
- ), "there is something wrong with the saving process"
292
- return
293
-
294
-
295
- def load_p(filename):
296
- import pickle
297
-
298
- with open(filename, "rb") as file:
299
- z = pickle.load(file)
300
- return z
301
-
302
-
303
- def save_json(data, name="data.json"):
304
- import json
305
-
306
- with open(name, "w") as fp:
307
- json.dump(data, fp)
308
- return
309
-
310
-
311
- def load_json(name):
312
- import json
313
-
314
- with open(name, "r") as fp:
315
- data = json.load(fp)
316
- return data
317
-
318
-
319
- def load_class_label(path):
320
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
321
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
322
- out = None
323
- if path is not None:
324
- if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
325
- out = load_p(path)
326
- elif pathlib.Path(path).suffix in [".json", ".txt"]:
327
- out = load_json(path)
328
- elif pathlib.Path(path).suffix in [".npy", ".npz"]:
329
- out = np.load(path)
330
- elif pathlib.Path(path).suffix in [".csv"]:
331
- import pandas as pd
332
-
333
- out = pd.read_csv(path)
334
- return out
335
- # if out is None:
336
- # return None
337
- # else:
338
- # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
339
- # val = Array('i', out.values(), lock=False)
340
- # return (key, val)
341
-
342
-
343
- from torch import optim
344
-
345
-
346
- def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
347
- if optimizer_name.lower() == "adamw":
348
- optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
349
- elif optimizer_name.lower() == "sgd":
350
- optimizer = optim.SGD(params, lr=lr, momentum=momentum)
351
- elif optimizer_name.lower() == "adam":
352
- optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
353
- else:
354
- raise ValueError("optimizer name is not correct")
355
- return optimizer
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torchvision.ops.misc import FrozenBatchNorm2d
5
+ import logging
6
+ from tqdm import tqdm
7
+ import random
8
+ import json
9
+ import os
10
+ import pathlib
11
+
12
+ # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
13
+ dataset_split = {
14
+ "audiocaps": ["train", "valid", "test"],
15
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
16
+ "BBCSoundEffects": ["train", "test"],
17
+ "Clotho": ["train", "test", "valid"],
18
+ "free_to_use_sounds": ["train", "test"],
19
+ "paramount_motion": ["train", "test"],
20
+ "sonniss_game_effects": ["train", "test"],
21
+ "wesoundeffects": ["train", "test"],
22
+ "MACS": ["train", "test"],
23
+ "freesound": ["train", "test"],
24
+ "FSD50K": ["train", "test", "valid"],
25
+ "fsd50k_class_label": ["train", "test", "valid"],
26
+ "esc50": ["train", "test"],
27
+ "audiostock": ["train", "test"],
28
+ "freesound_no_overlap_noesc50": ["train", "test"],
29
+ "epidemic_sound_effects": ["train", "test"],
30
+ "VGGSound": ["train", "test"],
31
+ "urbansound8k_class_label": ["train", "test"],
32
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
33
+ "epidemic_sound_effects_t5": ["train", "test"],
34
+ "WavText5K": ["train", "test"],
35
+ "esc50_no_overlap": ["train", "test"],
36
+ "usd8k_no_overlap": ["train", "test"],
37
+ "fsd50k_200_class_label": ["train", "test", "valid"],
38
+ }
39
+
40
+
41
+ def freeze_batch_norm_2d(module, module_match={}, name=""):
42
+ """
43
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
44
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
45
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
46
+
47
+ Args:
48
+ module (torch.nn.Module): Any PyTorch module.
49
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
50
+ name (str): Full module name (prefix)
51
+
52
+ Returns:
53
+ torch.nn.Module: Resulting module
54
+
55
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
56
+ """
57
+ res = module
58
+ is_match = True
59
+ if module_match:
60
+ is_match = name in module_match
61
+ if is_match and isinstance(
62
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
63
+ ):
64
+ res = FrozenBatchNorm2d(module.num_features)
65
+ res.num_features = module.num_features
66
+ res.affine = module.affine
67
+ if module.affine:
68
+ res.weight.data = module.weight.data.clone().detach()
69
+ res.bias.data = module.bias.data.clone().detach()
70
+ res.running_mean.data = module.running_mean.data
71
+ res.running_var.data = module.running_var.data
72
+ res.eps = module.eps
73
+ else:
74
+ for child_name, child in module.named_children():
75
+ full_child_name = ".".join([name, child_name]) if name else child_name
76
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
77
+ if new_child is not child:
78
+ res.add_module(child_name, new_child)
79
+ return res
80
+
81
+
82
+ def exist(dataset_name, dataset_type):
83
+ """
84
+ Check if dataset exists
85
+ """
86
+ if dataset_type in dataset_split[dataset_name]:
87
+ return True
88
+ else:
89
+ return False
90
+
91
+
92
+ def get_tar_path_from_dataset_name(
93
+ dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
94
+ ):
95
+ """
96
+ Get tar path from dataset name and type
97
+ """
98
+ output = []
99
+ for n in dataset_names:
100
+ if full_dataset is not None and n in full_dataset:
101
+ current_dataset_types = dataset_split[n]
102
+ else:
103
+ current_dataset_types = dataset_types
104
+ for s in current_dataset_types:
105
+ tmp = []
106
+ if islocal:
107
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
108
+ if not os.path.exists(sizefilepath_):
109
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
110
+ else:
111
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
112
+ if not os.path.exists(sizefilepath_):
113
+ continue
114
+ sizes = json.load(open(sizefilepath_, "r"))
115
+ for k in sizes.keys():
116
+ if islocal:
117
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
118
+ else:
119
+ tmp.append(
120
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
121
+ )
122
+ if proportion != 1:
123
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
124
+ output.append(tmp)
125
+ return sum(output, [])
126
+
127
+
128
+ def get_tar_path_from_txts(txt_path, islocal, proportion=1):
129
+ """
130
+ Get tar path from txt path
131
+ """
132
+ if isinstance(txt_path, (list, tuple)):
133
+ return sum(
134
+ [
135
+ get_tar_path_from_txts(
136
+ txt_path[i], islocal=islocal, proportion=proportion
137
+ )
138
+ for i in range(len(txt_path))
139
+ ],
140
+ [],
141
+ )
142
+ if isinstance(txt_path, str):
143
+ with open(txt_path) as f:
144
+ lines = f.readlines()
145
+ if islocal:
146
+ lines = [
147
+ lines[i]
148
+ .split("\n")[0]
149
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
150
+ for i in range(len(lines))
151
+ ]
152
+ else:
153
+ lines = [
154
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
155
+ for i in range(len(lines))
156
+ ]
157
+ if proportion != 1:
158
+ print("Sampling tars with proportion of {}".format(proportion))
159
+ lines = random.sample(lines, int(proportion * len(lines)))
160
+ return lines
161
+
162
+
163
+ def get_mix_lambda(mixup_alpha, batch_size):
164
+ mixup_lambdas = [
165
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
166
+ ]
167
+ return np.array(mixup_lambdas).astype(np.float32)
168
+
169
+
170
+ def do_mixup(x, mixup_lambda):
171
+ """
172
+ Args:
173
+ x: (batch_size , ...)
174
+ mixup_lambda: (batch_size,)
175
+ Returns:
176
+ out: (batch_size, ...)
177
+ """
178
+ out = (
179
+ x.transpose(0, -1) * mixup_lambda
180
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
181
+ ).transpose(0, -1)
182
+ return out
183
+
184
+
185
+ def interpolate(x, ratio):
186
+ """Interpolate data in time domain. This is used to compensate the
187
+ resolution reduction in downsampling of a CNN.
188
+
189
+ Args:
190
+ x: (batch_size, time_steps, classes_num)
191
+ ratio: int, ratio to interpolate
192
+ Returns:
193
+ upsampled: (batch_size, time_steps * ratio, classes_num)
194
+ """
195
+ (batch_size, time_steps, classes_num) = x.shape
196
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
197
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
198
+ return upsampled
199
+
200
+
201
+ def pad_framewise_output(framewise_output, frames_num):
202
+ """Pad framewise_output to the same length as input frames. The pad value
203
+ is the same as the value of the last frame.
204
+ Args:
205
+ framewise_output: (batch_size, frames_num, classes_num)
206
+ frames_num: int, number of frames to pad
207
+ Outputs:
208
+ output: (batch_size, frames_num, classes_num)
209
+ """
210
+ pad = framewise_output[:, -1:, :].repeat(
211
+ 1, frames_num - framewise_output.shape[1], 1
212
+ )
213
+ """tensor for padding"""
214
+
215
+ output = torch.cat((framewise_output, pad), dim=1)
216
+ """(batch_size, frames_num, classes_num)"""
217
+
218
+
219
+ # def process_ipc(index_path, classes_num, filename):
220
+ # # load data
221
+ # logging.info("Load Data...............")
222
+ # ipc = [[] for _ in range(classes_num)]
223
+ # with h5py.File(index_path, "r") as f:
224
+ # for i in tqdm(range(len(f["target"]))):
225
+ # t_class = np.where(f["target"][i])[0]
226
+ # for t in t_class:
227
+ # ipc[t].append(i)
228
+ # print(ipc)
229
+ # np.save(filename, ipc)
230
+ # logging.info("Load Data Succeed...............")
231
+
232
+
233
+ def save_to_dict(s, o_={}):
234
+ sp = s.split(": ")
235
+ o_.update({sp[0]: float(sp[1])})
236
+ return o_
237
+
238
+
239
+ def get_data_from_log(txt_path):
240
+ """
241
+ Output dictionary from out.txt log file
242
+ """
243
+ with open(txt_path) as f:
244
+ lines = f.readlines()
245
+ val_data = {}
246
+ train_data = {}
247
+ train_losses = []
248
+ train_losses_epoch = []
249
+ for i in range(len(lines)):
250
+ if "| INFO |" in lines[i]:
251
+ if "Eval Epoch" in lines[i]:
252
+ if "val_loss" in lines[i]:
253
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
254
+ line = lines[i].split("Eval Epoch: ")[-1]
255
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
256
+ d = {
257
+ line.split(" ")[0]
258
+ .split(" ")[1]
259
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
260
+ }
261
+ for i in range(1, len(line.split(" "))):
262
+ d = save_to_dict(line.split(" ")[i], d)
263
+ val_data[num_epoch] = d
264
+ elif "Train Epoch" in lines[i]:
265
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
266
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
267
+ train_losses.append(loss)
268
+ train_losses_epoch.append(num_epoch)
269
+ for i in range(len(train_losses)):
270
+ train_data[i] = {
271
+ "num_epoch": train_losses_epoch[i],
272
+ "train_loss": train_losses[i],
273
+ }
274
+ return train_data, val_data
275
+
276
+
277
+ def save_p(obj, filename):
278
+ import pickle
279
+
280
+ try:
281
+ from deepdiff import DeepDiff
282
+ except:
283
+ os.system("pip install deepdiff")
284
+ from deepdiff import DeepDiff
285
+ with open(filename, "wb") as file:
286
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
287
+ with open(filename, "rb") as file:
288
+ z = pickle.load(file)
289
+ assert (
290
+ DeepDiff(obj, z, ignore_string_case=True) == {}
291
+ ), "there is something wrong with the saving process"
292
+ return
293
+
294
+
295
+ def load_p(filename):
296
+ import pickle
297
+
298
+ with open(filename, "rb") as file:
299
+ z = pickle.load(file)
300
+ return z
301
+
302
+
303
+ def save_json(data, name="data.json"):
304
+ import json
305
+
306
+ with open(name, "w") as fp:
307
+ json.dump(data, fp)
308
+ return
309
+
310
+
311
+ def load_json(name):
312
+ import json
313
+
314
+ with open(name, "r") as fp:
315
+ data = json.load(fp)
316
+ return data
317
+
318
+
319
+ def load_class_label(path):
320
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
321
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
322
+ out = None
323
+ if path is not None:
324
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
325
+ out = load_p(path)
326
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
327
+ out = load_json(path)
328
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
329
+ out = np.load(path)
330
+ elif pathlib.Path(path).suffix in [".csv"]:
331
+ import pandas as pd
332
+
333
+ out = pd.read_csv(path)
334
+ return out
335
+ # if out is None:
336
+ # return None
337
+ # else:
338
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
339
+ # val = Array('i', out.values(), lock=False)
340
+ # return (key, val)
341
+
342
+
343
+ from torch import optim
344
+
345
+
346
+ def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
347
+ if optimizer_name.lower() == "adamw":
348
+ optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
349
+ elif optimizer_name.lower() == "sgd":
350
+ optimizer = optim.SGD(params, lr=lr, momentum=momentum)
351
+ elif optimizer_name.lower() == "adam":
352
+ optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
353
+ else:
354
+ raise ValueError("optimizer name is not correct")
355
+ return optimizer
audiosr/clap/training/data.py CHANGED
@@ -1,865 +1,865 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
- from dataclasses import dataclass
6
- import numpy as np
7
- import pandas as pd
8
- import torch
9
- import torchvision.datasets as datasets
10
- from PIL import Image
11
- from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
12
- from torch.utils.data.distributed import DistributedSampler
13
- import soundfile as sf
14
- import io
15
- from pathlib import Path
16
-
17
- # import wget
18
-
19
- from audiosr.clap.open_clip.utils import get_tar_path_from_dataset_name
20
- from audiosr.clap.open_clip.utils import load_class_label
21
-
22
- try:
23
- import horovod.torch as hvd
24
- except ImportError:
25
- hvd = None
26
-
27
- try:
28
- import torchaudio
29
- except ImportError:
30
- torchaudio = None
31
-
32
- from audiosr.clap.open_clip import tokenize
33
-
34
-
35
- def tokenizer(text):
36
- return tokenize(text).squeeze(0)
37
-
38
-
39
- from transformers import RobertaTokenizer
40
-
41
- tokenize = RobertaTokenizer.from_pretrained("roberta-base")
42
-
43
-
44
- def tokenizer(text):
45
- result = tokenize(
46
- text,
47
- padding="max_length",
48
- truncation=True,
49
- max_length=77,
50
- return_tensors="pt",
51
- )
52
- return {k: v.squeeze(0) for k, v in result.items()}
53
-
54
-
55
- # initizlied the audioset map
56
- _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
57
- _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
58
-
59
-
60
- def int16_to_float32(x):
61
- return (x / 32767.0).astype(np.float32)
62
-
63
-
64
- def float32_to_int16(x):
65
- x = np.clip(x, a_min=-1.0, a_max=1.0)
66
- return (x * 32767.0).astype(np.int16)
67
-
68
-
69
- # For Toy Dataset
70
- # class ToyDataset(Dataset):
71
- # def __init__(self, index_path, ipc, config, eval_mode=False):
72
- # """Toy Dataset for testing the audioset input with text labels
73
- # Parameters
74
- # ----------
75
- # index_path: str
76
- # the link to the h5 file of each audio
77
- # idc: str
78
- # the link to the npy file, the number of samples in each class
79
- # config: dict
80
- # the audio cfg file
81
- # eval_model (bool): to indicate if the dataset is a testing dataset
82
- # """
83
- # self.audio_cfg = config["audio_cfg"]
84
- # self.text_cfg = config["text_cfg"]
85
- # self.fp = h5py.File(index_path, "r")
86
- # self.ipc = np.load(ipc, allow_pickle=True)
87
- # self.total_size = len(self.fp["audio_name"])
88
- # self.classes_num = self.audio_cfg["class_num"]
89
- # self.eval_mode = eval_mode
90
-
91
- # if not eval_mode:
92
- # self.generate_queue()
93
- # else:
94
- # self.queue = []
95
- # for i in range(self.total_size):
96
- # target = self.fp["target"][i]
97
- # if np.sum(target) > 0:
98
- # self.queue.append(i)
99
- # self.total_size = len(self.queue)
100
- # logging.info("total dataset size: %d" % (self.total_size))
101
- # logging.info("class num: %d" % (self.classes_num))
102
-
103
- # def time_shifting(self, x):
104
- # frame_num = len(x)
105
- # shift_len = random.randint(0, frame_num - 1)
106
- # new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
107
- # return new_sample
108
-
109
- # def generate_queue(self):
110
- # self.queue = []
111
- # while len(self.queue) < self.total_size:
112
- # class_set = [*range(self.classes_num)]
113
- # random.shuffle(class_set)
114
- # self.queue += [
115
- # self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
116
- # ]
117
- # self.queue = self.queue[: self.total_size]
118
-
119
- # logging.info("queue regenerated:%s" % (self.queue[-5:]))
120
-
121
- # def crop_wav(self, x):
122
- # crop_size = self.audio_cfg["crop_size"]
123
- # crop_pos = random.randint(0, len(x) - crop_size - 1)
124
- # return x[crop_pos : crop_pos + crop_size]
125
-
126
- # def prompt_text(self, target):
127
- # events = _AUDIOSET_MAP[np.where(target > 0)]
128
- # event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
129
- # text = tokenize(event_text)[0]
130
- # return text
131
-
132
- # def __getitem__(self, index):
133
- # """Load waveform, text, and target of an audio clip
134
-
135
- # Parameters
136
- # ----------
137
- # index: int
138
- # the index number
139
- # Return
140
- # ------
141
- # output: dict {
142
- # "hdf5_path": str,
143
- # "index_in_hdf5": int,
144
- # "audio_name": str,
145
- # "waveform": list (audio_length,),
146
- # "target": list (class_num, ),
147
- # "text": torch.tensor (context_length,)
148
- # }
149
- # the output dictionary
150
- # """
151
- # s_index = self.queue[index]
152
-
153
- # audio_name = self.fp["audio_name"][s_index].decode()
154
- # # Hardcode here CHANGE
155
- # hdf5_path = (
156
- # self.fp["hdf5_path"][s_index]
157
- # .decode()
158
- # .replace(
159
- # "../workspace",
160
- # "/home/la/kechen/Research/ke_zsasp/workspace",
161
- # )
162
- # )
163
- # r_idx = self.fp["index_in_hdf5"][s_index]
164
- # target = self.fp["target"][s_index].astype(np.float32)
165
- # text = self.prompt_text(target)
166
- # with h5py.File(hdf5_path, "r") as f:
167
- # waveform = int16_to_float32(f["waveform"][r_idx])[
168
- # : self.audio_cfg["clip_samples"]
169
- # ]
170
- # assert (
171
- # len(waveform) == self.audio_cfg["clip_samples"]
172
- # ), "The sample length is not match"
173
- # # Time shift
174
- # # if (self.config.enable_time_shift) and (not self.eval_mode):
175
- # # waveform = self.time_shifting(waveform)
176
- # # # Label Enhance
177
- # # if (self.config.crop_size is not None) and (not self.eval_mode):
178
- # # waveform = self.crop_wav(waveform)
179
- # # # the label enhance rate is fixed 0.5
180
- # # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
181
- # # kidx = np.where(target)[0]
182
- # # for k in kidx:
183
- # # for add_key in self.class_map[k][1]:
184
- # # target[add_key] = 1.0
185
- # # if len(self.class_map[k][2]) > 0:
186
- # # add_key = random.choice(self.class_map[k][2])
187
- # # target[add_key] = 1.0
188
-
189
- # # missing the text input
190
- # mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
191
- # mel_spec = (
192
- # torch.cat(
193
- # [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
194
- # )
195
- # .cpu()
196
- # .numpy()
197
- # )
198
- # longer = random.choice([True, False])
199
- # if longer == False:
200
- # mel_spec[1:, :, :] = 0.0
201
- # data_dict = {
202
- # "hdf5_path": hdf5_path,
203
- # "index_in_hdf5": r_idx,
204
- # "audio_name": audio_name,
205
- # "waveform": waveform,
206
- # "class_label": target,
207
- # "text": text,
208
- # "longer": longer,
209
- # "mel_fusion": mel_spec,
210
- # }
211
- # return data_dict
212
-
213
- # def __len__(self):
214
- # return self.total_size
215
-
216
-
217
- class CsvDataset(Dataset):
218
- def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
219
- logging.debug(f"Loading csv data from {input_filename}.")
220
- df = pd.read_csv(input_filename, sep=sep)
221
-
222
- self.images = df[img_key].tolist()
223
- self.captions = df[caption_key].tolist()
224
- self.transforms = transforms
225
- logging.debug("Done loading data.")
226
-
227
- def __len__(self):
228
- return len(self.captions)
229
-
230
- def __getitem__(self, idx):
231
- images = self.transforms(Image.open(str(self.images[idx])))
232
- texts = tokenize([str(self.captions[idx])])[0]
233
- return images, texts
234
-
235
-
236
- @dataclass
237
- class DataInfo:
238
- dataloader: DataLoader
239
- sampler: DistributedSampler
240
-
241
-
242
- def preprocess_txt(text):
243
- return tokenize([str(text)])[0]
244
-
245
-
246
- # def get_dataset_size(shards, sizefilepath_=None, is_local=True):
247
- # if isinstance(shards, list):
248
- # size_list = []
249
- # for s in shards:
250
- # size_list.append(
251
- # get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
252
- # )
253
- # else:
254
- # if not is_local:
255
- # for n in dataset_split.keys():
256
- # if n in shards.split("/"):
257
- # break
258
- # for s in dataset_split[n]:
259
- # if s in shards.split("/"):
260
- # break
261
- # sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
262
- # shards_list = list(braceexpand.braceexpand(shards))
263
- # dir_path = os.path.dirname(shards)
264
- # if sizefilepath_ is not None:
265
- # sizes = json.load(open(sizefilepath_, "r"))
266
- # total_size = sum(
267
- # [
268
- # int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
269
- # for shard in shards_list
270
- # ]
271
- # )
272
- # else:
273
- # sizes_filename = os.path.join(dir_path, "sizes.json")
274
- # len_filename = os.path.join(dir_path, "__len__")
275
- # if os.path.exists(sizes_filename):
276
- # sizes = json.load(open(sizes_filename, "r"))
277
- # total_size = sum(
278
- # [int(sizes[os.path.basename(shard)]) for shard in shards_list]
279
- # )
280
- # elif os.path.exists(len_filename):
281
- # # FIXME this used to be eval(open(...)) but that seemed rather unsafe
282
- # total_size = ast.literal_eval(open(len_filename, "r").read())
283
- # else:
284
- # raise Exception(
285
- # "Cannot find sizes file for dataset. Please specify the path to the file."
286
- # )
287
- # # total_size = None # num samples undefined
288
- # # some common dataset sizes (at time of authors last download)
289
- # # cc3m-train: 2905954
290
- # # cc12m: 10968539
291
- # # LAION-400m: 407332084
292
- # num_shards = len(shards_list)
293
- # if isinstance(shards, list):
294
- # return sum(size_list), len(shards)
295
- # else:
296
- # return total_size, num_shards
297
-
298
-
299
- def get_imagenet(args, preprocess_fns, split):
300
- assert split in ["train", "val", "v2"]
301
- is_train = split == "train"
302
- preprocess_train, preprocess_val = preprocess_fns
303
-
304
- if split == "v2":
305
- from imagenetv2_pytorch import ImageNetV2Dataset
306
-
307
- dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
308
- else:
309
- if is_train:
310
- data_path = args.imagenet_train
311
- preprocess_fn = preprocess_train
312
- else:
313
- data_path = args.imagenet_val
314
- preprocess_fn = preprocess_val
315
- assert data_path
316
-
317
- dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
318
-
319
- if is_train:
320
- idxs = np.zeros(len(dataset.targets))
321
- target_array = np.array(dataset.targets)
322
- k = 50
323
- for c in range(1000):
324
- m = target_array == c
325
- n = len(idxs[m])
326
- arr = np.zeros(n)
327
- arr[:k] = 1
328
- np.random.shuffle(arr)
329
- idxs[m] = arr
330
-
331
- idxs = idxs.astype("int")
332
- sampler = SubsetRandomSampler(np.where(idxs)[0])
333
- else:
334
- sampler = None
335
-
336
- dataloader = torch.utils.data.DataLoader(
337
- dataset,
338
- batch_size=args.batch_size,
339
- num_workers=args.workers,
340
- sampler=sampler,
341
- )
342
-
343
- return DataInfo(dataloader, sampler)
344
-
345
-
346
- def count_samples(dataloader):
347
- os.environ["WDS_EPOCH"] = "0"
348
- n_elements, n_batches = 0, 0
349
- for images, texts in dataloader:
350
- n_batches += 1
351
- n_elements += len(images)
352
- assert len(images) == len(texts)
353
- return n_elements, n_batches
354
-
355
-
356
- def filter_no_caption(sample):
357
- return "txt" in sample
358
-
359
-
360
- def log_and_continue(exn):
361
- """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
362
- logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
363
- return True
364
-
365
-
366
- _SHARD_SHUFFLE_SIZE = 2000
367
- _SHARD_SHUFFLE_INITIAL = 500
368
- _SAMPLE_SHUFFLE_SIZE = 5000
369
- _SAMPLE_SHUFFLE_INITIAL = 1000
370
-
371
-
372
- # def sample_prop(sizefile, inputs, proportion, is_local=True):
373
- # """
374
- # Sample a proportion of the data.
375
- # """
376
- # file_path_dict = {
377
- # os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
378
- # for i in range(len(inputs))
379
- # }
380
- # sampled_filepath_dict = {}
381
- # sampled_size_dict = {}
382
- # if not is_local:
383
- # if os.path.exists("sizes.json"):
384
- # os.remove("sizes.json")
385
- # wget.download(sizefile, "sizes.json")
386
- # sizefile = "sizes.json"
387
- # with open(sizefile, "r", encoding="UTF-8") as f:
388
- # load_dict = json.load(f)
389
- # L = int(len(file_path_dict) * proportion)
390
- # subkeys = random.sample(file_path_dict.keys(), L)
391
- # for k in subkeys:
392
- # sampled_size_dict[k] = load_dict[k]
393
- # sampled_filepath_dict[k] = file_path_dict[k]
394
- # return (
395
- # sum(sampled_size_dict.values()),
396
- # L,
397
- # [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
398
- # sampled_size_dict,
399
- # )
400
-
401
-
402
- def get_mel(audio_data, audio_cfg):
403
- # mel shape: (n_mels, T)
404
- mel = torchaudio.transforms.MelSpectrogram(
405
- sample_rate=audio_cfg["sample_rate"],
406
- n_fft=audio_cfg["window_size"],
407
- win_length=audio_cfg["window_size"],
408
- hop_length=audio_cfg["hop_size"],
409
- center=True,
410
- pad_mode="reflect",
411
- power=2.0,
412
- norm=None,
413
- onesided=True,
414
- n_mels=64,
415
- f_min=audio_cfg["fmin"],
416
- f_max=audio_cfg["fmax"],
417
- ).to(audio_data.device)
418
- mel = mel(audio_data)
419
- # we use log mel spectrogram as input
420
- mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
421
- return mel.T # (T, n_mels)
422
-
423
-
424
- def get_audio_features(
425
- audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
426
- ):
427
- """
428
- Calculate and add audio features to sample.
429
- Sample: a dict containing all the data of current sample.
430
- audio_data: a tensor of shape (T) containing audio data.
431
- max_len: the maximum length of audio data.
432
- data_truncating: the method of truncating data.
433
- data_filling: the method of filling data.
434
- audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
435
- """
436
- sample = {}
437
-
438
- # assert audio_data.size(-1) <= max_len, str(audio_data.size())
439
-
440
- # split to three parts
441
- chunk_frames = (
442
- max_len // audio_cfg["hop_size"] + 1
443
- ) # the +1 related to how the spectrogram is computed
444
- mel = mel[:chunk_frames]
445
-
446
- audio_data = audio_data[..., :max_len]
447
- sample["mel_fusion"] = mel
448
- longer = torch.tensor([True])
449
-
450
- sample["longer"] = longer
451
- sample["waveform"] = audio_data
452
-
453
- return sample
454
-
455
-
456
- def preprocess(
457
- sample,
458
- audio_ext,
459
- text_ext,
460
- max_len,
461
- audio_cfg,
462
- class_index_dict=None,
463
- data_filling="pad",
464
- data_truncating="rand_trunc",
465
- text_augment_selection=None,
466
- ):
467
- """
468
- Preprocess a single sample for wdsdataloader.
469
- """
470
- audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
471
- audio_data = int16_to_float32(float32_to_int16(audio_data))
472
- audio_data = torch.tensor(audio_data).float()
473
-
474
- # TODO: (yusong) to be include in the future
475
- # # if torchaudio not installed, use soundfile to load audio
476
- # if torchaudio is None:
477
- # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
478
- # audio_data = torch.tensor(audio_data).float()
479
- # else:
480
- # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
481
- # with tempfile.TemporaryDirectory() as dirname:
482
- # os.makedirs(dirname, exist_ok=True)
483
- # fname = os.path.join(dirname, f"file.flac")
484
- # with open(fname, "wb") as stream:
485
- # stream.write(sample[audio_ext])
486
- # audio_data, orig_sr = torchaudio.load(fname)
487
- # audio_data = audio_data[0, :].float()
488
-
489
- sample = get_audio_features(
490
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
491
- )
492
- del sample[audio_ext]
493
-
494
- try:
495
- json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
496
- except:
497
- print("sample[__url__]:", sample["__url__"])
498
-
499
- # For selecting augmented text from dataset
500
- if text_augment_selection is None or text_augment_selection == "none":
501
- texts = json_dict_raw["text"]
502
- elif text_augment_selection == "all":
503
- if "text_augment_all" in json_dict_raw.keys():
504
- texts = json_dict_raw["text_augment_all"]
505
- else:
506
- texts = json_dict_raw["text"]
507
- elif text_augment_selection == "augment_only":
508
- if "text_augment_all" in json_dict_raw.keys():
509
- if json_dict_raw["text_augment_t5"] is None:
510
- texts = json_dict_raw["text"]
511
- else:
512
- texts = json_dict_raw["text_augment_t5"]
513
- else:
514
- texts = json_dict_raw["text"]
515
- else:
516
- raise NotImplementedError(
517
- f"text_augment_selection {text_augment_selection} not implemented"
518
- )
519
- sample["full_text"] = texts
520
-
521
- if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
522
- texts = random.choice(texts)
523
- sample["raw_text"] = texts
524
- sample["text"] = tokenizer(texts) # text shape: [num_token]
525
- if class_index_dict is not None:
526
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
527
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
528
- # key, val = class_index_dict
529
- # key = key[:].split('\n')
530
- # _dict = {k: v for k, v in zip(key, val)}
531
- sample["class_label"] = np.zeros(len(class_index_dict.keys()))
532
- for x in json_dict_raw["tag"]:
533
- sample["class_label"][class_index_dict[x]] = 1
534
- sample["class_label"] = torch.tensor(sample["class_label"]).float()
535
- del sample[text_ext]
536
- sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
537
- sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
538
- sample["audio_orig_sr"] = orig_sr
539
- return sample
540
-
541
-
542
- def collate_fn(batch):
543
- """
544
- Collate function for wdsdataloader.
545
- batch: a list of dict, each dict is a sample
546
- """
547
- # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
548
- batch_dict = {}
549
- for k in batch[0].keys():
550
- if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
551
- batch_dict[k] = {}
552
- for kk in batch[0][k].keys():
553
- tmp = []
554
- for i in range(len(batch)):
555
- tmp.append(batch[i][k][kk])
556
- batch_dict[k][kk] = torch.vstack(tmp)
557
- elif isinstance(batch[0][k], torch.Tensor):
558
- batch_dict[k] = torch.stack([sample[k] for sample in batch])
559
- elif isinstance(batch[0][k], np.ndarray):
560
- batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
561
- else:
562
- batch_dict[k] = [sample[k] for sample in batch]
563
- return batch_dict
564
-
565
-
566
- # def get_wds_dataset(
567
- # args,
568
- # model_cfg,
569
- # is_train,
570
- # audio_ext="flac",
571
- # text_ext="json",
572
- # max_len=480000,
573
- # proportion=1.0,
574
- # sizefilepath_=None,
575
- # is_local=None,
576
- # ):
577
- # """
578
- # Get a dataset for wdsdataloader.
579
- # """
580
- # if is_local is None and (not args.remotedata is None):
581
- # is_local = not args.remotedata
582
-
583
- # input_shards = args.train_data if is_train else args.val_data
584
- # assert input_shards is not None
585
-
586
- # if not sizefilepath_ is None:
587
- # sizefilepath = sizefilepath_
588
- # else:
589
- # sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
590
-
591
- # if proportion != 1.0:
592
- # num_samples, num_shards, input_shards, _ = sample_prop(
593
- # sizefilepath, input_shards, proportion, is_local=is_local
594
- # )
595
- # else:
596
- # num_samples, num_shards = get_dataset_size(
597
- # input_shards, sizefilepath_=sizefilepath_, is_local=is_local
598
- # )
599
-
600
- # if not num_samples:
601
- # if is_train:
602
- # num_samples = args.train_num_samples
603
- # if not num_samples:
604
- # raise RuntimeError(
605
- # "Currently, number of dataset samples must be specified for training dataset. "
606
- # "Please specify via `--train-num-samples` if no dataset length info present."
607
- # )
608
- # else:
609
- # num_samples = (
610
- # args.val_num_samples or 0
611
- # ) # eval will just exhaust the iterator if not specified
612
-
613
- # pipeline = [wds.SimpleShardList(input_shards)]
614
- # # at this point we have an iterator over all the shards
615
- # # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
616
- # if is_train or args.parallel_eval:
617
- # pipeline.extend(
618
- # [
619
- # wds.detshuffle(
620
- # bufsize=_SHARD_SHUFFLE_SIZE,
621
- # initial=_SHARD_SHUFFLE_INITIAL,
622
- # seed=args.seed,
623
- # ),
624
- # wds.split_by_node,
625
- # wds.split_by_worker,
626
- # # at this point, we have an iterator over the shards assigned to each worker at each node
627
- # wds.tarfile_to_samples(handler=log_and_continue),
628
- # wds.shuffle(
629
- # bufsize=_SAMPLE_SHUFFLE_SIZE,
630
- # initial=_SAMPLE_SHUFFLE_INITIAL,
631
- # rng=random.Random(args.seed),
632
- # ),
633
- # # wds.repeatedly, # FIXME determine if this is beneficial
634
- # ]
635
- # )
636
- # else:
637
- # pipeline.extend(
638
- # [
639
- # wds.split_by_worker,
640
- # # at this point, we have an iterator over the shards assigned to each worker
641
- # wds.tarfile_to_samples(handler=log_and_continue),
642
- # ]
643
- # )
644
- # pipeline.append(
645
- # wds.map(
646
- # partial(
647
- # preprocess,
648
- # audio_ext=audio_ext,
649
- # text_ext=text_ext,
650
- # max_len=max_len,
651
- # audio_cfg=model_cfg["audio_cfg"],
652
- # class_index_dict=copy.deepcopy(args.class_index_dict),
653
- # data_filling=args.data_filling,
654
- # data_truncating=args.data_truncating,
655
- # text_augment_selection=args.text_augment_selection,
656
- # )
657
- # ),
658
- # )
659
-
660
- # pipeline.append(
661
- # wds.batched(
662
- # args.batch_size,
663
- # partial=not (is_train or args.parallel_eval),
664
- # collation_fn=collate_fn,
665
- # )
666
- # )
667
-
668
- # dataset = wds.DataPipeline(*pipeline)
669
- # if is_train or args.parallel_eval:
670
- # # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
671
- # # (yusong): See comments below.
672
- # # roll over and repeat a few samples to get same number of full batches on each node
673
- # global_batch_size = args.batch_size * args.world_size
674
- # num_batches = math.ceil(num_samples / global_batch_size)
675
- # num_workers = max(1, args.workers)
676
- # num_worker_batches = math.ceil(
677
- # num_batches / num_workers
678
- # ) # per dataloader worker
679
- # num_batches = num_worker_batches * num_workers
680
- # num_samples = num_batches * global_batch_size
681
- # dataset = dataset.with_epoch(
682
- # num_worker_batches
683
- # ) # each worker is iterating over this
684
- # else:
685
- # # last batches are partial, eval is done on single (master) node
686
- # num_batches = math.ceil(num_samples / args.batch_size)
687
-
688
- # kwargs = {}
689
- # if args.horovod: # multi-node training on summit
690
- # kwargs["multiprocessing_context"] = "forkserver"
691
-
692
- # dataloader = wds.WebLoader(
693
- # dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
694
- # )
695
-
696
- # # FIXME not clear which approach is better, with_epoch before vs after dataloader?
697
- # # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
698
- # # if is_train:
699
- # # # roll over and repeat a few samples to get same number of full batches on each node
700
- # # global_batch_size = args.batch_size * args.world_size
701
- # # num_batches = math.ceil(num_samples / global_batch_size)
702
- # # num_workers = max(1, args.workers)
703
- # # num_batches = math.ceil(num_batches / num_workers) * num_workers
704
- # # num_samples = num_batches * global_batch_size
705
- # # dataloader = dataloader.with_epoch(num_batches)
706
- # # else:
707
- # # # last batches are partial, eval is done on single (master) node
708
- # # num_batches = math.ceil(num_samples / args.batch_size)
709
-
710
- # # add meta-data to dataloader instance for convenience
711
- # dataloader.num_batches = num_batches
712
- # dataloader.num_samples = num_samples
713
-
714
- # return DataInfo(dataloader, None)
715
-
716
-
717
- def wds_batch_list2dict(
718
- batch,
719
- keys=[
720
- "__url__",
721
- "__key__",
722
- "waveform",
723
- "text",
724
- "raw_text",
725
- "audio_name",
726
- "text_name",
727
- "audio_orig_sr",
728
- ],
729
- ):
730
- """
731
- Return a dictionary of the batch, with keys as the names of the fields.
732
- """
733
- assert len(keys) == len(
734
- batch
735
- ), "batch must have same number of keys as keys argument"
736
- return {keys[i]: batch[i] for i in range(len(batch))}
737
-
738
-
739
- def get_csv_dataset(args, preprocess_fn, is_train):
740
- input_filename = args.train_data if is_train else args.val_data
741
- assert input_filename
742
- dataset = CsvDataset(
743
- input_filename,
744
- preprocess_fn,
745
- img_key=args.csv_img_key,
746
- caption_key=args.csv_caption_key,
747
- sep=args.csv_separator,
748
- )
749
- num_samples = len(dataset)
750
- sampler = DistributedSampler(dataset) if args.distributed and is_train else None
751
- shuffle = is_train and sampler is None
752
-
753
- dataloader = DataLoader(
754
- dataset,
755
- batch_size=args.batch_size,
756
- shuffle=shuffle,
757
- num_workers=args.workers,
758
- pin_memory=True,
759
- sampler=sampler,
760
- drop_last=is_train,
761
- )
762
- dataloader.num_samples = num_samples
763
- dataloader.num_batches = len(dataloader)
764
-
765
- return DataInfo(dataloader, sampler)
766
-
767
-
768
- def get_toy_dataset(args, model_cfg, is_train):
769
- index_path = args.train_data if is_train else args.val_data
770
- ipc_path = args.train_ipc if is_train else args.val_ipc
771
- assert index_path and ipc_path
772
- eval_mode = not is_train
773
- dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
774
-
775
- num_samples = len(dataset)
776
- sampler = (
777
- DistributedSampler(dataset, shuffle=False)
778
- if args.distributed and is_train
779
- else None
780
- )
781
-
782
- dataloader = DataLoader(
783
- dataset,
784
- batch_size=args.batch_size,
785
- shuffle=False,
786
- num_workers=args.workers,
787
- sampler=sampler,
788
- drop_last=is_train,
789
- )
790
- dataloader.num_samples = num_samples
791
- dataloader.num_batches = len(dataloader)
792
-
793
- return DataInfo(dataloader, sampler)
794
-
795
-
796
- def get_dataset_fn(data_path, dataset_type):
797
- if dataset_type == "webdataset":
798
- return get_wds_dataset
799
- elif dataset_type == "csv":
800
- return get_csv_dataset
801
- elif dataset_type == "auto":
802
- ext = data_path.split(".")[-1]
803
- if ext in ["csv", "tsv"]:
804
- return get_csv_dataset
805
- elif ext in ["tar"]:
806
- return get_wds_dataset
807
- else:
808
- raise ValueError(
809
- f"Tried to figure out dataset type, but failed for extension {ext}."
810
- )
811
- elif dataset_type == "toy":
812
- return get_toy_dataset
813
- else:
814
- raise ValueError(f"Unsupported dataset type: {dataset_type}")
815
-
816
-
817
- def get_data(args, model_cfg):
818
- data = {}
819
-
820
- args.class_index_dict = load_class_label(args.class_label_path)
821
-
822
- if args.datasetinfos is None:
823
- args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
824
- if args.dataset_type == "webdataset":
825
- args.train_data = get_tar_path_from_dataset_name(
826
- args.datasetnames,
827
- args.datasetinfos,
828
- islocal=not args.remotedata,
829
- proportion=args.dataset_proportion,
830
- dataset_path=args.datasetpath,
831
- full_dataset=args.full_train_dataset,
832
- )
833
-
834
- if args.full_train_dataset is None:
835
- args.full_train_dataset = []
836
- if args.exclude_eval_dataset is None:
837
- args.exclude_eval_dataset = []
838
- excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
839
-
840
- val_dataset_names = (
841
- [n for n in args.datasetnames if n not in excluded_eval_datasets]
842
- if excluded_eval_datasets
843
- else args.datasetnames
844
- )
845
- args.val_dataset_names = val_dataset_names
846
- args.val_data = get_tar_path_from_dataset_name(
847
- val_dataset_names,
848
- ["valid", "test", "eval"],
849
- islocal=not args.remotedata,
850
- proportion=1,
851
- dataset_path=args.datasetpath,
852
- full_dataset=None,
853
- )
854
-
855
- if args.train_data:
856
- data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
857
- args, model_cfg, is_train=True
858
- )
859
-
860
- if args.val_data:
861
- data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
862
- args, model_cfg, is_train=False
863
- )
864
-
865
- return data
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ from dataclasses import dataclass
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torchvision.datasets as datasets
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ import soundfile as sf
14
+ import io
15
+ from pathlib import Path
16
+
17
+ # import wget
18
+
19
+ from audiosr.clap.open_clip.utils import get_tar_path_from_dataset_name
20
+ from audiosr.clap.open_clip.utils import load_class_label
21
+
22
+ try:
23
+ import horovod.torch as hvd
24
+ except ImportError:
25
+ hvd = None
26
+
27
+ try:
28
+ import torchaudio
29
+ except ImportError:
30
+ torchaudio = None
31
+
32
+ from audiosr.clap.open_clip import tokenize
33
+
34
+
35
+ def tokenizer(text):
36
+ return tokenize(text).squeeze(0)
37
+
38
+
39
+ from transformers import RobertaTokenizer
40
+
41
+ tokenize = RobertaTokenizer.from_pretrained("roberta-base")
42
+
43
+
44
+ def tokenizer(text):
45
+ result = tokenize(
46
+ text,
47
+ padding="max_length",
48
+ truncation=True,
49
+ max_length=77,
50
+ return_tensors="pt",
51
+ )
52
+ return {k: v.squeeze(0) for k, v in result.items()}
53
+
54
+
55
+ # initizlied the audioset map
56
+ _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
57
+ _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
58
+
59
+
60
+ def int16_to_float32(x):
61
+ return (x / 32767.0).astype(np.float32)
62
+
63
+
64
+ def float32_to_int16(x):
65
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
66
+ return (x * 32767.0).astype(np.int16)
67
+
68
+
69
+ # For Toy Dataset
70
+ # class ToyDataset(Dataset):
71
+ # def __init__(self, index_path, ipc, config, eval_mode=False):
72
+ # """Toy Dataset for testing the audioset input with text labels
73
+ # Parameters
74
+ # ----------
75
+ # index_path: str
76
+ # the link to the h5 file of each audio
77
+ # idc: str
78
+ # the link to the npy file, the number of samples in each class
79
+ # config: dict
80
+ # the audio cfg file
81
+ # eval_model (bool): to indicate if the dataset is a testing dataset
82
+ # """
83
+ # self.audio_cfg = config["audio_cfg"]
84
+ # self.text_cfg = config["text_cfg"]
85
+ # self.fp = h5py.File(index_path, "r")
86
+ # self.ipc = np.load(ipc, allow_pickle=True)
87
+ # self.total_size = len(self.fp["audio_name"])
88
+ # self.classes_num = self.audio_cfg["class_num"]
89
+ # self.eval_mode = eval_mode
90
+
91
+ # if not eval_mode:
92
+ # self.generate_queue()
93
+ # else:
94
+ # self.queue = []
95
+ # for i in range(self.total_size):
96
+ # target = self.fp["target"][i]
97
+ # if np.sum(target) > 0:
98
+ # self.queue.append(i)
99
+ # self.total_size = len(self.queue)
100
+ # logging.info("total dataset size: %d" % (self.total_size))
101
+ # logging.info("class num: %d" % (self.classes_num))
102
+
103
+ # def time_shifting(self, x):
104
+ # frame_num = len(x)
105
+ # shift_len = random.randint(0, frame_num - 1)
106
+ # new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
107
+ # return new_sample
108
+
109
+ # def generate_queue(self):
110
+ # self.queue = []
111
+ # while len(self.queue) < self.total_size:
112
+ # class_set = [*range(self.classes_num)]
113
+ # random.shuffle(class_set)
114
+ # self.queue += [
115
+ # self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
116
+ # ]
117
+ # self.queue = self.queue[: self.total_size]
118
+
119
+ # logging.info("queue regenerated:%s" % (self.queue[-5:]))
120
+
121
+ # def crop_wav(self, x):
122
+ # crop_size = self.audio_cfg["crop_size"]
123
+ # crop_pos = random.randint(0, len(x) - crop_size - 1)
124
+ # return x[crop_pos : crop_pos + crop_size]
125
+
126
+ # def prompt_text(self, target):
127
+ # events = _AUDIOSET_MAP[np.where(target > 0)]
128
+ # event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
129
+ # text = tokenize(event_text)[0]
130
+ # return text
131
+
132
+ # def __getitem__(self, index):
133
+ # """Load waveform, text, and target of an audio clip
134
+
135
+ # Parameters
136
+ # ----------
137
+ # index: int
138
+ # the index number
139
+ # Return
140
+ # ------
141
+ # output: dict {
142
+ # "hdf5_path": str,
143
+ # "index_in_hdf5": int,
144
+ # "audio_name": str,
145
+ # "waveform": list (audio_length,),
146
+ # "target": list (class_num, ),
147
+ # "text": torch.tensor (context_length,)
148
+ # }
149
+ # the output dictionary
150
+ # """
151
+ # s_index = self.queue[index]
152
+
153
+ # audio_name = self.fp["audio_name"][s_index].decode()
154
+ # # Hardcode here CHANGE
155
+ # hdf5_path = (
156
+ # self.fp["hdf5_path"][s_index]
157
+ # .decode()
158
+ # .replace(
159
+ # "../workspace",
160
+ # "/home/la/kechen/Research/ke_zsasp/workspace",
161
+ # )
162
+ # )
163
+ # r_idx = self.fp["index_in_hdf5"][s_index]
164
+ # target = self.fp["target"][s_index].astype(np.float32)
165
+ # text = self.prompt_text(target)
166
+ # with h5py.File(hdf5_path, "r") as f:
167
+ # waveform = int16_to_float32(f["waveform"][r_idx])[
168
+ # : self.audio_cfg["clip_samples"]
169
+ # ]
170
+ # assert (
171
+ # len(waveform) == self.audio_cfg["clip_samples"]
172
+ # ), "The sample length is not match"
173
+ # # Time shift
174
+ # # if (self.config.enable_time_shift) and (not self.eval_mode):
175
+ # # waveform = self.time_shifting(waveform)
176
+ # # # Label Enhance
177
+ # # if (self.config.crop_size is not None) and (not self.eval_mode):
178
+ # # waveform = self.crop_wav(waveform)
179
+ # # # the label enhance rate is fixed 0.5
180
+ # # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
181
+ # # kidx = np.where(target)[0]
182
+ # # for k in kidx:
183
+ # # for add_key in self.class_map[k][1]:
184
+ # # target[add_key] = 1.0
185
+ # # if len(self.class_map[k][2]) > 0:
186
+ # # add_key = random.choice(self.class_map[k][2])
187
+ # # target[add_key] = 1.0
188
+
189
+ # # missing the text input
190
+ # mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
191
+ # mel_spec = (
192
+ # torch.cat(
193
+ # [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
194
+ # )
195
+ # .cpu()
196
+ # .numpy()
197
+ # )
198
+ # longer = random.choice([True, False])
199
+ # if longer == False:
200
+ # mel_spec[1:, :, :] = 0.0
201
+ # data_dict = {
202
+ # "hdf5_path": hdf5_path,
203
+ # "index_in_hdf5": r_idx,
204
+ # "audio_name": audio_name,
205
+ # "waveform": waveform,
206
+ # "class_label": target,
207
+ # "text": text,
208
+ # "longer": longer,
209
+ # "mel_fusion": mel_spec,
210
+ # }
211
+ # return data_dict
212
+
213
+ # def __len__(self):
214
+ # return self.total_size
215
+
216
+
217
+ class CsvDataset(Dataset):
218
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
219
+ logging.debug(f"Loading csv data from {input_filename}.")
220
+ df = pd.read_csv(input_filename, sep=sep)
221
+
222
+ self.images = df[img_key].tolist()
223
+ self.captions = df[caption_key].tolist()
224
+ self.transforms = transforms
225
+ logging.debug("Done loading data.")
226
+
227
+ def __len__(self):
228
+ return len(self.captions)
229
+
230
+ def __getitem__(self, idx):
231
+ images = self.transforms(Image.open(str(self.images[idx])))
232
+ texts = tokenize([str(self.captions[idx])])[0]
233
+ return images, texts
234
+
235
+
236
+ @dataclass
237
+ class DataInfo:
238
+ dataloader: DataLoader
239
+ sampler: DistributedSampler
240
+
241
+
242
+ def preprocess_txt(text):
243
+ return tokenize([str(text)])[0]
244
+
245
+
246
+ # def get_dataset_size(shards, sizefilepath_=None, is_local=True):
247
+ # if isinstance(shards, list):
248
+ # size_list = []
249
+ # for s in shards:
250
+ # size_list.append(
251
+ # get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
252
+ # )
253
+ # else:
254
+ # if not is_local:
255
+ # for n in dataset_split.keys():
256
+ # if n in shards.split("/"):
257
+ # break
258
+ # for s in dataset_split[n]:
259
+ # if s in shards.split("/"):
260
+ # break
261
+ # sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
262
+ # shards_list = list(braceexpand.braceexpand(shards))
263
+ # dir_path = os.path.dirname(shards)
264
+ # if sizefilepath_ is not None:
265
+ # sizes = json.load(open(sizefilepath_, "r"))
266
+ # total_size = sum(
267
+ # [
268
+ # int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
269
+ # for shard in shards_list
270
+ # ]
271
+ # )
272
+ # else:
273
+ # sizes_filename = os.path.join(dir_path, "sizes.json")
274
+ # len_filename = os.path.join(dir_path, "__len__")
275
+ # if os.path.exists(sizes_filename):
276
+ # sizes = json.load(open(sizes_filename, "r"))
277
+ # total_size = sum(
278
+ # [int(sizes[os.path.basename(shard)]) for shard in shards_list]
279
+ # )
280
+ # elif os.path.exists(len_filename):
281
+ # # FIXME this used to be eval(open(...)) but that seemed rather unsafe
282
+ # total_size = ast.literal_eval(open(len_filename, "r").read())
283
+ # else:
284
+ # raise Exception(
285
+ # "Cannot find sizes file for dataset. Please specify the path to the file."
286
+ # )
287
+ # # total_size = None # num samples undefined
288
+ # # some common dataset sizes (at time of authors last download)
289
+ # # cc3m-train: 2905954
290
+ # # cc12m: 10968539
291
+ # # LAION-400m: 407332084
292
+ # num_shards = len(shards_list)
293
+ # if isinstance(shards, list):
294
+ # return sum(size_list), len(shards)
295
+ # else:
296
+ # return total_size, num_shards
297
+
298
+
299
+ def get_imagenet(args, preprocess_fns, split):
300
+ assert split in ["train", "val", "v2"]
301
+ is_train = split == "train"
302
+ preprocess_train, preprocess_val = preprocess_fns
303
+
304
+ if split == "v2":
305
+ from imagenetv2_pytorch import ImageNetV2Dataset
306
+
307
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
308
+ else:
309
+ if is_train:
310
+ data_path = args.imagenet_train
311
+ preprocess_fn = preprocess_train
312
+ else:
313
+ data_path = args.imagenet_val
314
+ preprocess_fn = preprocess_val
315
+ assert data_path
316
+
317
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
318
+
319
+ if is_train:
320
+ idxs = np.zeros(len(dataset.targets))
321
+ target_array = np.array(dataset.targets)
322
+ k = 50
323
+ for c in range(1000):
324
+ m = target_array == c
325
+ n = len(idxs[m])
326
+ arr = np.zeros(n)
327
+ arr[:k] = 1
328
+ np.random.shuffle(arr)
329
+ idxs[m] = arr
330
+
331
+ idxs = idxs.astype("int")
332
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
333
+ else:
334
+ sampler = None
335
+
336
+ dataloader = torch.utils.data.DataLoader(
337
+ dataset,
338
+ batch_size=args.batch_size,
339
+ num_workers=args.workers,
340
+ sampler=sampler,
341
+ )
342
+
343
+ return DataInfo(dataloader, sampler)
344
+
345
+
346
+ def count_samples(dataloader):
347
+ os.environ["WDS_EPOCH"] = "0"
348
+ n_elements, n_batches = 0, 0
349
+ for images, texts in dataloader:
350
+ n_batches += 1
351
+ n_elements += len(images)
352
+ assert len(images) == len(texts)
353
+ return n_elements, n_batches
354
+
355
+
356
+ def filter_no_caption(sample):
357
+ return "txt" in sample
358
+
359
+
360
+ def log_and_continue(exn):
361
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
362
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
363
+ return True
364
+
365
+
366
+ _SHARD_SHUFFLE_SIZE = 2000
367
+ _SHARD_SHUFFLE_INITIAL = 500
368
+ _SAMPLE_SHUFFLE_SIZE = 5000
369
+ _SAMPLE_SHUFFLE_INITIAL = 1000
370
+
371
+
372
+ # def sample_prop(sizefile, inputs, proportion, is_local=True):
373
+ # """
374
+ # Sample a proportion of the data.
375
+ # """
376
+ # file_path_dict = {
377
+ # os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
378
+ # for i in range(len(inputs))
379
+ # }
380
+ # sampled_filepath_dict = {}
381
+ # sampled_size_dict = {}
382
+ # if not is_local:
383
+ # if os.path.exists("sizes.json"):
384
+ # os.remove("sizes.json")
385
+ # wget.download(sizefile, "sizes.json")
386
+ # sizefile = "sizes.json"
387
+ # with open(sizefile, "r", encoding="UTF-8") as f:
388
+ # load_dict = json.load(f)
389
+ # L = int(len(file_path_dict) * proportion)
390
+ # subkeys = random.sample(file_path_dict.keys(), L)
391
+ # for k in subkeys:
392
+ # sampled_size_dict[k] = load_dict[k]
393
+ # sampled_filepath_dict[k] = file_path_dict[k]
394
+ # return (
395
+ # sum(sampled_size_dict.values()),
396
+ # L,
397
+ # [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
398
+ # sampled_size_dict,
399
+ # )
400
+
401
+
402
+ def get_mel(audio_data, audio_cfg):
403
+ # mel shape: (n_mels, T)
404
+ mel = torchaudio.transforms.MelSpectrogram(
405
+ sample_rate=audio_cfg["sample_rate"],
406
+ n_fft=audio_cfg["window_size"],
407
+ win_length=audio_cfg["window_size"],
408
+ hop_length=audio_cfg["hop_size"],
409
+ center=True,
410
+ pad_mode="reflect",
411
+ power=2.0,
412
+ norm=None,
413
+ onesided=True,
414
+ n_mels=64,
415
+ f_min=audio_cfg["fmin"],
416
+ f_max=audio_cfg["fmax"],
417
+ ).to(audio_data.device)
418
+ mel = mel(audio_data)
419
+ # we use log mel spectrogram as input
420
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
421
+ return mel.T # (T, n_mels)
422
+
423
+
424
+ def get_audio_features(
425
+ audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
426
+ ):
427
+ """
428
+ Calculate and add audio features to sample.
429
+ Sample: a dict containing all the data of current sample.
430
+ audio_data: a tensor of shape (T) containing audio data.
431
+ max_len: the maximum length of audio data.
432
+ data_truncating: the method of truncating data.
433
+ data_filling: the method of filling data.
434
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
435
+ """
436
+ sample = {}
437
+
438
+ # assert audio_data.size(-1) <= max_len, str(audio_data.size())
439
+
440
+ # split to three parts
441
+ chunk_frames = (
442
+ max_len // audio_cfg["hop_size"] + 1
443
+ ) # the +1 related to how the spectrogram is computed
444
+ mel = mel[:chunk_frames]
445
+
446
+ audio_data = audio_data[..., :max_len]
447
+ sample["mel_fusion"] = mel
448
+ longer = torch.tensor([True])
449
+
450
+ sample["longer"] = longer
451
+ sample["waveform"] = audio_data
452
+
453
+ return sample
454
+
455
+
456
+ def preprocess(
457
+ sample,
458
+ audio_ext,
459
+ text_ext,
460
+ max_len,
461
+ audio_cfg,
462
+ class_index_dict=None,
463
+ data_filling="pad",
464
+ data_truncating="rand_trunc",
465
+ text_augment_selection=None,
466
+ ):
467
+ """
468
+ Preprocess a single sample for wdsdataloader.
469
+ """
470
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
471
+ audio_data = int16_to_float32(float32_to_int16(audio_data))
472
+ audio_data = torch.tensor(audio_data).float()
473
+
474
+ # TODO: (yusong) to be include in the future
475
+ # # if torchaudio not installed, use soundfile to load audio
476
+ # if torchaudio is None:
477
+ # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
478
+ # audio_data = torch.tensor(audio_data).float()
479
+ # else:
480
+ # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
481
+ # with tempfile.TemporaryDirectory() as dirname:
482
+ # os.makedirs(dirname, exist_ok=True)
483
+ # fname = os.path.join(dirname, f"file.flac")
484
+ # with open(fname, "wb") as stream:
485
+ # stream.write(sample[audio_ext])
486
+ # audio_data, orig_sr = torchaudio.load(fname)
487
+ # audio_data = audio_data[0, :].float()
488
+
489
+ sample = get_audio_features(
490
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
491
+ )
492
+ del sample[audio_ext]
493
+
494
+ try:
495
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
496
+ except:
497
+ print("sample[__url__]:", sample["__url__"])
498
+
499
+ # For selecting augmented text from dataset
500
+ if text_augment_selection is None or text_augment_selection == "none":
501
+ texts = json_dict_raw["text"]
502
+ elif text_augment_selection == "all":
503
+ if "text_augment_all" in json_dict_raw.keys():
504
+ texts = json_dict_raw["text_augment_all"]
505
+ else:
506
+ texts = json_dict_raw["text"]
507
+ elif text_augment_selection == "augment_only":
508
+ if "text_augment_all" in json_dict_raw.keys():
509
+ if json_dict_raw["text_augment_t5"] is None:
510
+ texts = json_dict_raw["text"]
511
+ else:
512
+ texts = json_dict_raw["text_augment_t5"]
513
+ else:
514
+ texts = json_dict_raw["text"]
515
+ else:
516
+ raise NotImplementedError(
517
+ f"text_augment_selection {text_augment_selection} not implemented"
518
+ )
519
+ sample["full_text"] = texts
520
+
521
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
522
+ texts = random.choice(texts)
523
+ sample["raw_text"] = texts
524
+ sample["text"] = tokenizer(texts) # text shape: [num_token]
525
+ if class_index_dict is not None:
526
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
527
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
528
+ # key, val = class_index_dict
529
+ # key = key[:].split('\n')
530
+ # _dict = {k: v for k, v in zip(key, val)}
531
+ sample["class_label"] = np.zeros(len(class_index_dict.keys()))
532
+ for x in json_dict_raw["tag"]:
533
+ sample["class_label"][class_index_dict[x]] = 1
534
+ sample["class_label"] = torch.tensor(sample["class_label"]).float()
535
+ del sample[text_ext]
536
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
537
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
538
+ sample["audio_orig_sr"] = orig_sr
539
+ return sample
540
+
541
+
542
+ def collate_fn(batch):
543
+ """
544
+ Collate function for wdsdataloader.
545
+ batch: a list of dict, each dict is a sample
546
+ """
547
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
548
+ batch_dict = {}
549
+ for k in batch[0].keys():
550
+ if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
551
+ batch_dict[k] = {}
552
+ for kk in batch[0][k].keys():
553
+ tmp = []
554
+ for i in range(len(batch)):
555
+ tmp.append(batch[i][k][kk])
556
+ batch_dict[k][kk] = torch.vstack(tmp)
557
+ elif isinstance(batch[0][k], torch.Tensor):
558
+ batch_dict[k] = torch.stack([sample[k] for sample in batch])
559
+ elif isinstance(batch[0][k], np.ndarray):
560
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
561
+ else:
562
+ batch_dict[k] = [sample[k] for sample in batch]
563
+ return batch_dict
564
+
565
+
566
+ # def get_wds_dataset(
567
+ # args,
568
+ # model_cfg,
569
+ # is_train,
570
+ # audio_ext="flac",
571
+ # text_ext="json",
572
+ # max_len=480000,
573
+ # proportion=1.0,
574
+ # sizefilepath_=None,
575
+ # is_local=None,
576
+ # ):
577
+ # """
578
+ # Get a dataset for wdsdataloader.
579
+ # """
580
+ # if is_local is None and (not args.remotedata is None):
581
+ # is_local = not args.remotedata
582
+
583
+ # input_shards = args.train_data if is_train else args.val_data
584
+ # assert input_shards is not None
585
+
586
+ # if not sizefilepath_ is None:
587
+ # sizefilepath = sizefilepath_
588
+ # else:
589
+ # sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
590
+
591
+ # if proportion != 1.0:
592
+ # num_samples, num_shards, input_shards, _ = sample_prop(
593
+ # sizefilepath, input_shards, proportion, is_local=is_local
594
+ # )
595
+ # else:
596
+ # num_samples, num_shards = get_dataset_size(
597
+ # input_shards, sizefilepath_=sizefilepath_, is_local=is_local
598
+ # )
599
+
600
+ # if not num_samples:
601
+ # if is_train:
602
+ # num_samples = args.train_num_samples
603
+ # if not num_samples:
604
+ # raise RuntimeError(
605
+ # "Currently, number of dataset samples must be specified for training dataset. "
606
+ # "Please specify via `--train-num-samples` if no dataset length info present."
607
+ # )
608
+ # else:
609
+ # num_samples = (
610
+ # args.val_num_samples or 0
611
+ # ) # eval will just exhaust the iterator if not specified
612
+
613
+ # pipeline = [wds.SimpleShardList(input_shards)]
614
+ # # at this point we have an iterator over all the shards
615
+ # # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
616
+ # if is_train or args.parallel_eval:
617
+ # pipeline.extend(
618
+ # [
619
+ # wds.detshuffle(
620
+ # bufsize=_SHARD_SHUFFLE_SIZE,
621
+ # initial=_SHARD_SHUFFLE_INITIAL,
622
+ # seed=args.seed,
623
+ # ),
624
+ # wds.split_by_node,
625
+ # wds.split_by_worker,
626
+ # # at this point, we have an iterator over the shards assigned to each worker at each node
627
+ # wds.tarfile_to_samples(handler=log_and_continue),
628
+ # wds.shuffle(
629
+ # bufsize=_SAMPLE_SHUFFLE_SIZE,
630
+ # initial=_SAMPLE_SHUFFLE_INITIAL,
631
+ # rng=random.Random(args.seed),
632
+ # ),
633
+ # # wds.repeatedly, # FIXME determine if this is beneficial
634
+ # ]
635
+ # )
636
+ # else:
637
+ # pipeline.extend(
638
+ # [
639
+ # wds.split_by_worker,
640
+ # # at this point, we have an iterator over the shards assigned to each worker
641
+ # wds.tarfile_to_samples(handler=log_and_continue),
642
+ # ]
643
+ # )
644
+ # pipeline.append(
645
+ # wds.map(
646
+ # partial(
647
+ # preprocess,
648
+ # audio_ext=audio_ext,
649
+ # text_ext=text_ext,
650
+ # max_len=max_len,
651
+ # audio_cfg=model_cfg["audio_cfg"],
652
+ # class_index_dict=copy.deepcopy(args.class_index_dict),
653
+ # data_filling=args.data_filling,
654
+ # data_truncating=args.data_truncating,
655
+ # text_augment_selection=args.text_augment_selection,
656
+ # )
657
+ # ),
658
+ # )
659
+
660
+ # pipeline.append(
661
+ # wds.batched(
662
+ # args.batch_size,
663
+ # partial=not (is_train or args.parallel_eval),
664
+ # collation_fn=collate_fn,
665
+ # )
666
+ # )
667
+
668
+ # dataset = wds.DataPipeline(*pipeline)
669
+ # if is_train or args.parallel_eval:
670
+ # # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
671
+ # # (yusong): See comments below.
672
+ # # roll over and repeat a few samples to get same number of full batches on each node
673
+ # global_batch_size = args.batch_size * args.world_size
674
+ # num_batches = math.ceil(num_samples / global_batch_size)
675
+ # num_workers = max(1, args.workers)
676
+ # num_worker_batches = math.ceil(
677
+ # num_batches / num_workers
678
+ # ) # per dataloader worker
679
+ # num_batches = num_worker_batches * num_workers
680
+ # num_samples = num_batches * global_batch_size
681
+ # dataset = dataset.with_epoch(
682
+ # num_worker_batches
683
+ # ) # each worker is iterating over this
684
+ # else:
685
+ # # last batches are partial, eval is done on single (master) node
686
+ # num_batches = math.ceil(num_samples / args.batch_size)
687
+
688
+ # kwargs = {}
689
+ # if args.horovod: # multi-node training on summit
690
+ # kwargs["multiprocessing_context"] = "forkserver"
691
+
692
+ # dataloader = wds.WebLoader(
693
+ # dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
694
+ # )
695
+
696
+ # # FIXME not clear which approach is better, with_epoch before vs after dataloader?
697
+ # # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
698
+ # # if is_train:
699
+ # # # roll over and repeat a few samples to get same number of full batches on each node
700
+ # # global_batch_size = args.batch_size * args.world_size
701
+ # # num_batches = math.ceil(num_samples / global_batch_size)
702
+ # # num_workers = max(1, args.workers)
703
+ # # num_batches = math.ceil(num_batches / num_workers) * num_workers
704
+ # # num_samples = num_batches * global_batch_size
705
+ # # dataloader = dataloader.with_epoch(num_batches)
706
+ # # else:
707
+ # # # last batches are partial, eval is done on single (master) node
708
+ # # num_batches = math.ceil(num_samples / args.batch_size)
709
+
710
+ # # add meta-data to dataloader instance for convenience
711
+ # dataloader.num_batches = num_batches
712
+ # dataloader.num_samples = num_samples
713
+
714
+ # return DataInfo(dataloader, None)
715
+
716
+
717
+ def wds_batch_list2dict(
718
+ batch,
719
+ keys=[
720
+ "__url__",
721
+ "__key__",
722
+ "waveform",
723
+ "text",
724
+ "raw_text",
725
+ "audio_name",
726
+ "text_name",
727
+ "audio_orig_sr",
728
+ ],
729
+ ):
730
+ """
731
+ Return a dictionary of the batch, with keys as the names of the fields.
732
+ """
733
+ assert len(keys) == len(
734
+ batch
735
+ ), "batch must have same number of keys as keys argument"
736
+ return {keys[i]: batch[i] for i in range(len(batch))}
737
+
738
+
739
+ def get_csv_dataset(args, preprocess_fn, is_train):
740
+ input_filename = args.train_data if is_train else args.val_data
741
+ assert input_filename
742
+ dataset = CsvDataset(
743
+ input_filename,
744
+ preprocess_fn,
745
+ img_key=args.csv_img_key,
746
+ caption_key=args.csv_caption_key,
747
+ sep=args.csv_separator,
748
+ )
749
+ num_samples = len(dataset)
750
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
751
+ shuffle = is_train and sampler is None
752
+
753
+ dataloader = DataLoader(
754
+ dataset,
755
+ batch_size=args.batch_size,
756
+ shuffle=shuffle,
757
+ num_workers=args.workers,
758
+ pin_memory=True,
759
+ sampler=sampler,
760
+ drop_last=is_train,
761
+ )
762
+ dataloader.num_samples = num_samples
763
+ dataloader.num_batches = len(dataloader)
764
+
765
+ return DataInfo(dataloader, sampler)
766
+
767
+
768
+ def get_toy_dataset(args, model_cfg, is_train):
769
+ index_path = args.train_data if is_train else args.val_data
770
+ ipc_path = args.train_ipc if is_train else args.val_ipc
771
+ assert index_path and ipc_path
772
+ eval_mode = not is_train
773
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
774
+
775
+ num_samples = len(dataset)
776
+ sampler = (
777
+ DistributedSampler(dataset, shuffle=False)
778
+ if args.distributed and is_train
779
+ else None
780
+ )
781
+
782
+ dataloader = DataLoader(
783
+ dataset,
784
+ batch_size=args.batch_size,
785
+ shuffle=False,
786
+ num_workers=args.workers,
787
+ sampler=sampler,
788
+ drop_last=is_train,
789
+ )
790
+ dataloader.num_samples = num_samples
791
+ dataloader.num_batches = len(dataloader)
792
+
793
+ return DataInfo(dataloader, sampler)
794
+
795
+
796
+ def get_dataset_fn(data_path, dataset_type):
797
+ if dataset_type == "webdataset":
798
+ return get_wds_dataset
799
+ elif dataset_type == "csv":
800
+ return get_csv_dataset
801
+ elif dataset_type == "auto":
802
+ ext = data_path.split(".")[-1]
803
+ if ext in ["csv", "tsv"]:
804
+ return get_csv_dataset
805
+ elif ext in ["tar"]:
806
+ return get_wds_dataset
807
+ else:
808
+ raise ValueError(
809
+ f"Tried to figure out dataset type, but failed for extension {ext}."
810
+ )
811
+ elif dataset_type == "toy":
812
+ return get_toy_dataset
813
+ else:
814
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
815
+
816
+
817
+ def get_data(args, model_cfg):
818
+ data = {}
819
+
820
+ args.class_index_dict = load_class_label(args.class_label_path)
821
+
822
+ if args.datasetinfos is None:
823
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
824
+ if args.dataset_type == "webdataset":
825
+ args.train_data = get_tar_path_from_dataset_name(
826
+ args.datasetnames,
827
+ args.datasetinfos,
828
+ islocal=not args.remotedata,
829
+ proportion=args.dataset_proportion,
830
+ dataset_path=args.datasetpath,
831
+ full_dataset=args.full_train_dataset,
832
+ )
833
+
834
+ if args.full_train_dataset is None:
835
+ args.full_train_dataset = []
836
+ if args.exclude_eval_dataset is None:
837
+ args.exclude_eval_dataset = []
838
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
839
+
840
+ val_dataset_names = (
841
+ [n for n in args.datasetnames if n not in excluded_eval_datasets]
842
+ if excluded_eval_datasets
843
+ else args.datasetnames
844
+ )
845
+ args.val_dataset_names = val_dataset_names
846
+ args.val_data = get_tar_path_from_dataset_name(
847
+ val_dataset_names,
848
+ ["valid", "test", "eval"],
849
+ islocal=not args.remotedata,
850
+ proportion=1,
851
+ dataset_path=args.datasetpath,
852
+ full_dataset=None,
853
+ )
854
+
855
+ if args.train_data:
856
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
857
+ args, model_cfg, is_train=True
858
+ )
859
+
860
+ if args.val_data:
861
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
862
+ args, model_cfg, is_train=False
863
+ )
864
+
865
+ return data
audiosr/clap/training/params.py CHANGED
@@ -1,563 +1,563 @@
1
- import argparse
2
-
3
-
4
- def get_default_params(model_name):
5
- # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
6
- model_name = model_name.lower()
7
- if "vit" in model_name:
8
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
9
- else:
10
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
11
-
12
-
13
- def parse_args():
14
- parser = argparse.ArgumentParser()
15
- parser.add_argument(
16
- "--train-data",
17
- type=str,
18
- default=None,
19
- help="Path to h5 filewith training data",
20
- )
21
- parser.add_argument(
22
- "--val-data",
23
- type=str,
24
- default=None,
25
- help="Path to h5 file with validation data",
26
- )
27
- parser.add_argument(
28
- "--freeze-text",
29
- default=False,
30
- action="store_true",
31
- help="if you need to freeze the text encoder, make this True",
32
- )
33
- parser.add_argument(
34
- "--freeze-text-after",
35
- type=int,
36
- default=-1,
37
- help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
38
- )
39
- parser.add_argument(
40
- "--train-ipc",
41
- type=str,
42
- default=None,
43
- help="Path to npy file of the number of instance per class in training data",
44
- )
45
- parser.add_argument(
46
- "--val-ipc",
47
- type=str,
48
- default=None,
49
- help="Path to npy file of the number of instance per class in validation data",
50
- )
51
- parser.add_argument(
52
- "--train-num-samples",
53
- type=int,
54
- default=None,
55
- help="Number of samples in dataset. Required for webdataset if not available in info file.",
56
- )
57
- parser.add_argument(
58
- "--val-num-samples",
59
- type=int,
60
- default=None,
61
- help="Number of samples in dataset. Useful for webdataset if not available in info file.",
62
- )
63
- parser.add_argument(
64
- "--dataset-type",
65
- choices=["webdataset", "csv", "auto", "toy"],
66
- default="auto",
67
- help="Which type of dataset to process.",
68
- )
69
- parser.add_argument(
70
- "--csv-separator",
71
- type=str,
72
- default="\t",
73
- help="For csv-like datasets, which separator to use.",
74
- )
75
- parser.add_argument(
76
- "--csv-img-key",
77
- type=str,
78
- default="filepath",
79
- help="For csv-like datasets, the name of the key for the image paths.",
80
- )
81
- parser.add_argument(
82
- "--csv-caption-key",
83
- type=str,
84
- default="title",
85
- help="For csv-like datasets, the name of the key for the captions.",
86
- )
87
- parser.add_argument(
88
- "--imagenet-val",
89
- type=str,
90
- default=None,
91
- help="Path to imagenet val set for conducting zero shot evaluation.",
92
- )
93
- parser.add_argument(
94
- "--imagenet-v2",
95
- type=str,
96
- default=None,
97
- help="Path to imagenet v2 for conducting zero shot evaluation.",
98
- )
99
- parser.add_argument(
100
- "--datasetnames",
101
- nargs="+",
102
- default=None,
103
- help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
104
- )
105
- parser.add_argument(
106
- "--full-train-dataset",
107
- nargs="+",
108
- default=None,
109
- help="Which dataset will be trained with all the subsets. (train+test)",
110
- )
111
- parser.add_argument(
112
- "--exclude-eval-dataset",
113
- nargs="+",
114
- default=None,
115
- help="Which dataset will be excluded with evaluation",
116
- )
117
- parser.add_argument(
118
- "--datasetinfos",
119
- nargs="+",
120
- default=None,
121
- help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
122
- )
123
- parser.add_argument(
124
- "--dataset-proportion",
125
- type=float,
126
- default=1.0,
127
- help="How much proportion of dataset we want to train.",
128
- )
129
- parser.add_argument(
130
- "--remotedata",
131
- default=False,
132
- action="store_true",
133
- help="if the dataset is remote, set this flag",
134
- )
135
- parser.add_argument(
136
- "--class-label-path",
137
- type=str,
138
- default=None,
139
- help="The path of the class label pickle or csv.",
140
- )
141
- parser.add_argument(
142
- "--datasetpath",
143
- type=str,
144
- default="/mnt/audio_clip/webdataset_tar",
145
- help="The path to the dataset",
146
- )
147
- parser.add_argument(
148
- "--logs",
149
- type=str,
150
- default="./logs/",
151
- help="Where to store tensorboard logs. Use None to avoid storing logs.",
152
- )
153
- parser.add_argument(
154
- "--log-local",
155
- action="store_true",
156
- default=False,
157
- help="log files on local master, otherwise global master only.",
158
- )
159
- parser.add_argument(
160
- "--name",
161
- type=str,
162
- default=None,
163
- help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
164
- )
165
- parser.add_argument(
166
- "--workers", type=int, default=1, help="Number of workers per GPU."
167
- )
168
- parser.add_argument(
169
- "--batch-size", type=int, default=64, help="Batch size per GPU."
170
- )
171
- parser.add_argument(
172
- "--epochs", type=int, default=32, help="Number of epochs to train for."
173
- )
174
- parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
175
- parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
176
- parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
177
- parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
178
- parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
179
- parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
180
-
181
- parser.add_argument(
182
- "--split-opt",
183
- action="store_true",
184
- default=False,
185
- help="Use this flag to skip the learning rate decay.",
186
- )
187
- parser.add_argument(
188
- "--lr-pretrained", type=float, default=None, help="Learning rate for text."
189
- )
190
- parser.add_argument(
191
- "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
192
- )
193
- parser.add_argument(
194
- "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
195
- )
196
- parser.add_argument(
197
- "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
198
- )
199
- parser.add_argument(
200
- "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
201
- )
202
- parser.add_argument(
203
- "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
204
- )
205
- parser.add_argument(
206
- "--lr-new", type=float, default=None, help="Learning rate for audio."
207
- )
208
- parser.add_argument(
209
- "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
210
- )
211
- parser.add_argument(
212
- "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
213
- )
214
- parser.add_argument(
215
- "--eps-new", type=float, default=None, help="Adam epsilon for audio."
216
- )
217
- parser.add_argument(
218
- "--wd-new", type=float, default=0.2, help="Weight decay for audio."
219
- )
220
- parser.add_argument(
221
- "--momentum-new", type=float, default=0.9, help="Momentum for audio."
222
- )
223
- parser.add_argument(
224
- "--warmup", type=int, default=10000, help="Number of steps to warmup for."
225
- )
226
- parser.add_argument(
227
- "--use-bn-sync",
228
- default=False,
229
- action="store_true",
230
- help="Whether to use batch norm sync.",
231
- )
232
- parser.add_argument(
233
- "--skip-scheduler",
234
- action="store_true",
235
- default=False,
236
- help="Use this flag to skip the learning rate decay.",
237
- )
238
- parser.add_argument(
239
- "--save-frequency", type=int, default=1, help="How often to save checkpoints."
240
- )
241
- parser.add_argument(
242
- "--save-top-performance",
243
- type=int,
244
- default=0,
245
- help="Save the top x performance weights if the value >0",
246
- )
247
- parser.add_argument(
248
- "--save-most-recent",
249
- action="store_true",
250
- default=False,
251
- help="Always save the most recent model trained to epoch_latest.pt.",
252
- )
253
- parser.add_argument(
254
- "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
255
- )
256
- parser.add_argument(
257
- "--val-frequency",
258
- type=int,
259
- default=1,
260
- help="How often to run evaluation with val data.",
261
- )
262
- parser.add_argument(
263
- "--resume",
264
- default=None,
265
- type=str,
266
- help="path to latest checkpoint (default: none)",
267
- )
268
- parser.add_argument(
269
- "--precision",
270
- choices=["amp", "fp16", "fp32"],
271
- default="amp",
272
- help="Floating point precision.",
273
- )
274
- parser.add_argument(
275
- "--amodel",
276
- type=str,
277
- default="RN50",
278
- help="Name of the audio backbone to use.",
279
- )
280
- parser.add_argument(
281
- "--tmodel",
282
- type=str,
283
- default="transformer",
284
- help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
285
- )
286
- parser.add_argument(
287
- "--pretrained-audio",
288
- default="",
289
- type=str,
290
- help="Use a pretrained audio model weights for the audio encoder of CLAP",
291
- )
292
- parser.add_argument(
293
- "--pretrained-text",
294
- default="",
295
- type=str,
296
- help="Use a pretrained text model weights for the text encoder of CLAP",
297
- )
298
- parser.add_argument(
299
- "--pretrained",
300
- default="",
301
- type=str,
302
- help="Use a pretrained CLIP model weights with the specified tag or file path.",
303
- )
304
- parser.add_argument(
305
- "--pretrained-image",
306
- default=False,
307
- action="store_true",
308
- help="Load imagenet pretrained weights for image tower backbone if available.",
309
- )
310
- parser.add_argument(
311
- "--lock-image",
312
- default=False,
313
- action="store_true",
314
- help="Lock full image tower by disabling gradients.",
315
- )
316
- parser.add_argument(
317
- "--lock-image-unlocked-groups",
318
- type=int,
319
- default=0,
320
- help="Leave last n image tower layer groups unlocked.",
321
- )
322
- parser.add_argument(
323
- "--lock-image-freeze-bn-stats",
324
- default=False,
325
- action="store_true",
326
- help="Freeze BatchNorm running stats in image tower for any locked layers.",
327
- )
328
- parser.add_argument(
329
- "--local-loss",
330
- default=False,
331
- action="store_true",
332
- help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
333
- )
334
- parser.add_argument(
335
- "--gather-with-grad",
336
- default=False,
337
- action="store_true",
338
- help="enable full distributed gradient for feature gather",
339
- )
340
- parser.add_argument(
341
- "--force-quick-gelu",
342
- default=False,
343
- action="store_true",
344
- help="Force use of QuickGELU activation for non-OpenAI transformer models.",
345
- )
346
- parser.add_argument(
347
- "--torchscript",
348
- default=False,
349
- action="store_true",
350
- help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
351
- )
352
- parser.add_argument(
353
- "--trace",
354
- default=False,
355
- action="store_true",
356
- help="torch.jit.trace the model for inference / eval only",
357
- )
358
- # arguments for distributed training
359
- parser.add_argument(
360
- "--dist-url",
361
- default="env://",
362
- type=str,
363
- help="url used to set up distributed training",
364
- )
365
- parser.add_argument(
366
- "--dist-backend", default="nccl", type=str, help="distributed backend"
367
- )
368
- parser.add_argument(
369
- "--report-to",
370
- default="",
371
- type=str,
372
- help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
373
- )
374
- parser.add_argument(
375
- "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
376
- )
377
- parser.add_argument(
378
- "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
379
- )
380
- parser.add_argument(
381
- "--debug",
382
- default=False,
383
- action="store_true",
384
- help="If true, more information is logged.",
385
- )
386
- parser.add_argument(
387
- "--copy-codebase",
388
- default=False,
389
- action="store_true",
390
- help="If true, we copy the entire base on the log diretory, and execute from there.",
391
- )
392
- parser.add_argument(
393
- "--horovod",
394
- default=False,
395
- action="store_true",
396
- help="Use horovod for distributed training.",
397
- )
398
- parser.add_argument(
399
- "--ddp-static-graph",
400
- default=False,
401
- action="store_true",
402
- help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
403
- )
404
- parser.add_argument(
405
- "--no-set-device-rank",
406
- default=False,
407
- action="store_true",
408
- help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
409
- )
410
- parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
411
-
412
- parser.add_argument(
413
- "--top-k-checkpoint-select-dataset",
414
- type=str,
415
- default="all",
416
- help="The dataset of selecting top-k checkpoint.",
417
- )
418
-
419
- # @R10, @R@5, @R1, mAP@10
420
- parser.add_argument(
421
- "--top-k-checkpoint-select-metric",
422
- type=str,
423
- default="_R@10",
424
- help="The metric for selecting top-k checkpoint.",
425
- )
426
- parser.add_argument(
427
- "--openai-model-cache-dir",
428
- type=str,
429
- default="~/.cache/clip",
430
- help="Directory to download OpenAI models.",
431
- )
432
- parser.add_argument(
433
- "--optimizer",
434
- type=str,
435
- default="adamw",
436
- help="can be AdamW or SGD",
437
- )
438
- parser.add_argument(
439
- "--parallel-eval",
440
- default=False,
441
- action="store_true",
442
- help="Eval in parallel (multi-GPU, multi-node).",
443
- )
444
-
445
- parser.add_argument(
446
- "--no-eval",
447
- default=False,
448
- action="store_true",
449
- help="Training without evaluation.",
450
- )
451
-
452
- parser.add_argument(
453
- "--lp-mlp",
454
- default=False,
455
- action="store_true",
456
- help="Linear Probe using MLP layer or not.",
457
- )
458
-
459
- parser.add_argument(
460
- "--lp-freeze",
461
- default=False,
462
- action="store_true",
463
- help="Linear Probe using Freeze CLAP or not",
464
- )
465
-
466
- parser.add_argument(
467
- "--lp-act",
468
- default="None",
469
- type=str,
470
- help="Options are ['relu','elu','prelu','softmax','sigmoid']",
471
- )
472
-
473
- parser.add_argument(
474
- "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
475
- )
476
-
477
- parser.add_argument(
478
- "--lp-metrics",
479
- type=str,
480
- default="map,mauc,acc",
481
- help="Metrics of Linear Probe.",
482
- )
483
-
484
- parser.add_argument(
485
- "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
486
- )
487
- parser.add_argument(
488
- "--kappa",
489
- type=float,
490
- default=0,
491
- help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
492
- )
493
-
494
- parser.add_argument(
495
- "--data-filling",
496
- type=str,
497
- default="pad",
498
- help="type of data filling when the audio length is shorter than the max length."
499
- "Can be one of the following: repeat, repeatpad, pad",
500
- )
501
- parser.add_argument(
502
- "--data-truncating",
503
- type=str,
504
- default="rand_trunc",
505
- help="type of data truncation when the audio length is longer than the max length."
506
- "Can be one of the following: rand_trunc, fusion",
507
- )
508
-
509
- parser.add_argument(
510
- "--clap-mlploss",
511
- default=False,
512
- action="store_true",
513
- help="Using MLP loss for CLAP model or not",
514
- )
515
-
516
- parser.add_argument(
517
- "--wandb-id",
518
- type=str,
519
- default=None,
520
- help="the id of wandb experiment to restore.",
521
- )
522
-
523
- parser.add_argument(
524
- "--sleep", type=float, default=0, help="sleep n seconds before start training"
525
- )
526
-
527
- # variable length processing
528
- parser.add_argument(
529
- "--enable-fusion",
530
- default=False,
531
- action="store_true",
532
- help="Enable feature funsion for variable-length data",
533
- )
534
-
535
- parser.add_argument(
536
- "--fusion-type",
537
- type=str,
538
- default="None",
539
- help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
540
- )
541
-
542
- parser.add_argument(
543
- "--mixup",
544
- default=False,
545
- action="store_true",
546
- help="Enable mixup in finetuning training.",
547
- )
548
- parser.add_argument(
549
- "--text-augment-selection",
550
- type=str,
551
- default=None,
552
- help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
553
- )
554
-
555
- args = parser.parse_args()
556
-
557
- # If some params are not passed, we use the default values based on model name.
558
- default_params = get_default_params(args.amodel)
559
- for name, val in default_params.items():
560
- if getattr(args, name) is None:
561
- setattr(args, name, val)
562
-
563
- return args
 
1
+ import argparse
2
+
3
+
4
+ def get_default_params(model_name):
5
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
6
+ model_name = model_name.lower()
7
+ if "vit" in model_name:
8
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
9
+ else:
10
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--train-data",
17
+ type=str,
18
+ default=None,
19
+ help="Path to h5 filewith training data",
20
+ )
21
+ parser.add_argument(
22
+ "--val-data",
23
+ type=str,
24
+ default=None,
25
+ help="Path to h5 file with validation data",
26
+ )
27
+ parser.add_argument(
28
+ "--freeze-text",
29
+ default=False,
30
+ action="store_true",
31
+ help="if you need to freeze the text encoder, make this True",
32
+ )
33
+ parser.add_argument(
34
+ "--freeze-text-after",
35
+ type=int,
36
+ default=-1,
37
+ help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
38
+ )
39
+ parser.add_argument(
40
+ "--train-ipc",
41
+ type=str,
42
+ default=None,
43
+ help="Path to npy file of the number of instance per class in training data",
44
+ )
45
+ parser.add_argument(
46
+ "--val-ipc",
47
+ type=str,
48
+ default=None,
49
+ help="Path to npy file of the number of instance per class in validation data",
50
+ )
51
+ parser.add_argument(
52
+ "--train-num-samples",
53
+ type=int,
54
+ default=None,
55
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
56
+ )
57
+ parser.add_argument(
58
+ "--val-num-samples",
59
+ type=int,
60
+ default=None,
61
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
62
+ )
63
+ parser.add_argument(
64
+ "--dataset-type",
65
+ choices=["webdataset", "csv", "auto", "toy"],
66
+ default="auto",
67
+ help="Which type of dataset to process.",
68
+ )
69
+ parser.add_argument(
70
+ "--csv-separator",
71
+ type=str,
72
+ default="\t",
73
+ help="For csv-like datasets, which separator to use.",
74
+ )
75
+ parser.add_argument(
76
+ "--csv-img-key",
77
+ type=str,
78
+ default="filepath",
79
+ help="For csv-like datasets, the name of the key for the image paths.",
80
+ )
81
+ parser.add_argument(
82
+ "--csv-caption-key",
83
+ type=str,
84
+ default="title",
85
+ help="For csv-like datasets, the name of the key for the captions.",
86
+ )
87
+ parser.add_argument(
88
+ "--imagenet-val",
89
+ type=str,
90
+ default=None,
91
+ help="Path to imagenet val set for conducting zero shot evaluation.",
92
+ )
93
+ parser.add_argument(
94
+ "--imagenet-v2",
95
+ type=str,
96
+ default=None,
97
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
98
+ )
99
+ parser.add_argument(
100
+ "--datasetnames",
101
+ nargs="+",
102
+ default=None,
103
+ help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
104
+ )
105
+ parser.add_argument(
106
+ "--full-train-dataset",
107
+ nargs="+",
108
+ default=None,
109
+ help="Which dataset will be trained with all the subsets. (train+test)",
110
+ )
111
+ parser.add_argument(
112
+ "--exclude-eval-dataset",
113
+ nargs="+",
114
+ default=None,
115
+ help="Which dataset will be excluded with evaluation",
116
+ )
117
+ parser.add_argument(
118
+ "--datasetinfos",
119
+ nargs="+",
120
+ default=None,
121
+ help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
122
+ )
123
+ parser.add_argument(
124
+ "--dataset-proportion",
125
+ type=float,
126
+ default=1.0,
127
+ help="How much proportion of dataset we want to train.",
128
+ )
129
+ parser.add_argument(
130
+ "--remotedata",
131
+ default=False,
132
+ action="store_true",
133
+ help="if the dataset is remote, set this flag",
134
+ )
135
+ parser.add_argument(
136
+ "--class-label-path",
137
+ type=str,
138
+ default=None,
139
+ help="The path of the class label pickle or csv.",
140
+ )
141
+ parser.add_argument(
142
+ "--datasetpath",
143
+ type=str,
144
+ default="/mnt/audio_clip/webdataset_tar",
145
+ help="The path to the dataset",
146
+ )
147
+ parser.add_argument(
148
+ "--logs",
149
+ type=str,
150
+ default="./logs/",
151
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
152
+ )
153
+ parser.add_argument(
154
+ "--log-local",
155
+ action="store_true",
156
+ default=False,
157
+ help="log files on local master, otherwise global master only.",
158
+ )
159
+ parser.add_argument(
160
+ "--name",
161
+ type=str,
162
+ default=None,
163
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
164
+ )
165
+ parser.add_argument(
166
+ "--workers", type=int, default=1, help="Number of workers per GPU."
167
+ )
168
+ parser.add_argument(
169
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
170
+ )
171
+ parser.add_argument(
172
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
173
+ )
174
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
175
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
176
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
177
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
178
+ parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
179
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
180
+
181
+ parser.add_argument(
182
+ "--split-opt",
183
+ action="store_true",
184
+ default=False,
185
+ help="Use this flag to skip the learning rate decay.",
186
+ )
187
+ parser.add_argument(
188
+ "--lr-pretrained", type=float, default=None, help="Learning rate for text."
189
+ )
190
+ parser.add_argument(
191
+ "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
192
+ )
193
+ parser.add_argument(
194
+ "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
195
+ )
196
+ parser.add_argument(
197
+ "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
198
+ )
199
+ parser.add_argument(
200
+ "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
201
+ )
202
+ parser.add_argument(
203
+ "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
204
+ )
205
+ parser.add_argument(
206
+ "--lr-new", type=float, default=None, help="Learning rate for audio."
207
+ )
208
+ parser.add_argument(
209
+ "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
210
+ )
211
+ parser.add_argument(
212
+ "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
213
+ )
214
+ parser.add_argument(
215
+ "--eps-new", type=float, default=None, help="Adam epsilon for audio."
216
+ )
217
+ parser.add_argument(
218
+ "--wd-new", type=float, default=0.2, help="Weight decay for audio."
219
+ )
220
+ parser.add_argument(
221
+ "--momentum-new", type=float, default=0.9, help="Momentum for audio."
222
+ )
223
+ parser.add_argument(
224
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
225
+ )
226
+ parser.add_argument(
227
+ "--use-bn-sync",
228
+ default=False,
229
+ action="store_true",
230
+ help="Whether to use batch norm sync.",
231
+ )
232
+ parser.add_argument(
233
+ "--skip-scheduler",
234
+ action="store_true",
235
+ default=False,
236
+ help="Use this flag to skip the learning rate decay.",
237
+ )
238
+ parser.add_argument(
239
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
240
+ )
241
+ parser.add_argument(
242
+ "--save-top-performance",
243
+ type=int,
244
+ default=0,
245
+ help="Save the top x performance weights if the value >0",
246
+ )
247
+ parser.add_argument(
248
+ "--save-most-recent",
249
+ action="store_true",
250
+ default=False,
251
+ help="Always save the most recent model trained to epoch_latest.pt.",
252
+ )
253
+ parser.add_argument(
254
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
255
+ )
256
+ parser.add_argument(
257
+ "--val-frequency",
258
+ type=int,
259
+ default=1,
260
+ help="How often to run evaluation with val data.",
261
+ )
262
+ parser.add_argument(
263
+ "--resume",
264
+ default=None,
265
+ type=str,
266
+ help="path to latest checkpoint (default: none)",
267
+ )
268
+ parser.add_argument(
269
+ "--precision",
270
+ choices=["amp", "fp16", "fp32"],
271
+ default="amp",
272
+ help="Floating point precision.",
273
+ )
274
+ parser.add_argument(
275
+ "--amodel",
276
+ type=str,
277
+ default="RN50",
278
+ help="Name of the audio backbone to use.",
279
+ )
280
+ parser.add_argument(
281
+ "--tmodel",
282
+ type=str,
283
+ default="transformer",
284
+ help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
285
+ )
286
+ parser.add_argument(
287
+ "--pretrained-audio",
288
+ default="",
289
+ type=str,
290
+ help="Use a pretrained audio model weights for the audio encoder of CLAP",
291
+ )
292
+ parser.add_argument(
293
+ "--pretrained-text",
294
+ default="",
295
+ type=str,
296
+ help="Use a pretrained text model weights for the text encoder of CLAP",
297
+ )
298
+ parser.add_argument(
299
+ "--pretrained",
300
+ default="",
301
+ type=str,
302
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
303
+ )
304
+ parser.add_argument(
305
+ "--pretrained-image",
306
+ default=False,
307
+ action="store_true",
308
+ help="Load imagenet pretrained weights for image tower backbone if available.",
309
+ )
310
+ parser.add_argument(
311
+ "--lock-image",
312
+ default=False,
313
+ action="store_true",
314
+ help="Lock full image tower by disabling gradients.",
315
+ )
316
+ parser.add_argument(
317
+ "--lock-image-unlocked-groups",
318
+ type=int,
319
+ default=0,
320
+ help="Leave last n image tower layer groups unlocked.",
321
+ )
322
+ parser.add_argument(
323
+ "--lock-image-freeze-bn-stats",
324
+ default=False,
325
+ action="store_true",
326
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
327
+ )
328
+ parser.add_argument(
329
+ "--local-loss",
330
+ default=False,
331
+ action="store_true",
332
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
333
+ )
334
+ parser.add_argument(
335
+ "--gather-with-grad",
336
+ default=False,
337
+ action="store_true",
338
+ help="enable full distributed gradient for feature gather",
339
+ )
340
+ parser.add_argument(
341
+ "--force-quick-gelu",
342
+ default=False,
343
+ action="store_true",
344
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
345
+ )
346
+ parser.add_argument(
347
+ "--torchscript",
348
+ default=False,
349
+ action="store_true",
350
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
351
+ )
352
+ parser.add_argument(
353
+ "--trace",
354
+ default=False,
355
+ action="store_true",
356
+ help="torch.jit.trace the model for inference / eval only",
357
+ )
358
+ # arguments for distributed training
359
+ parser.add_argument(
360
+ "--dist-url",
361
+ default="env://",
362
+ type=str,
363
+ help="url used to set up distributed training",
364
+ )
365
+ parser.add_argument(
366
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
367
+ )
368
+ parser.add_argument(
369
+ "--report-to",
370
+ default="",
371
+ type=str,
372
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
373
+ )
374
+ parser.add_argument(
375
+ "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
376
+ )
377
+ parser.add_argument(
378
+ "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
379
+ )
380
+ parser.add_argument(
381
+ "--debug",
382
+ default=False,
383
+ action="store_true",
384
+ help="If true, more information is logged.",
385
+ )
386
+ parser.add_argument(
387
+ "--copy-codebase",
388
+ default=False,
389
+ action="store_true",
390
+ help="If true, we copy the entire base on the log diretory, and execute from there.",
391
+ )
392
+ parser.add_argument(
393
+ "--horovod",
394
+ default=False,
395
+ action="store_true",
396
+ help="Use horovod for distributed training.",
397
+ )
398
+ parser.add_argument(
399
+ "--ddp-static-graph",
400
+ default=False,
401
+ action="store_true",
402
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
403
+ )
404
+ parser.add_argument(
405
+ "--no-set-device-rank",
406
+ default=False,
407
+ action="store_true",
408
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
409
+ )
410
+ parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
411
+
412
+ parser.add_argument(
413
+ "--top-k-checkpoint-select-dataset",
414
+ type=str,
415
+ default="all",
416
+ help="The dataset of selecting top-k checkpoint.",
417
+ )
418
+
419
+ # @R10, @R@5, @R1, mAP@10
420
+ parser.add_argument(
421
+ "--top-k-checkpoint-select-metric",
422
+ type=str,
423
+ default="_R@10",
424
+ help="The metric for selecting top-k checkpoint.",
425
+ )
426
+ parser.add_argument(
427
+ "--openai-model-cache-dir",
428
+ type=str,
429
+ default="~/.cache/clip",
430
+ help="Directory to download OpenAI models.",
431
+ )
432
+ parser.add_argument(
433
+ "--optimizer",
434
+ type=str,
435
+ default="adamw",
436
+ help="can be AdamW or SGD",
437
+ )
438
+ parser.add_argument(
439
+ "--parallel-eval",
440
+ default=False,
441
+ action="store_true",
442
+ help="Eval in parallel (multi-GPU, multi-node).",
443
+ )
444
+
445
+ parser.add_argument(
446
+ "--no-eval",
447
+ default=False,
448
+ action="store_true",
449
+ help="Training without evaluation.",
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--lp-mlp",
454
+ default=False,
455
+ action="store_true",
456
+ help="Linear Probe using MLP layer or not.",
457
+ )
458
+
459
+ parser.add_argument(
460
+ "--lp-freeze",
461
+ default=False,
462
+ action="store_true",
463
+ help="Linear Probe using Freeze CLAP or not",
464
+ )
465
+
466
+ parser.add_argument(
467
+ "--lp-act",
468
+ default="None",
469
+ type=str,
470
+ help="Options are ['relu','elu','prelu','softmax','sigmoid']",
471
+ )
472
+
473
+ parser.add_argument(
474
+ "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
475
+ )
476
+
477
+ parser.add_argument(
478
+ "--lp-metrics",
479
+ type=str,
480
+ default="map,mauc,acc",
481
+ help="Metrics of Linear Probe.",
482
+ )
483
+
484
+ parser.add_argument(
485
+ "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
486
+ )
487
+ parser.add_argument(
488
+ "--kappa",
489
+ type=float,
490
+ default=0,
491
+ help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
492
+ )
493
+
494
+ parser.add_argument(
495
+ "--data-filling",
496
+ type=str,
497
+ default="pad",
498
+ help="type of data filling when the audio length is shorter than the max length."
499
+ "Can be one of the following: repeat, repeatpad, pad",
500
+ )
501
+ parser.add_argument(
502
+ "--data-truncating",
503
+ type=str,
504
+ default="rand_trunc",
505
+ help="type of data truncation when the audio length is longer than the max length."
506
+ "Can be one of the following: rand_trunc, fusion",
507
+ )
508
+
509
+ parser.add_argument(
510
+ "--clap-mlploss",
511
+ default=False,
512
+ action="store_true",
513
+ help="Using MLP loss for CLAP model or not",
514
+ )
515
+
516
+ parser.add_argument(
517
+ "--wandb-id",
518
+ type=str,
519
+ default=None,
520
+ help="the id of wandb experiment to restore.",
521
+ )
522
+
523
+ parser.add_argument(
524
+ "--sleep", type=float, default=0, help="sleep n seconds before start training"
525
+ )
526
+
527
+ # variable length processing
528
+ parser.add_argument(
529
+ "--enable-fusion",
530
+ default=False,
531
+ action="store_true",
532
+ help="Enable feature funsion for variable-length data",
533
+ )
534
+
535
+ parser.add_argument(
536
+ "--fusion-type",
537
+ type=str,
538
+ default="None",
539
+ help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
540
+ )
541
+
542
+ parser.add_argument(
543
+ "--mixup",
544
+ default=False,
545
+ action="store_true",
546
+ help="Enable mixup in finetuning training.",
547
+ )
548
+ parser.add_argument(
549
+ "--text-augment-selection",
550
+ type=str,
551
+ default=None,
552
+ help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
553
+ )
554
+
555
+ args = parser.parse_args()
556
+
557
+ # If some params are not passed, we use the default values based on model name.
558
+ default_params = get_default_params(args.amodel)
559
+ for name, val in default_params.items():
560
+ if getattr(args, name) is None:
561
+ setattr(args, name, val)
562
+
563
+ return args
audiosr/hifigan/LICENSE CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Jungil Kong
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
  SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
  SOFTWARE.
audiosr/hifigan/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- from .models_v2 import Generator
2
- from .models import Generator as Generator_old
3
-
4
-
5
- class AttrDict(dict):
6
- def __init__(self, *args, **kwargs):
7
- super(AttrDict, self).__init__(*args, **kwargs)
8
- self.__dict__ = self
 
1
+ from .models_v2 import Generator
2
+ from .models import Generator as Generator_old
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
audiosr/hifigan/models.py CHANGED
@@ -1,174 +1,174 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.nn import Conv1d, ConvTranspose1d
5
- from torch.nn.utils import weight_norm, remove_weight_norm
6
-
7
- LRELU_SLOPE = 0.1
8
-
9
-
10
- def init_weights(m, mean=0.0, std=0.01):
11
- classname = m.__class__.__name__
12
- if classname.find("Conv") != -1:
13
- m.weight.data.normal_(mean, std)
14
-
15
-
16
- def get_padding(kernel_size, dilation=1):
17
- return int((kernel_size * dilation - dilation) / 2)
18
-
19
-
20
- class ResBlock(torch.nn.Module):
21
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
- super(ResBlock, self).__init__()
23
- self.h = h
24
- self.convs1 = nn.ModuleList(
25
- [
26
- weight_norm(
27
- Conv1d(
28
- channels,
29
- channels,
30
- kernel_size,
31
- 1,
32
- dilation=dilation[0],
33
- padding=get_padding(kernel_size, dilation[0]),
34
- )
35
- ),
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[1],
43
- padding=get_padding(kernel_size, dilation[1]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[2],
53
- padding=get_padding(kernel_size, dilation[2]),
54
- )
55
- ),
56
- ]
57
- )
58
- self.convs1.apply(init_weights)
59
-
60
- self.convs2 = nn.ModuleList(
61
- [
62
- weight_norm(
63
- Conv1d(
64
- channels,
65
- channels,
66
- kernel_size,
67
- 1,
68
- dilation=1,
69
- padding=get_padding(kernel_size, 1),
70
- )
71
- ),
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- ]
93
- )
94
- self.convs2.apply(init_weights)
95
-
96
- def forward(self, x):
97
- for c1, c2 in zip(self.convs1, self.convs2):
98
- xt = F.leaky_relu(x, LRELU_SLOPE)
99
- xt = c1(xt)
100
- xt = F.leaky_relu(xt, LRELU_SLOPE)
101
- xt = c2(xt)
102
- x = xt + x
103
- return x
104
-
105
- def remove_weight_norm(self):
106
- for l in self.convs1:
107
- remove_weight_norm(l)
108
- for l in self.convs2:
109
- remove_weight_norm(l)
110
-
111
-
112
- class Generator(torch.nn.Module):
113
- def __init__(self, h):
114
- super(Generator, self).__init__()
115
- self.h = h
116
- self.num_kernels = len(h.resblock_kernel_sizes)
117
- self.num_upsamples = len(h.upsample_rates)
118
- self.conv_pre = weight_norm(
119
- Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
- )
121
- resblock = ResBlock
122
-
123
- self.ups = nn.ModuleList()
124
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
- self.ups.append(
126
- weight_norm(
127
- ConvTranspose1d(
128
- h.upsample_initial_channel // (2**i),
129
- h.upsample_initial_channel // (2 ** (i + 1)),
130
- k,
131
- u,
132
- padding=(k - u) // 2,
133
- )
134
- )
135
- )
136
-
137
- self.resblocks = nn.ModuleList()
138
- for i in range(len(self.ups)):
139
- ch = h.upsample_initial_channel // (2 ** (i + 1))
140
- for j, (k, d) in enumerate(
141
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
- ):
143
- self.resblocks.append(resblock(h, ch, k, d))
144
-
145
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
- self.ups.apply(init_weights)
147
- self.conv_post.apply(init_weights)
148
-
149
- def forward(self, x):
150
- x = self.conv_pre(x)
151
- for i in range(self.num_upsamples):
152
- x = F.leaky_relu(x, LRELU_SLOPE)
153
- x = self.ups[i](x)
154
- xs = None
155
- for j in range(self.num_kernels):
156
- if xs is None:
157
- xs = self.resblocks[i * self.num_kernels + j](x)
158
- else:
159
- xs += self.resblocks[i * self.num_kernels + j](x)
160
- x = xs / self.num_kernels
161
- x = F.leaky_relu(x)
162
- x = self.conv_post(x)
163
- x = torch.tanh(x)
164
-
165
- return x
166
-
167
- def remove_weight_norm(self):
168
- # print("Removing weight norm...")
169
- for l in self.ups:
170
- remove_weight_norm(l)
171
- for l in self.resblocks:
172
- l.remove_weight_norm()
173
- remove_weight_norm(self.conv_pre)
174
- remove_weight_norm(self.conv_post)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class Generator(torch.nn.Module):
113
+ def __init__(self, h):
114
+ super(Generator, self).__init__()
115
+ self.h = h
116
+ self.num_kernels = len(h.resblock_kernel_sizes)
117
+ self.num_upsamples = len(h.upsample_rates)
118
+ self.conv_pre = weight_norm(
119
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
+ )
121
+ resblock = ResBlock
122
+
123
+ self.ups = nn.ModuleList()
124
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
+ self.ups.append(
126
+ weight_norm(
127
+ ConvTranspose1d(
128
+ h.upsample_initial_channel // (2**i),
129
+ h.upsample_initial_channel // (2 ** (i + 1)),
130
+ k,
131
+ u,
132
+ padding=(k - u) // 2,
133
+ )
134
+ )
135
+ )
136
+
137
+ self.resblocks = nn.ModuleList()
138
+ for i in range(len(self.ups)):
139
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
140
+ for j, (k, d) in enumerate(
141
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
+ ):
143
+ self.resblocks.append(resblock(h, ch, k, d))
144
+
145
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
+ self.ups.apply(init_weights)
147
+ self.conv_post.apply(init_weights)
148
+
149
+ def forward(self, x):
150
+ x = self.conv_pre(x)
151
+ for i in range(self.num_upsamples):
152
+ x = F.leaky_relu(x, LRELU_SLOPE)
153
+ x = self.ups[i](x)
154
+ xs = None
155
+ for j in range(self.num_kernels):
156
+ if xs is None:
157
+ xs = self.resblocks[i * self.num_kernels + j](x)
158
+ else:
159
+ xs += self.resblocks[i * self.num_kernels + j](x)
160
+ x = xs / self.num_kernels
161
+ x = F.leaky_relu(x)
162
+ x = self.conv_post(x)
163
+ x = torch.tanh(x)
164
+
165
+ return x
166
+
167
+ def remove_weight_norm(self):
168
+ # print("Removing weight norm...")
169
+ for l in self.ups:
170
+ remove_weight_norm(l)
171
+ for l in self.resblocks:
172
+ l.remove_weight_norm()
173
+ remove_weight_norm(self.conv_pre)
174
+ remove_weight_norm(self.conv_post)
audiosr/hifigan/models_v2.py CHANGED
@@ -1,395 +1,395 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import torch.nn as nn
4
- from torch.nn import Conv1d, ConvTranspose1d
5
- from torch.nn.utils import weight_norm, remove_weight_norm
6
-
7
- LRELU_SLOPE = 0.1
8
-
9
-
10
- def init_weights(m, mean=0.0, std=0.01):
11
- classname = m.__class__.__name__
12
- if classname.find("Conv") != -1:
13
- m.weight.data.normal_(mean, std)
14
-
15
-
16
- def get_padding(kernel_size, dilation=1):
17
- return int((kernel_size * dilation - dilation) / 2)
18
-
19
-
20
- class ResBlock1(torch.nn.Module):
21
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
- super(ResBlock1, self).__init__()
23
- self.h = h
24
- self.convs1 = nn.ModuleList(
25
- [
26
- weight_norm(
27
- Conv1d(
28
- channels,
29
- channels,
30
- kernel_size,
31
- 1,
32
- dilation=dilation[0],
33
- padding=get_padding(kernel_size, dilation[0]),
34
- )
35
- ),
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[1],
43
- padding=get_padding(kernel_size, dilation[1]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[2],
53
- padding=get_padding(kernel_size, dilation[2]),
54
- )
55
- ),
56
- ]
57
- )
58
- self.convs1.apply(init_weights)
59
-
60
- self.convs2 = nn.ModuleList(
61
- [
62
- weight_norm(
63
- Conv1d(
64
- channels,
65
- channels,
66
- kernel_size,
67
- 1,
68
- dilation=1,
69
- padding=get_padding(kernel_size, 1),
70
- )
71
- ),
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- ]
93
- )
94
- self.convs2.apply(init_weights)
95
-
96
- def forward(self, x):
97
- for c1, c2 in zip(self.convs1, self.convs2):
98
- xt = F.leaky_relu(x, LRELU_SLOPE)
99
- xt = c1(xt)
100
- xt = F.leaky_relu(xt, LRELU_SLOPE)
101
- xt = c2(xt)
102
- x = xt + x
103
- return x
104
-
105
- def remove_weight_norm(self):
106
- for l in self.convs1:
107
- remove_weight_norm(l)
108
- for l in self.convs2:
109
- remove_weight_norm(l)
110
-
111
-
112
- class ResBlock2(torch.nn.Module):
113
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
114
- super(ResBlock2, self).__init__()
115
- self.h = h
116
- self.convs = nn.ModuleList(
117
- [
118
- weight_norm(
119
- Conv1d(
120
- channels,
121
- channels,
122
- kernel_size,
123
- 1,
124
- dilation=dilation[0],
125
- padding=get_padding(kernel_size, dilation[0]),
126
- )
127
- ),
128
- weight_norm(
129
- Conv1d(
130
- channels,
131
- channels,
132
- kernel_size,
133
- 1,
134
- dilation=dilation[1],
135
- padding=get_padding(kernel_size, dilation[1]),
136
- )
137
- ),
138
- ]
139
- )
140
- self.convs.apply(init_weights)
141
-
142
- def forward(self, x):
143
- for c in self.convs:
144
- xt = F.leaky_relu(x, LRELU_SLOPE)
145
- xt = c(xt)
146
- x = xt + x
147
- return x
148
-
149
- def remove_weight_norm(self):
150
- for l in self.convs:
151
- remove_weight_norm(l)
152
-
153
-
154
- class Generator(torch.nn.Module):
155
- def __init__(self, h):
156
- super(Generator, self).__init__()
157
- self.h = h
158
- self.num_kernels = len(h.resblock_kernel_sizes)
159
- self.num_upsamples = len(h.upsample_rates)
160
- self.conv_pre = weight_norm(
161
- Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
162
- )
163
- resblock = ResBlock1 if h.resblock == "1" else ResBlock2
164
-
165
- self.ups = nn.ModuleList()
166
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
167
- self.ups.append(
168
- weight_norm(
169
- ConvTranspose1d(
170
- h.upsample_initial_channel // (2**i),
171
- h.upsample_initial_channel // (2 ** (i + 1)),
172
- u * 2,
173
- u,
174
- padding=u // 2 + u % 2,
175
- output_padding=u % 2,
176
- )
177
- )
178
- )
179
-
180
- self.resblocks = nn.ModuleList()
181
- for i in range(len(self.ups)):
182
- ch = h.upsample_initial_channel // (2 ** (i + 1))
183
- for j, (k, d) in enumerate(
184
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
185
- ):
186
- self.resblocks.append(resblock(h, ch, k, d))
187
-
188
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
189
- self.ups.apply(init_weights)
190
- self.conv_post.apply(init_weights)
191
-
192
- def forward(self, x):
193
- # import ipdb; ipdb.set_trace()
194
- x = self.conv_pre(x)
195
- for i in range(self.num_upsamples):
196
- x = F.leaky_relu(x, LRELU_SLOPE)
197
- x = self.ups[i](x)
198
- xs = None
199
- for j in range(self.num_kernels):
200
- if xs is None:
201
- xs = self.resblocks[i * self.num_kernels + j](x)
202
- else:
203
- xs += self.resblocks[i * self.num_kernels + j](x)
204
- x = xs / self.num_kernels
205
- x = F.leaky_relu(x)
206
- x = self.conv_post(x)
207
- x = torch.tanh(x)
208
-
209
- return x
210
-
211
- def remove_weight_norm(self):
212
- # print('Removing weight norm...')
213
- for l in self.ups:
214
- remove_weight_norm(l)
215
- for l in self.resblocks:
216
- l.remove_weight_norm()
217
- remove_weight_norm(self.conv_pre)
218
- remove_weight_norm(self.conv_post)
219
-
220
-
221
- ##################################################################################################
222
-
223
- # import torch
224
- # import torch.nn as nn
225
- # import torch.nn.functional as F
226
- # from torch.nn import Conv1d, ConvTranspose1d
227
- # from torch.nn.utils import weight_norm, remove_weight_norm
228
-
229
- # LRELU_SLOPE = 0.1
230
-
231
-
232
- # def init_weights(m, mean=0.0, std=0.01):
233
- # classname = m.__class__.__name__
234
- # if classname.find("Conv") != -1:
235
- # m.weight.data.normal_(mean, std)
236
-
237
-
238
- # def get_padding(kernel_size, dilation=1):
239
- # return int((kernel_size * dilation - dilation) / 2)
240
-
241
-
242
- # class ResBlock(torch.nn.Module):
243
- # def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
244
- # super(ResBlock, self).__init__()
245
- # self.h = h
246
- # self.convs1 = nn.ModuleList(
247
- # [
248
- # weight_norm(
249
- # Conv1d(
250
- # channels,
251
- # channels,
252
- # kernel_size,
253
- # 1,
254
- # dilation=dilation[0],
255
- # padding=get_padding(kernel_size, dilation[0]),
256
- # )
257
- # ),
258
- # weight_norm(
259
- # Conv1d(
260
- # channels,
261
- # channels,
262
- # kernel_size,
263
- # 1,
264
- # dilation=dilation[1],
265
- # padding=get_padding(kernel_size, dilation[1]),
266
- # )
267
- # ),
268
- # weight_norm(
269
- # Conv1d(
270
- # channels,
271
- # channels,
272
- # kernel_size,
273
- # 1,
274
- # dilation=dilation[2],
275
- # padding=get_padding(kernel_size, dilation[2]),
276
- # )
277
- # ),
278
- # ]
279
- # )
280
- # self.convs1.apply(init_weights)
281
-
282
- # self.convs2 = nn.ModuleList(
283
- # [
284
- # weight_norm(
285
- # Conv1d(
286
- # channels,
287
- # channels,
288
- # kernel_size,
289
- # 1,
290
- # dilation=1,
291
- # padding=get_padding(kernel_size, 1),
292
- # )
293
- # ),
294
- # weight_norm(
295
- # Conv1d(
296
- # channels,
297
- # channels,
298
- # kernel_size,
299
- # 1,
300
- # dilation=1,
301
- # padding=get_padding(kernel_size, 1),
302
- # )
303
- # ),
304
- # weight_norm(
305
- # Conv1d(
306
- # channels,
307
- # channels,
308
- # kernel_size,
309
- # 1,
310
- # dilation=1,
311
- # padding=get_padding(kernel_size, 1),
312
- # )
313
- # ),
314
- # ]
315
- # )
316
- # self.convs2.apply(init_weights)
317
-
318
- # def forward(self, x):
319
- # for c1, c2 in zip(self.convs1, self.convs2):
320
- # xt = F.leaky_relu(x, LRELU_SLOPE)
321
- # xt = c1(xt)
322
- # xt = F.leaky_relu(xt, LRELU_SLOPE)
323
- # xt = c2(xt)
324
- # x = xt + x
325
- # return x
326
-
327
- # def remove_weight_norm(self):
328
- # for l in self.convs1:
329
- # remove_weight_norm(l)
330
- # for l in self.convs2:
331
- # remove_weight_norm(l)
332
-
333
- # class Generator(torch.nn.Module):
334
- # def __init__(self, h):
335
- # super(Generator, self).__init__()
336
- # self.h = h
337
- # self.num_kernels = len(h.resblock_kernel_sizes)
338
- # self.num_upsamples = len(h.upsample_rates)
339
- # self.conv_pre = weight_norm(
340
- # Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
341
- # )
342
- # resblock = ResBlock
343
-
344
- # self.ups = nn.ModuleList()
345
- # for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
346
- # self.ups.append(
347
- # weight_norm(
348
- # ConvTranspose1d(
349
- # h.upsample_initial_channel // (2**i),
350
- # h.upsample_initial_channel // (2 ** (i + 1)),
351
- # k,
352
- # u,
353
- # padding=(k - u) // 2,
354
- # )
355
- # )
356
- # )
357
-
358
- # self.resblocks = nn.ModuleList()
359
- # for i in range(len(self.ups)):
360
- # ch = h.upsample_initial_channel // (2 ** (i + 1))
361
- # for j, (k, d) in enumerate(
362
- # zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
363
- # ):
364
- # self.resblocks.append(resblock(h, ch, k, d))
365
-
366
- # self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
367
- # self.ups.apply(init_weights)
368
- # self.conv_post.apply(init_weights)
369
-
370
- # def forward(self, x):
371
- # x = self.conv_pre(x)
372
- # for i in range(self.num_upsamples):
373
- # x = F.leaky_relu(x, LRELU_SLOPE)
374
- # x = self.ups[i](x)
375
- # xs = None
376
- # for j in range(self.num_kernels):
377
- # if xs is None:
378
- # xs = self.resblocks[i * self.num_kernels + j](x)
379
- # else:
380
- # xs += self.resblocks[i * self.num_kernels + j](x)
381
- # x = xs / self.num_kernels
382
- # x = F.leaky_relu(x)
383
- # x = self.conv_post(x)
384
- # x = torch.tanh(x)
385
-
386
- # return x
387
-
388
- # def remove_weight_norm(self):
389
- # print("Removing weight norm...")
390
- # for l in self.ups:
391
- # remove_weight_norm(l)
392
- # for l in self.resblocks:
393
- # l.remove_weight_norm()
394
- # remove_weight_norm(self.conv_pre)
395
- # remove_weight_norm(self.conv_post)
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock1(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock1, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class ResBlock2(torch.nn.Module):
113
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
114
+ super(ResBlock2, self).__init__()
115
+ self.h = h
116
+ self.convs = nn.ModuleList(
117
+ [
118
+ weight_norm(
119
+ Conv1d(
120
+ channels,
121
+ channels,
122
+ kernel_size,
123
+ 1,
124
+ dilation=dilation[0],
125
+ padding=get_padding(kernel_size, dilation[0]),
126
+ )
127
+ ),
128
+ weight_norm(
129
+ Conv1d(
130
+ channels,
131
+ channels,
132
+ kernel_size,
133
+ 1,
134
+ dilation=dilation[1],
135
+ padding=get_padding(kernel_size, dilation[1]),
136
+ )
137
+ ),
138
+ ]
139
+ )
140
+ self.convs.apply(init_weights)
141
+
142
+ def forward(self, x):
143
+ for c in self.convs:
144
+ xt = F.leaky_relu(x, LRELU_SLOPE)
145
+ xt = c(xt)
146
+ x = xt + x
147
+ return x
148
+
149
+ def remove_weight_norm(self):
150
+ for l in self.convs:
151
+ remove_weight_norm(l)
152
+
153
+
154
+ class Generator(torch.nn.Module):
155
+ def __init__(self, h):
156
+ super(Generator, self).__init__()
157
+ self.h = h
158
+ self.num_kernels = len(h.resblock_kernel_sizes)
159
+ self.num_upsamples = len(h.upsample_rates)
160
+ self.conv_pre = weight_norm(
161
+ Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
162
+ )
163
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
164
+
165
+ self.ups = nn.ModuleList()
166
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
167
+ self.ups.append(
168
+ weight_norm(
169
+ ConvTranspose1d(
170
+ h.upsample_initial_channel // (2**i),
171
+ h.upsample_initial_channel // (2 ** (i + 1)),
172
+ u * 2,
173
+ u,
174
+ padding=u // 2 + u % 2,
175
+ output_padding=u % 2,
176
+ )
177
+ )
178
+ )
179
+
180
+ self.resblocks = nn.ModuleList()
181
+ for i in range(len(self.ups)):
182
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
183
+ for j, (k, d) in enumerate(
184
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
185
+ ):
186
+ self.resblocks.append(resblock(h, ch, k, d))
187
+
188
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
189
+ self.ups.apply(init_weights)
190
+ self.conv_post.apply(init_weights)
191
+
192
+ def forward(self, x):
193
+ # import ipdb; ipdb.set_trace()
194
+ x = self.conv_pre(x)
195
+ for i in range(self.num_upsamples):
196
+ x = F.leaky_relu(x, LRELU_SLOPE)
197
+ x = self.ups[i](x)
198
+ xs = None
199
+ for j in range(self.num_kernels):
200
+ if xs is None:
201
+ xs = self.resblocks[i * self.num_kernels + j](x)
202
+ else:
203
+ xs += self.resblocks[i * self.num_kernels + j](x)
204
+ x = xs / self.num_kernels
205
+ x = F.leaky_relu(x)
206
+ x = self.conv_post(x)
207
+ x = torch.tanh(x)
208
+
209
+ return x
210
+
211
+ def remove_weight_norm(self):
212
+ # print('Removing weight norm...')
213
+ for l in self.ups:
214
+ remove_weight_norm(l)
215
+ for l in self.resblocks:
216
+ l.remove_weight_norm()
217
+ remove_weight_norm(self.conv_pre)
218
+ remove_weight_norm(self.conv_post)
219
+
220
+
221
+ ##################################################################################################
222
+
223
+ # import torch
224
+ # import torch.nn as nn
225
+ # import torch.nn.functional as F
226
+ # from torch.nn import Conv1d, ConvTranspose1d
227
+ # from torch.nn.utils import weight_norm, remove_weight_norm
228
+
229
+ # LRELU_SLOPE = 0.1
230
+
231
+
232
+ # def init_weights(m, mean=0.0, std=0.01):
233
+ # classname = m.__class__.__name__
234
+ # if classname.find("Conv") != -1:
235
+ # m.weight.data.normal_(mean, std)
236
+
237
+
238
+ # def get_padding(kernel_size, dilation=1):
239
+ # return int((kernel_size * dilation - dilation) / 2)
240
+
241
+
242
+ # class ResBlock(torch.nn.Module):
243
+ # def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
244
+ # super(ResBlock, self).__init__()
245
+ # self.h = h
246
+ # self.convs1 = nn.ModuleList(
247
+ # [
248
+ # weight_norm(
249
+ # Conv1d(
250
+ # channels,
251
+ # channels,
252
+ # kernel_size,
253
+ # 1,
254
+ # dilation=dilation[0],
255
+ # padding=get_padding(kernel_size, dilation[0]),
256
+ # )
257
+ # ),
258
+ # weight_norm(
259
+ # Conv1d(
260
+ # channels,
261
+ # channels,
262
+ # kernel_size,
263
+ # 1,
264
+ # dilation=dilation[1],
265
+ # padding=get_padding(kernel_size, dilation[1]),
266
+ # )
267
+ # ),
268
+ # weight_norm(
269
+ # Conv1d(
270
+ # channels,
271
+ # channels,
272
+ # kernel_size,
273
+ # 1,
274
+ # dilation=dilation[2],
275
+ # padding=get_padding(kernel_size, dilation[2]),
276
+ # )
277
+ # ),
278
+ # ]
279
+ # )
280
+ # self.convs1.apply(init_weights)
281
+
282
+ # self.convs2 = nn.ModuleList(
283
+ # [
284
+ # weight_norm(
285
+ # Conv1d(
286
+ # channels,
287
+ # channels,
288
+ # kernel_size,
289
+ # 1,
290
+ # dilation=1,
291
+ # padding=get_padding(kernel_size, 1),
292
+ # )
293
+ # ),
294
+ # weight_norm(
295
+ # Conv1d(
296
+ # channels,
297
+ # channels,
298
+ # kernel_size,
299
+ # 1,
300
+ # dilation=1,
301
+ # padding=get_padding(kernel_size, 1),
302
+ # )
303
+ # ),
304
+ # weight_norm(
305
+ # Conv1d(
306
+ # channels,
307
+ # channels,
308
+ # kernel_size,
309
+ # 1,
310
+ # dilation=1,
311
+ # padding=get_padding(kernel_size, 1),
312
+ # )
313
+ # ),
314
+ # ]
315
+ # )
316
+ # self.convs2.apply(init_weights)
317
+
318
+ # def forward(self, x):
319
+ # for c1, c2 in zip(self.convs1, self.convs2):
320
+ # xt = F.leaky_relu(x, LRELU_SLOPE)
321
+ # xt = c1(xt)
322
+ # xt = F.leaky_relu(xt, LRELU_SLOPE)
323
+ # xt = c2(xt)
324
+ # x = xt + x
325
+ # return x
326
+
327
+ # def remove_weight_norm(self):
328
+ # for l in self.convs1:
329
+ # remove_weight_norm(l)
330
+ # for l in self.convs2:
331
+ # remove_weight_norm(l)
332
+
333
+ # class Generator(torch.nn.Module):
334
+ # def __init__(self, h):
335
+ # super(Generator, self).__init__()
336
+ # self.h = h
337
+ # self.num_kernels = len(h.resblock_kernel_sizes)
338
+ # self.num_upsamples = len(h.upsample_rates)
339
+ # self.conv_pre = weight_norm(
340
+ # Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
341
+ # )
342
+ # resblock = ResBlock
343
+
344
+ # self.ups = nn.ModuleList()
345
+ # for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
346
+ # self.ups.append(
347
+ # weight_norm(
348
+ # ConvTranspose1d(
349
+ # h.upsample_initial_channel // (2**i),
350
+ # h.upsample_initial_channel // (2 ** (i + 1)),
351
+ # k,
352
+ # u,
353
+ # padding=(k - u) // 2,
354
+ # )
355
+ # )
356
+ # )
357
+
358
+ # self.resblocks = nn.ModuleList()
359
+ # for i in range(len(self.ups)):
360
+ # ch = h.upsample_initial_channel // (2 ** (i + 1))
361
+ # for j, (k, d) in enumerate(
362
+ # zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
363
+ # ):
364
+ # self.resblocks.append(resblock(h, ch, k, d))
365
+
366
+ # self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
367
+ # self.ups.apply(init_weights)
368
+ # self.conv_post.apply(init_weights)
369
+
370
+ # def forward(self, x):
371
+ # x = self.conv_pre(x)
372
+ # for i in range(self.num_upsamples):
373
+ # x = F.leaky_relu(x, LRELU_SLOPE)
374
+ # x = self.ups[i](x)
375
+ # xs = None
376
+ # for j in range(self.num_kernels):
377
+ # if xs is None:
378
+ # xs = self.resblocks[i * self.num_kernels + j](x)
379
+ # else:
380
+ # xs += self.resblocks[i * self.num_kernels + j](x)
381
+ # x = xs / self.num_kernels
382
+ # x = F.leaky_relu(x)
383
+ # x = self.conv_post(x)
384
+ # x = torch.tanh(x)
385
+
386
+ # return x
387
+
388
+ # def remove_weight_norm(self):
389
+ # print("Removing weight norm...")
390
+ # for l in self.ups:
391
+ # remove_weight_norm(l)
392
+ # for l in self.resblocks:
393
+ # l.remove_weight_norm()
394
+ # remove_weight_norm(self.conv_pre)
395
+ # remove_weight_norm(self.conv_post)
audiosr/latent_diffusion/models/ddim.py CHANGED
@@ -1,492 +1,492 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from audiosr.latent_diffusion.modules.diffusionmodules.util import (
8
- make_ddim_sampling_parameters,
9
- make_ddim_timesteps,
10
- noise_like,
11
- extract_into_tensor,
12
- )
13
-
14
-
15
- class DDIMSampler(object):
16
- def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
17
- super().__init__()
18
- self.model = model
19
- self.ddpm_num_timesteps = model.num_timesteps
20
- self.schedule = schedule
21
- self.device = device
22
-
23
- def register_buffer(self, name, attr):
24
- if type(attr) == torch.Tensor:
25
- if attr.device != self.device:
26
- is_mps = self.device == "mps" or self.device == torch.device("mps")
27
- if is_mps and attr.dtype == torch.float64:
28
- attr = attr.to(self.device, dtype=torch.float32)
29
- else:
30
- attr = attr.to(self.device)
31
- setattr(self, name, attr)
32
-
33
- def make_schedule(
34
- self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
35
- ):
36
- self.ddim_timesteps = make_ddim_timesteps(
37
- ddim_discr_method=ddim_discretize,
38
- num_ddim_timesteps=ddim_num_steps,
39
- num_ddpm_timesteps=self.ddpm_num_timesteps,
40
- verbose=verbose,
41
- )
42
- alphas_cumprod = self.model.alphas_cumprod
43
- assert (
44
- alphas_cumprod.shape[0] == self.ddpm_num_timesteps
45
- ), "alphas have to be defined for each timestep"
46
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
47
-
48
- self.register_buffer("betas", to_torch(self.model.betas))
49
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
50
- self.register_buffer(
51
- "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
52
- )
53
-
54
- # calculations for diffusion q(x_t | x_{t-1}) and others
55
- self.register_buffer(
56
- "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
57
- )
58
- self.register_buffer(
59
- "sqrt_one_minus_alphas_cumprod",
60
- to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
61
- )
62
- self.register_buffer(
63
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
64
- )
65
- self.register_buffer(
66
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
67
- )
68
- self.register_buffer(
69
- "sqrt_recipm1_alphas_cumprod",
70
- to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
71
- )
72
-
73
- # ddim sampling parameters
74
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
75
- alphacums=alphas_cumprod.cpu(),
76
- ddim_timesteps=self.ddim_timesteps,
77
- eta=ddim_eta,
78
- verbose=verbose,
79
- )
80
- self.register_buffer("ddim_sigmas", ddim_sigmas)
81
- self.register_buffer("ddim_alphas", ddim_alphas)
82
- self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
83
- self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
84
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
85
- (1 - self.alphas_cumprod_prev)
86
- / (1 - self.alphas_cumprod)
87
- * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
88
- )
89
- self.register_buffer(
90
- "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
91
- )
92
-
93
- @torch.no_grad()
94
- def sample(
95
- self,
96
- S,
97
- batch_size,
98
- shape,
99
- conditioning=None,
100
- callback=None,
101
- normals_sequence=None,
102
- img_callback=None,
103
- quantize_x0=False,
104
- eta=0.0,
105
- mask=None,
106
- x0=None,
107
- temperature=1.0,
108
- noise_dropout=0.0,
109
- score_corrector=None,
110
- corrector_kwargs=None,
111
- verbose=True,
112
- x_T=None,
113
- log_every_t=100,
114
- unconditional_guidance_scale=1.0,
115
- unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
116
- dynamic_threshold=None,
117
- ucg_schedule=None,
118
- **kwargs,
119
- ):
120
- # if conditioning is not None:
121
- # if isinstance(conditioning, dict):
122
- # ctmp = conditioning[list(conditioning.keys())[0]]
123
- # while isinstance(ctmp, list): ctmp = ctmp[0]
124
- # cbs = ctmp.shape[0]
125
- # if cbs != batch_size:
126
- # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
127
-
128
- # elif isinstance(conditioning, list):
129
- # for ctmp in conditioning:
130
- # if ctmp.shape[0] != batch_size:
131
- # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
132
-
133
- # else:
134
- # if conditioning.shape[0] != batch_size:
135
- # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
136
-
137
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
138
- # sampling
139
- C, H, W = shape
140
- size = (batch_size, C, H, W)
141
- # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
142
-
143
- samples, intermediates = self.ddim_sampling(
144
- conditioning,
145
- size,
146
- callback=callback,
147
- img_callback=img_callback,
148
- quantize_denoised=quantize_x0,
149
- mask=mask,
150
- x0=x0,
151
- ddim_use_original_steps=False,
152
- noise_dropout=noise_dropout,
153
- temperature=temperature,
154
- score_corrector=score_corrector,
155
- corrector_kwargs=corrector_kwargs,
156
- x_T=x_T,
157
- log_every_t=log_every_t,
158
- unconditional_guidance_scale=unconditional_guidance_scale,
159
- unconditional_conditioning=unconditional_conditioning,
160
- dynamic_threshold=dynamic_threshold,
161
- ucg_schedule=ucg_schedule,
162
- )
163
- return samples, intermediates
164
-
165
- @torch.no_grad()
166
- def ddim_sampling(
167
- self,
168
- cond,
169
- shape,
170
- x_T=None,
171
- ddim_use_original_steps=False,
172
- callback=None,
173
- timesteps=None,
174
- quantize_denoised=False,
175
- mask=None,
176
- x0=None,
177
- img_callback=None,
178
- log_every_t=100,
179
- temperature=1.0,
180
- noise_dropout=0.0,
181
- score_corrector=None,
182
- corrector_kwargs=None,
183
- unconditional_guidance_scale=1.0,
184
- unconditional_conditioning=None,
185
- dynamic_threshold=None,
186
- ucg_schedule=None,
187
- ):
188
- device = self.model.betas.device
189
- b = shape[0]
190
- if x_T is None:
191
- img = torch.randn(shape, device=device)
192
- else:
193
- img = x_T
194
-
195
- if timesteps is None:
196
- timesteps = (
197
- self.ddpm_num_timesteps
198
- if ddim_use_original_steps
199
- else self.ddim_timesteps
200
- )
201
- elif timesteps is not None and not ddim_use_original_steps:
202
- subset_end = (
203
- int(
204
- min(timesteps / self.ddim_timesteps.shape[0], 1)
205
- * self.ddim_timesteps.shape[0]
206
- )
207
- - 1
208
- )
209
- timesteps = self.ddim_timesteps[:subset_end]
210
-
211
- intermediates = {"x_inter": [img], "pred_x0": [img]}
212
- time_range = (
213
- reversed(range(0, timesteps))
214
- if ddim_use_original_steps
215
- else np.flip(timesteps)
216
- )
217
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
218
- print(f"Running DDIM Sampling with {total_steps} timesteps")
219
-
220
- iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
221
-
222
- for i, step in enumerate(iterator):
223
- index = total_steps - i - 1
224
- ts = torch.full((b,), step, device=device, dtype=torch.long)
225
-
226
- if mask is not None:
227
- assert x0 is not None
228
- img_orig = self.model.q_sample(
229
- x0, ts
230
- ) # TODO: deterministic forward pass?
231
- img = img_orig * mask + (1.0 - mask) * img
232
-
233
- if ucg_schedule is not None:
234
- assert len(ucg_schedule) == len(time_range)
235
- unconditional_guidance_scale = ucg_schedule[i]
236
-
237
- outs = self.p_sample_ddim(
238
- img,
239
- cond,
240
- ts,
241
- index=index,
242
- use_original_steps=ddim_use_original_steps,
243
- quantize_denoised=quantize_denoised,
244
- temperature=temperature,
245
- noise_dropout=noise_dropout,
246
- score_corrector=score_corrector,
247
- corrector_kwargs=corrector_kwargs,
248
- unconditional_guidance_scale=unconditional_guidance_scale,
249
- unconditional_conditioning=unconditional_conditioning,
250
- dynamic_threshold=dynamic_threshold,
251
- )
252
- img, pred_x0 = outs
253
- if callback:
254
- callback(i)
255
- if img_callback:
256
- img_callback(pred_x0, i)
257
-
258
- if index % log_every_t == 0 or index == total_steps - 1:
259
- intermediates["x_inter"].append(img)
260
- intermediates["pred_x0"].append(pred_x0)
261
-
262
- return img, intermediates
263
-
264
- @torch.no_grad()
265
- def p_sample_ddim(
266
- self,
267
- x,
268
- c,
269
- t,
270
- index,
271
- repeat_noise=False,
272
- use_original_steps=False,
273
- quantize_denoised=False,
274
- temperature=1.0,
275
- noise_dropout=0.0,
276
- score_corrector=None,
277
- corrector_kwargs=None,
278
- unconditional_guidance_scale=1.0,
279
- unconditional_conditioning=None,
280
- dynamic_threshold=None,
281
- ):
282
- b, *_, device = *x.shape, x.device
283
-
284
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
285
- model_output = self.model.apply_model(x, t, c)
286
- else:
287
- x_in = x
288
- t_in = t
289
-
290
- assert isinstance(c, dict)
291
- assert isinstance(unconditional_conditioning, dict)
292
-
293
- model_t = self.model.apply_model(x_in, t_in, c)
294
-
295
- model_uncond = self.model.apply_model(
296
- x_in, t_in, unconditional_conditioning
297
- )
298
-
299
- model_output = model_uncond + unconditional_guidance_scale * (
300
- model_t - model_uncond
301
- )
302
-
303
- if self.model.parameterization == "v":
304
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
305
- else:
306
- e_t = model_output
307
-
308
- if score_corrector is not None:
309
- assert self.model.parameterization == "eps", "not implemented"
310
- e_t = score_corrector.modify_score(
311
- self.model, e_t, x, t, c, **corrector_kwargs
312
- )
313
-
314
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
315
- alphas_prev = (
316
- self.model.alphas_cumprod_prev
317
- if use_original_steps
318
- else self.ddim_alphas_prev
319
- )
320
- sqrt_one_minus_alphas = (
321
- self.model.sqrt_one_minus_alphas_cumprod
322
- if use_original_steps
323
- else self.ddim_sqrt_one_minus_alphas
324
- )
325
- sigmas = (
326
- self.model.ddim_sigmas_for_original_num_steps
327
- if use_original_steps
328
- else self.ddim_sigmas
329
- )
330
- # select parameters corresponding to the currently considered timestep
331
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
332
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
333
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
334
- sqrt_one_minus_at = torch.full(
335
- (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
336
- )
337
-
338
- # current prediction for x_0
339
- if self.model.parameterization != "v":
340
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
341
- else:
342
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
343
-
344
- if quantize_denoised:
345
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
346
-
347
- if dynamic_threshold is not None:
348
- raise NotImplementedError()
349
-
350
- # direction pointing to x_t
351
- dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
352
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
353
- if noise_dropout > 0.0:
354
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
355
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
356
- return x_prev, pred_x0
357
-
358
- @torch.no_grad()
359
- def encode(
360
- self,
361
- x0,
362
- c,
363
- t_enc,
364
- use_original_steps=False,
365
- return_intermediates=None,
366
- unconditional_guidance_scale=1.0,
367
- unconditional_conditioning=None,
368
- callback=None,
369
- ):
370
- num_reference_steps = (
371
- self.ddpm_num_timesteps
372
- if use_original_steps
373
- else self.ddim_timesteps.shape[0]
374
- )
375
-
376
- assert t_enc <= num_reference_steps
377
- num_steps = t_enc
378
-
379
- if use_original_steps:
380
- alphas_next = self.alphas_cumprod[:num_steps]
381
- alphas = self.alphas_cumprod_prev[:num_steps]
382
- else:
383
- alphas_next = self.ddim_alphas[:num_steps]
384
- alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
385
-
386
- x_next = x0
387
- intermediates = []
388
- inter_steps = []
389
- for i in tqdm(range(num_steps), desc="Encoding Image"):
390
- t = torch.full(
391
- (x0.shape[0],), i, device=self.model.device, dtype=torch.long
392
- )
393
- if unconditional_guidance_scale == 1.0:
394
- noise_pred = self.model.apply_model(x_next, t, c)
395
- else:
396
- assert unconditional_conditioning is not None
397
- e_t_uncond, noise_pred = torch.chunk(
398
- self.model.apply_model(
399
- torch.cat((x_next, x_next)),
400
- torch.cat((t, t)),
401
- torch.cat((unconditional_conditioning, c)),
402
- ),
403
- 2,
404
- )
405
- noise_pred = e_t_uncond + unconditional_guidance_scale * (
406
- noise_pred - e_t_uncond
407
- )
408
-
409
- xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
410
- weighted_noise_pred = (
411
- alphas_next[i].sqrt()
412
- * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
413
- * noise_pred
414
- )
415
- x_next = xt_weighted + weighted_noise_pred
416
- if (
417
- return_intermediates
418
- and i % (num_steps // return_intermediates) == 0
419
- and i < num_steps - 1
420
- ):
421
- intermediates.append(x_next)
422
- inter_steps.append(i)
423
- elif return_intermediates and i >= num_steps - 2:
424
- intermediates.append(x_next)
425
- inter_steps.append(i)
426
- if callback:
427
- callback(i)
428
-
429
- out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
430
- if return_intermediates:
431
- out.update({"intermediates": intermediates})
432
- return x_next, out
433
-
434
- @torch.no_grad()
435
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
436
- # fast, but does not allow for exact reconstruction
437
- # t serves as an index to gather the correct alphas
438
- if use_original_steps:
439
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
440
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
441
- else:
442
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
443
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
444
-
445
- if noise is None:
446
- noise = torch.randn_like(x0)
447
- return (
448
- extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
449
- + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
450
- )
451
-
452
- @torch.no_grad()
453
- def decode(
454
- self,
455
- x_latent,
456
- cond,
457
- t_start,
458
- unconditional_guidance_scale=1.0,
459
- unconditional_conditioning=None,
460
- use_original_steps=False,
461
- callback=None,
462
- ):
463
- timesteps = (
464
- np.arange(self.ddpm_num_timesteps)
465
- if use_original_steps
466
- else self.ddim_timesteps
467
- )
468
- timesteps = timesteps[:t_start]
469
-
470
- time_range = np.flip(timesteps)
471
- total_steps = timesteps.shape[0]
472
- print(f"Running DDIM Sampling with {total_steps} timesteps")
473
-
474
- iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
475
- x_dec = x_latent
476
- for i, step in enumerate(iterator):
477
- index = total_steps - i - 1
478
- ts = torch.full(
479
- (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
480
- )
481
- x_dec, _ = self.p_sample_ddim(
482
- x_dec,
483
- cond,
484
- ts,
485
- index=index,
486
- use_original_steps=use_original_steps,
487
- unconditional_guidance_scale=unconditional_guidance_scale,
488
- unconditional_conditioning=unconditional_conditioning,
489
- )
490
- if callback:
491
- callback(i)
492
- return x_dec
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from audiosr.latent_diffusion.modules.diffusionmodules.util import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_timesteps,
10
+ noise_like,
11
+ extract_into_tensor,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ self.ddpm_num_timesteps = model.num_timesteps
20
+ self.schedule = schedule
21
+ self.device = device
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != self.device:
26
+ is_mps = self.device == "mps" or self.device == torch.device("mps")
27
+ if is_mps and attr.dtype == torch.float64:
28
+ attr = attr.to(self.device, dtype=torch.float32)
29
+ else:
30
+ attr = attr.to(self.device)
31
+ setattr(self, name, attr)
32
+
33
+ def make_schedule(
34
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
35
+ ):
36
+ self.ddim_timesteps = make_ddim_timesteps(
37
+ ddim_discr_method=ddim_discretize,
38
+ num_ddim_timesteps=ddim_num_steps,
39
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
40
+ verbose=verbose,
41
+ )
42
+ alphas_cumprod = self.model.alphas_cumprod
43
+ assert (
44
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
45
+ ), "alphas have to be defined for each timestep"
46
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
47
+
48
+ self.register_buffer("betas", to_torch(self.model.betas))
49
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
50
+ self.register_buffer(
51
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
52
+ )
53
+
54
+ # calculations for diffusion q(x_t | x_{t-1}) and others
55
+ self.register_buffer(
56
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
57
+ )
58
+ self.register_buffer(
59
+ "sqrt_one_minus_alphas_cumprod",
60
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
61
+ )
62
+ self.register_buffer(
63
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
64
+ )
65
+ self.register_buffer(
66
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
67
+ )
68
+ self.register_buffer(
69
+ "sqrt_recipm1_alphas_cumprod",
70
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
71
+ )
72
+
73
+ # ddim sampling parameters
74
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
75
+ alphacums=alphas_cumprod.cpu(),
76
+ ddim_timesteps=self.ddim_timesteps,
77
+ eta=ddim_eta,
78
+ verbose=verbose,
79
+ )
80
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
81
+ self.register_buffer("ddim_alphas", ddim_alphas)
82
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
83
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
84
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
85
+ (1 - self.alphas_cumprod_prev)
86
+ / (1 - self.alphas_cumprod)
87
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
88
+ )
89
+ self.register_buffer(
90
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
91
+ )
92
+
93
+ @torch.no_grad()
94
+ def sample(
95
+ self,
96
+ S,
97
+ batch_size,
98
+ shape,
99
+ conditioning=None,
100
+ callback=None,
101
+ normals_sequence=None,
102
+ img_callback=None,
103
+ quantize_x0=False,
104
+ eta=0.0,
105
+ mask=None,
106
+ x0=None,
107
+ temperature=1.0,
108
+ noise_dropout=0.0,
109
+ score_corrector=None,
110
+ corrector_kwargs=None,
111
+ verbose=True,
112
+ x_T=None,
113
+ log_every_t=100,
114
+ unconditional_guidance_scale=1.0,
115
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
116
+ dynamic_threshold=None,
117
+ ucg_schedule=None,
118
+ **kwargs,
119
+ ):
120
+ # if conditioning is not None:
121
+ # if isinstance(conditioning, dict):
122
+ # ctmp = conditioning[list(conditioning.keys())[0]]
123
+ # while isinstance(ctmp, list): ctmp = ctmp[0]
124
+ # cbs = ctmp.shape[0]
125
+ # if cbs != batch_size:
126
+ # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
127
+
128
+ # elif isinstance(conditioning, list):
129
+ # for ctmp in conditioning:
130
+ # if ctmp.shape[0] != batch_size:
131
+ # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
132
+
133
+ # else:
134
+ # if conditioning.shape[0] != batch_size:
135
+ # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
136
+
137
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
138
+ # sampling
139
+ C, H, W = shape
140
+ size = (batch_size, C, H, W)
141
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
142
+
143
+ samples, intermediates = self.ddim_sampling(
144
+ conditioning,
145
+ size,
146
+ callback=callback,
147
+ img_callback=img_callback,
148
+ quantize_denoised=quantize_x0,
149
+ mask=mask,
150
+ x0=x0,
151
+ ddim_use_original_steps=False,
152
+ noise_dropout=noise_dropout,
153
+ temperature=temperature,
154
+ score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ x_T=x_T,
157
+ log_every_t=log_every_t,
158
+ unconditional_guidance_scale=unconditional_guidance_scale,
159
+ unconditional_conditioning=unconditional_conditioning,
160
+ dynamic_threshold=dynamic_threshold,
161
+ ucg_schedule=ucg_schedule,
162
+ )
163
+ return samples, intermediates
164
+
165
+ @torch.no_grad()
166
+ def ddim_sampling(
167
+ self,
168
+ cond,
169
+ shape,
170
+ x_T=None,
171
+ ddim_use_original_steps=False,
172
+ callback=None,
173
+ timesteps=None,
174
+ quantize_denoised=False,
175
+ mask=None,
176
+ x0=None,
177
+ img_callback=None,
178
+ log_every_t=100,
179
+ temperature=1.0,
180
+ noise_dropout=0.0,
181
+ score_corrector=None,
182
+ corrector_kwargs=None,
183
+ unconditional_guidance_scale=1.0,
184
+ unconditional_conditioning=None,
185
+ dynamic_threshold=None,
186
+ ucg_schedule=None,
187
+ ):
188
+ device = self.model.betas.device
189
+ b = shape[0]
190
+ if x_T is None:
191
+ img = torch.randn(shape, device=device)
192
+ else:
193
+ img = x_T
194
+
195
+ if timesteps is None:
196
+ timesteps = (
197
+ self.ddpm_num_timesteps
198
+ if ddim_use_original_steps
199
+ else self.ddim_timesteps
200
+ )
201
+ elif timesteps is not None and not ddim_use_original_steps:
202
+ subset_end = (
203
+ int(
204
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
205
+ * self.ddim_timesteps.shape[0]
206
+ )
207
+ - 1
208
+ )
209
+ timesteps = self.ddim_timesteps[:subset_end]
210
+
211
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
212
+ time_range = (
213
+ reversed(range(0, timesteps))
214
+ if ddim_use_original_steps
215
+ else np.flip(timesteps)
216
+ )
217
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
218
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
219
+
220
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
221
+
222
+ for i, step in enumerate(iterator):
223
+ index = total_steps - i - 1
224
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
225
+
226
+ if mask is not None:
227
+ assert x0 is not None
228
+ img_orig = self.model.q_sample(
229
+ x0, ts
230
+ ) # TODO: deterministic forward pass?
231
+ img = img_orig * mask + (1.0 - mask) * img
232
+
233
+ if ucg_schedule is not None:
234
+ assert len(ucg_schedule) == len(time_range)
235
+ unconditional_guidance_scale = ucg_schedule[i]
236
+
237
+ outs = self.p_sample_ddim(
238
+ img,
239
+ cond,
240
+ ts,
241
+ index=index,
242
+ use_original_steps=ddim_use_original_steps,
243
+ quantize_denoised=quantize_denoised,
244
+ temperature=temperature,
245
+ noise_dropout=noise_dropout,
246
+ score_corrector=score_corrector,
247
+ corrector_kwargs=corrector_kwargs,
248
+ unconditional_guidance_scale=unconditional_guidance_scale,
249
+ unconditional_conditioning=unconditional_conditioning,
250
+ dynamic_threshold=dynamic_threshold,
251
+ )
252
+ img, pred_x0 = outs
253
+ if callback:
254
+ callback(i)
255
+ if img_callback:
256
+ img_callback(pred_x0, i)
257
+
258
+ if index % log_every_t == 0 or index == total_steps - 1:
259
+ intermediates["x_inter"].append(img)
260
+ intermediates["pred_x0"].append(pred_x0)
261
+
262
+ return img, intermediates
263
+
264
+ @torch.no_grad()
265
+ def p_sample_ddim(
266
+ self,
267
+ x,
268
+ c,
269
+ t,
270
+ index,
271
+ repeat_noise=False,
272
+ use_original_steps=False,
273
+ quantize_denoised=False,
274
+ temperature=1.0,
275
+ noise_dropout=0.0,
276
+ score_corrector=None,
277
+ corrector_kwargs=None,
278
+ unconditional_guidance_scale=1.0,
279
+ unconditional_conditioning=None,
280
+ dynamic_threshold=None,
281
+ ):
282
+ b, *_, device = *x.shape, x.device
283
+
284
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
285
+ model_output = self.model.apply_model(x, t, c)
286
+ else:
287
+ x_in = x
288
+ t_in = t
289
+
290
+ assert isinstance(c, dict)
291
+ assert isinstance(unconditional_conditioning, dict)
292
+
293
+ model_t = self.model.apply_model(x_in, t_in, c)
294
+
295
+ model_uncond = self.model.apply_model(
296
+ x_in, t_in, unconditional_conditioning
297
+ )
298
+
299
+ model_output = model_uncond + unconditional_guidance_scale * (
300
+ model_t - model_uncond
301
+ )
302
+
303
+ if self.model.parameterization == "v":
304
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
305
+ else:
306
+ e_t = model_output
307
+
308
+ if score_corrector is not None:
309
+ assert self.model.parameterization == "eps", "not implemented"
310
+ e_t = score_corrector.modify_score(
311
+ self.model, e_t, x, t, c, **corrector_kwargs
312
+ )
313
+
314
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
315
+ alphas_prev = (
316
+ self.model.alphas_cumprod_prev
317
+ if use_original_steps
318
+ else self.ddim_alphas_prev
319
+ )
320
+ sqrt_one_minus_alphas = (
321
+ self.model.sqrt_one_minus_alphas_cumprod
322
+ if use_original_steps
323
+ else self.ddim_sqrt_one_minus_alphas
324
+ )
325
+ sigmas = (
326
+ self.model.ddim_sigmas_for_original_num_steps
327
+ if use_original_steps
328
+ else self.ddim_sigmas
329
+ )
330
+ # select parameters corresponding to the currently considered timestep
331
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
332
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
333
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
334
+ sqrt_one_minus_at = torch.full(
335
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
336
+ )
337
+
338
+ # current prediction for x_0
339
+ if self.model.parameterization != "v":
340
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
341
+ else:
342
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
343
+
344
+ if quantize_denoised:
345
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
346
+
347
+ if dynamic_threshold is not None:
348
+ raise NotImplementedError()
349
+
350
+ # direction pointing to x_t
351
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
352
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
353
+ if noise_dropout > 0.0:
354
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
355
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
356
+ return x_prev, pred_x0
357
+
358
+ @torch.no_grad()
359
+ def encode(
360
+ self,
361
+ x0,
362
+ c,
363
+ t_enc,
364
+ use_original_steps=False,
365
+ return_intermediates=None,
366
+ unconditional_guidance_scale=1.0,
367
+ unconditional_conditioning=None,
368
+ callback=None,
369
+ ):
370
+ num_reference_steps = (
371
+ self.ddpm_num_timesteps
372
+ if use_original_steps
373
+ else self.ddim_timesteps.shape[0]
374
+ )
375
+
376
+ assert t_enc <= num_reference_steps
377
+ num_steps = t_enc
378
+
379
+ if use_original_steps:
380
+ alphas_next = self.alphas_cumprod[:num_steps]
381
+ alphas = self.alphas_cumprod_prev[:num_steps]
382
+ else:
383
+ alphas_next = self.ddim_alphas[:num_steps]
384
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
385
+
386
+ x_next = x0
387
+ intermediates = []
388
+ inter_steps = []
389
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
390
+ t = torch.full(
391
+ (x0.shape[0],), i, device=self.model.device, dtype=torch.long
392
+ )
393
+ if unconditional_guidance_scale == 1.0:
394
+ noise_pred = self.model.apply_model(x_next, t, c)
395
+ else:
396
+ assert unconditional_conditioning is not None
397
+ e_t_uncond, noise_pred = torch.chunk(
398
+ self.model.apply_model(
399
+ torch.cat((x_next, x_next)),
400
+ torch.cat((t, t)),
401
+ torch.cat((unconditional_conditioning, c)),
402
+ ),
403
+ 2,
404
+ )
405
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
406
+ noise_pred - e_t_uncond
407
+ )
408
+
409
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
410
+ weighted_noise_pred = (
411
+ alphas_next[i].sqrt()
412
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
413
+ * noise_pred
414
+ )
415
+ x_next = xt_weighted + weighted_noise_pred
416
+ if (
417
+ return_intermediates
418
+ and i % (num_steps // return_intermediates) == 0
419
+ and i < num_steps - 1
420
+ ):
421
+ intermediates.append(x_next)
422
+ inter_steps.append(i)
423
+ elif return_intermediates and i >= num_steps - 2:
424
+ intermediates.append(x_next)
425
+ inter_steps.append(i)
426
+ if callback:
427
+ callback(i)
428
+
429
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
430
+ if return_intermediates:
431
+ out.update({"intermediates": intermediates})
432
+ return x_next, out
433
+
434
+ @torch.no_grad()
435
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
436
+ # fast, but does not allow for exact reconstruction
437
+ # t serves as an index to gather the correct alphas
438
+ if use_original_steps:
439
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
440
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
441
+ else:
442
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
443
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
444
+
445
+ if noise is None:
446
+ noise = torch.randn_like(x0)
447
+ return (
448
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
449
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
450
+ )
451
+
452
+ @torch.no_grad()
453
+ def decode(
454
+ self,
455
+ x_latent,
456
+ cond,
457
+ t_start,
458
+ unconditional_guidance_scale=1.0,
459
+ unconditional_conditioning=None,
460
+ use_original_steps=False,
461
+ callback=None,
462
+ ):
463
+ timesteps = (
464
+ np.arange(self.ddpm_num_timesteps)
465
+ if use_original_steps
466
+ else self.ddim_timesteps
467
+ )
468
+ timesteps = timesteps[:t_start]
469
+
470
+ time_range = np.flip(timesteps)
471
+ total_steps = timesteps.shape[0]
472
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
473
+
474
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
475
+ x_dec = x_latent
476
+ for i, step in enumerate(iterator):
477
+ index = total_steps - i - 1
478
+ ts = torch.full(
479
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
480
+ )
481
+ x_dec, _ = self.p_sample_ddim(
482
+ x_dec,
483
+ cond,
484
+ ts,
485
+ index=index,
486
+ use_original_steps=use_original_steps,
487
+ unconditional_guidance_scale=unconditional_guidance_scale,
488
+ unconditional_conditioning=unconditional_conditioning,
489
+ )
490
+ if callback:
491
+ callback(i)
492
+ return x_dec
audiosr/latent_diffusion/models/ddpm.py CHANGED
The diff for this file is too large to render. See raw diff
 
audiosr/latent_diffusion/models/plms.py CHANGED
@@ -1,360 +1,360 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from audiosr.latent_diffusion.modules.diffusionmodules.util import (
8
- make_ddim_sampling_parameters,
9
- make_ddim_timesteps,
10
- noise_like,
11
- )
12
-
13
-
14
- class PLMSSampler(object):
15
- def __init__(self, model, schedule="linear", **kwargs):
16
- super().__init__()
17
- self.model = model
18
- self.ddpm_num_timesteps = model.num_timesteps
19
- self.schedule = schedule
20
-
21
- def register_buffer(self, name, attr):
22
- if type(attr) == torch.Tensor:
23
- if attr.device != torch.device("cuda"):
24
- attr = attr.to(torch.device("cuda"))
25
- setattr(self, name, attr)
26
-
27
- def make_schedule(
28
- self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
29
- ):
30
- if ddim_eta != 0:
31
- ddim_eta = 0
32
- # raise ValueError('ddim_eta must be 0 for PLMS')
33
-
34
- self.ddim_timesteps = make_ddim_timesteps(
35
- ddim_discr_method=ddim_discretize,
36
- num_ddim_timesteps=ddim_num_steps,
37
- num_ddpm_timesteps=self.ddpm_num_timesteps,
38
- verbose=verbose,
39
- )
40
- alphas_cumprod = self.model.alphas_cumprod
41
- assert (
42
- alphas_cumprod.shape[0] == self.ddpm_num_timesteps
43
- ), "alphas have to be defined for each timestep"
44
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
45
-
46
- self.register_buffer("betas", to_torch(self.model.betas))
47
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
48
- self.register_buffer(
49
- "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
50
- )
51
-
52
- # calculations for diffusion q(x_t | x_{t-1}) and others
53
- self.register_buffer(
54
- "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
55
- )
56
- self.register_buffer(
57
- "sqrt_one_minus_alphas_cumprod",
58
- to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
59
- )
60
- self.register_buffer(
61
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
62
- )
63
- self.register_buffer(
64
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
65
- )
66
- self.register_buffer(
67
- "sqrt_recipm1_alphas_cumprod",
68
- to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
69
- )
70
-
71
- # ddim sampling parameters
72
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
73
- alphacums=alphas_cumprod.cpu(),
74
- ddim_timesteps=self.ddim_timesteps,
75
- eta=ddim_eta,
76
- verbose=verbose,
77
- )
78
- self.register_buffer("ddim_sigmas", ddim_sigmas)
79
- self.register_buffer("ddim_alphas", ddim_alphas)
80
- self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
81
- self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
82
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
83
- (1 - self.alphas_cumprod_prev)
84
- / (1 - self.alphas_cumprod)
85
- * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
86
- )
87
- self.register_buffer(
88
- "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
89
- )
90
-
91
- @torch.no_grad()
92
- def sample(
93
- self,
94
- S,
95
- batch_size,
96
- shape,
97
- conditioning=None,
98
- callback=None,
99
- normals_sequence=None,
100
- img_callback=None,
101
- quantize_x0=False,
102
- eta=0.0,
103
- mask=None,
104
- x0=None,
105
- temperature=1.0,
106
- noise_dropout=0.0,
107
- score_corrector=None,
108
- corrector_kwargs=None,
109
- verbose=True,
110
- x_T=None,
111
- log_every_t=100,
112
- unconditional_guidance_scale=1.0,
113
- unconditional_conditioning=None,
114
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
115
- **kwargs,
116
- ):
117
- if conditioning is not None:
118
- if isinstance(conditioning, dict):
119
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
120
- if cbs != batch_size:
121
- print(
122
- f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
123
- )
124
- else:
125
- if conditioning.shape[0] != batch_size:
126
- print(
127
- f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
128
- )
129
-
130
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
131
- # sampling
132
- C, H, W = shape
133
- size = (batch_size, C, H, W)
134
- print(f"Data shape for PLMS sampling is {size}")
135
-
136
- samples, intermediates = self.plms_sampling(
137
- conditioning,
138
- size,
139
- callback=callback,
140
- img_callback=img_callback,
141
- quantize_denoised=quantize_x0,
142
- mask=mask,
143
- x0=x0,
144
- ddim_use_original_steps=False,
145
- noise_dropout=noise_dropout,
146
- temperature=temperature,
147
- score_corrector=score_corrector,
148
- corrector_kwargs=corrector_kwargs,
149
- x_T=x_T,
150
- log_every_t=log_every_t,
151
- unconditional_guidance_scale=unconditional_guidance_scale,
152
- unconditional_conditioning=unconditional_conditioning,
153
- )
154
- return samples, intermediates
155
-
156
- @torch.no_grad()
157
- def plms_sampling(
158
- self,
159
- cond,
160
- shape,
161
- x_T=None,
162
- ddim_use_original_steps=False,
163
- callback=None,
164
- timesteps=None,
165
- quantize_denoised=False,
166
- mask=None,
167
- x0=None,
168
- img_callback=None,
169
- log_every_t=100,
170
- temperature=1.0,
171
- noise_dropout=0.0,
172
- score_corrector=None,
173
- corrector_kwargs=None,
174
- unconditional_guidance_scale=1.0,
175
- unconditional_conditioning=None,
176
- ):
177
- device = self.model.betas.device
178
- b = shape[0]
179
- if x_T is None:
180
- img = torch.randn(shape, device=device)
181
- else:
182
- img = x_T
183
-
184
- if timesteps is None:
185
- timesteps = (
186
- self.ddpm_num_timesteps
187
- if ddim_use_original_steps
188
- else self.ddim_timesteps
189
- )
190
- elif timesteps is not None and not ddim_use_original_steps:
191
- subset_end = (
192
- int(
193
- min(timesteps / self.ddim_timesteps.shape[0], 1)
194
- * self.ddim_timesteps.shape[0]
195
- )
196
- - 1
197
- )
198
- timesteps = self.ddim_timesteps[:subset_end]
199
-
200
- intermediates = {"x_inter": [img], "pred_x0": [img]}
201
- time_range = (
202
- list(reversed(range(0, timesteps)))
203
- if ddim_use_original_steps
204
- else np.flip(timesteps)
205
- )
206
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
207
- print(f"Running PLMS Sampling with {total_steps} timesteps")
208
-
209
- iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
210
- old_eps = []
211
-
212
- for i, step in enumerate(iterator):
213
- index = total_steps - i - 1
214
- ts = torch.full((b,), step, device=device, dtype=torch.long)
215
- ts_next = torch.full(
216
- (b,),
217
- time_range[min(i + 1, len(time_range) - 1)],
218
- device=device,
219
- dtype=torch.long,
220
- )
221
-
222
- if mask is not None:
223
- assert x0 is not None
224
- img_orig = self.model.q_sample(
225
- x0, ts
226
- ) # TODO: deterministic forward pass?
227
- img = img_orig * mask + (1.0 - mask) * img
228
-
229
- outs = self.p_sample_plms(
230
- img,
231
- cond,
232
- ts,
233
- index=index,
234
- use_original_steps=ddim_use_original_steps,
235
- quantize_denoised=quantize_denoised,
236
- temperature=temperature,
237
- noise_dropout=noise_dropout,
238
- score_corrector=score_corrector,
239
- corrector_kwargs=corrector_kwargs,
240
- unconditional_guidance_scale=unconditional_guidance_scale,
241
- unconditional_conditioning=unconditional_conditioning,
242
- old_eps=old_eps,
243
- t_next=ts_next,
244
- )
245
- img, pred_x0, e_t = outs
246
- old_eps.append(e_t)
247
- if len(old_eps) >= 4:
248
- old_eps.pop(0)
249
- if callback:
250
- callback(i)
251
- if img_callback:
252
- img_callback(pred_x0, i)
253
-
254
- if index % log_every_t == 0 or index == total_steps - 1:
255
- intermediates["x_inter"].append(img)
256
- intermediates["pred_x0"].append(pred_x0)
257
-
258
- return img, intermediates
259
-
260
- @torch.no_grad()
261
- def p_sample_plms(
262
- self,
263
- x,
264
- c,
265
- t,
266
- index,
267
- repeat_noise=False,
268
- use_original_steps=False,
269
- quantize_denoised=False,
270
- temperature=1.0,
271
- noise_dropout=0.0,
272
- score_corrector=None,
273
- corrector_kwargs=None,
274
- unconditional_guidance_scale=1.0,
275
- unconditional_conditioning=None,
276
- old_eps=None,
277
- t_next=None,
278
- ):
279
- b, *_, device = *x.shape, x.device
280
-
281
- def get_model_output(x, t):
282
- if (
283
- unconditional_conditioning is None
284
- or unconditional_guidance_scale == 1.0
285
- ):
286
- e_t = self.model.apply_model(x, t, c)
287
- else:
288
- x_in = torch.cat([x] * 2)
289
- t_in = torch.cat([t] * 2)
290
- c_in = torch.cat([unconditional_conditioning, c])
291
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
292
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
293
-
294
- if score_corrector is not None:
295
- assert self.model.parameterization == "eps"
296
- e_t = score_corrector.modify_score(
297
- self.model, e_t, x, t, c, **corrector_kwargs
298
- )
299
-
300
- return e_t
301
-
302
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
303
- alphas_prev = (
304
- self.model.alphas_cumprod_prev
305
- if use_original_steps
306
- else self.ddim_alphas_prev
307
- )
308
- sqrt_one_minus_alphas = (
309
- self.model.sqrt_one_minus_alphas_cumprod
310
- if use_original_steps
311
- else self.ddim_sqrt_one_minus_alphas
312
- )
313
- sigmas = (
314
- self.model.ddim_sigmas_for_original_num_steps
315
- if use_original_steps
316
- else self.ddim_sigmas
317
- )
318
-
319
- def get_x_prev_and_pred_x0(e_t, index):
320
- # select parameters corresponding to the currently considered timestep
321
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
322
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
323
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
324
- sqrt_one_minus_at = torch.full(
325
- (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
326
- )
327
-
328
- # current prediction for x_0
329
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
330
- if quantize_denoised:
331
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
332
- # direction pointing to x_t
333
- dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
334
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
335
- if noise_dropout > 0.0:
336
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
337
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
338
- return x_prev, pred_x0
339
-
340
- e_t = get_model_output(x, t)
341
- if len(old_eps) == 0:
342
- # Pseudo Improved Euler (2nd order)
343
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
344
- e_t_next = get_model_output(x_prev, t_next)
345
- e_t_prime = (e_t + e_t_next) / 2
346
- elif len(old_eps) == 1:
347
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
348
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
349
- elif len(old_eps) == 2:
350
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
351
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
352
- elif len(old_eps) >= 3:
353
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
354
- e_t_prime = (
355
- 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
356
- ) / 24
357
-
358
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
359
-
360
- return x_prev, pred_x0, e_t
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from audiosr.latent_diffusion.modules.diffusionmodules.util import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_timesteps,
10
+ noise_like,
11
+ )
12
+
13
+
14
+ class PLMSSampler(object):
15
+ def __init__(self, model, schedule="linear", **kwargs):
16
+ super().__init__()
17
+ self.model = model
18
+ self.ddpm_num_timesteps = model.num_timesteps
19
+ self.schedule = schedule
20
+
21
+ def register_buffer(self, name, attr):
22
+ if type(attr) == torch.Tensor:
23
+ if attr.device != torch.device("cuda"):
24
+ attr = attr.to(torch.device("cuda"))
25
+ setattr(self, name, attr)
26
+
27
+ def make_schedule(
28
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
29
+ ):
30
+ if ddim_eta != 0:
31
+ ddim_eta = 0
32
+ # raise ValueError('ddim_eta must be 0 for PLMS')
33
+
34
+ self.ddim_timesteps = make_ddim_timesteps(
35
+ ddim_discr_method=ddim_discretize,
36
+ num_ddim_timesteps=ddim_num_steps,
37
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
38
+ verbose=verbose,
39
+ )
40
+ alphas_cumprod = self.model.alphas_cumprod
41
+ assert (
42
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
43
+ ), "alphas have to be defined for each timestep"
44
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
45
+
46
+ self.register_buffer("betas", to_torch(self.model.betas))
47
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
48
+ self.register_buffer(
49
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
50
+ )
51
+
52
+ # calculations for diffusion q(x_t | x_{t-1}) and others
53
+ self.register_buffer(
54
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
55
+ )
56
+ self.register_buffer(
57
+ "sqrt_one_minus_alphas_cumprod",
58
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
59
+ )
60
+ self.register_buffer(
61
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
62
+ )
63
+ self.register_buffer(
64
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
65
+ )
66
+ self.register_buffer(
67
+ "sqrt_recipm1_alphas_cumprod",
68
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
69
+ )
70
+
71
+ # ddim sampling parameters
72
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
73
+ alphacums=alphas_cumprod.cpu(),
74
+ ddim_timesteps=self.ddim_timesteps,
75
+ eta=ddim_eta,
76
+ verbose=verbose,
77
+ )
78
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
79
+ self.register_buffer("ddim_alphas", ddim_alphas)
80
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
81
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
82
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
83
+ (1 - self.alphas_cumprod_prev)
84
+ / (1 - self.alphas_cumprod)
85
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
86
+ )
87
+ self.register_buffer(
88
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
89
+ )
90
+
91
+ @torch.no_grad()
92
+ def sample(
93
+ self,
94
+ S,
95
+ batch_size,
96
+ shape,
97
+ conditioning=None,
98
+ callback=None,
99
+ normals_sequence=None,
100
+ img_callback=None,
101
+ quantize_x0=False,
102
+ eta=0.0,
103
+ mask=None,
104
+ x0=None,
105
+ temperature=1.0,
106
+ noise_dropout=0.0,
107
+ score_corrector=None,
108
+ corrector_kwargs=None,
109
+ verbose=True,
110
+ x_T=None,
111
+ log_every_t=100,
112
+ unconditional_guidance_scale=1.0,
113
+ unconditional_conditioning=None,
114
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
115
+ **kwargs,
116
+ ):
117
+ if conditioning is not None:
118
+ if isinstance(conditioning, dict):
119
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
120
+ if cbs != batch_size:
121
+ print(
122
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
123
+ )
124
+ else:
125
+ if conditioning.shape[0] != batch_size:
126
+ print(
127
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
128
+ )
129
+
130
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
131
+ # sampling
132
+ C, H, W = shape
133
+ size = (batch_size, C, H, W)
134
+ print(f"Data shape for PLMS sampling is {size}")
135
+
136
+ samples, intermediates = self.plms_sampling(
137
+ conditioning,
138
+ size,
139
+ callback=callback,
140
+ img_callback=img_callback,
141
+ quantize_denoised=quantize_x0,
142
+ mask=mask,
143
+ x0=x0,
144
+ ddim_use_original_steps=False,
145
+ noise_dropout=noise_dropout,
146
+ temperature=temperature,
147
+ score_corrector=score_corrector,
148
+ corrector_kwargs=corrector_kwargs,
149
+ x_T=x_T,
150
+ log_every_t=log_every_t,
151
+ unconditional_guidance_scale=unconditional_guidance_scale,
152
+ unconditional_conditioning=unconditional_conditioning,
153
+ )
154
+ return samples, intermediates
155
+
156
+ @torch.no_grad()
157
+ def plms_sampling(
158
+ self,
159
+ cond,
160
+ shape,
161
+ x_T=None,
162
+ ddim_use_original_steps=False,
163
+ callback=None,
164
+ timesteps=None,
165
+ quantize_denoised=False,
166
+ mask=None,
167
+ x0=None,
168
+ img_callback=None,
169
+ log_every_t=100,
170
+ temperature=1.0,
171
+ noise_dropout=0.0,
172
+ score_corrector=None,
173
+ corrector_kwargs=None,
174
+ unconditional_guidance_scale=1.0,
175
+ unconditional_conditioning=None,
176
+ ):
177
+ device = self.model.betas.device
178
+ b = shape[0]
179
+ if x_T is None:
180
+ img = torch.randn(shape, device=device)
181
+ else:
182
+ img = x_T
183
+
184
+ if timesteps is None:
185
+ timesteps = (
186
+ self.ddpm_num_timesteps
187
+ if ddim_use_original_steps
188
+ else self.ddim_timesteps
189
+ )
190
+ elif timesteps is not None and not ddim_use_original_steps:
191
+ subset_end = (
192
+ int(
193
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
194
+ * self.ddim_timesteps.shape[0]
195
+ )
196
+ - 1
197
+ )
198
+ timesteps = self.ddim_timesteps[:subset_end]
199
+
200
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
201
+ time_range = (
202
+ list(reversed(range(0, timesteps)))
203
+ if ddim_use_original_steps
204
+ else np.flip(timesteps)
205
+ )
206
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
207
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
208
+
209
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
210
+ old_eps = []
211
+
212
+ for i, step in enumerate(iterator):
213
+ index = total_steps - i - 1
214
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
215
+ ts_next = torch.full(
216
+ (b,),
217
+ time_range[min(i + 1, len(time_range) - 1)],
218
+ device=device,
219
+ dtype=torch.long,
220
+ )
221
+
222
+ if mask is not None:
223
+ assert x0 is not None
224
+ img_orig = self.model.q_sample(
225
+ x0, ts
226
+ ) # TODO: deterministic forward pass?
227
+ img = img_orig * mask + (1.0 - mask) * img
228
+
229
+ outs = self.p_sample_plms(
230
+ img,
231
+ cond,
232
+ ts,
233
+ index=index,
234
+ use_original_steps=ddim_use_original_steps,
235
+ quantize_denoised=quantize_denoised,
236
+ temperature=temperature,
237
+ noise_dropout=noise_dropout,
238
+ score_corrector=score_corrector,
239
+ corrector_kwargs=corrector_kwargs,
240
+ unconditional_guidance_scale=unconditional_guidance_scale,
241
+ unconditional_conditioning=unconditional_conditioning,
242
+ old_eps=old_eps,
243
+ t_next=ts_next,
244
+ )
245
+ img, pred_x0, e_t = outs
246
+ old_eps.append(e_t)
247
+ if len(old_eps) >= 4:
248
+ old_eps.pop(0)
249
+ if callback:
250
+ callback(i)
251
+ if img_callback:
252
+ img_callback(pred_x0, i)
253
+
254
+ if index % log_every_t == 0 or index == total_steps - 1:
255
+ intermediates["x_inter"].append(img)
256
+ intermediates["pred_x0"].append(pred_x0)
257
+
258
+ return img, intermediates
259
+
260
+ @torch.no_grad()
261
+ def p_sample_plms(
262
+ self,
263
+ x,
264
+ c,
265
+ t,
266
+ index,
267
+ repeat_noise=False,
268
+ use_original_steps=False,
269
+ quantize_denoised=False,
270
+ temperature=1.0,
271
+ noise_dropout=0.0,
272
+ score_corrector=None,
273
+ corrector_kwargs=None,
274
+ unconditional_guidance_scale=1.0,
275
+ unconditional_conditioning=None,
276
+ old_eps=None,
277
+ t_next=None,
278
+ ):
279
+ b, *_, device = *x.shape, x.device
280
+
281
+ def get_model_output(x, t):
282
+ if (
283
+ unconditional_conditioning is None
284
+ or unconditional_guidance_scale == 1.0
285
+ ):
286
+ e_t = self.model.apply_model(x, t, c)
287
+ else:
288
+ x_in = torch.cat([x] * 2)
289
+ t_in = torch.cat([t] * 2)
290
+ c_in = torch.cat([unconditional_conditioning, c])
291
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
292
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
293
+
294
+ if score_corrector is not None:
295
+ assert self.model.parameterization == "eps"
296
+ e_t = score_corrector.modify_score(
297
+ self.model, e_t, x, t, c, **corrector_kwargs
298
+ )
299
+
300
+ return e_t
301
+
302
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
303
+ alphas_prev = (
304
+ self.model.alphas_cumprod_prev
305
+ if use_original_steps
306
+ else self.ddim_alphas_prev
307
+ )
308
+ sqrt_one_minus_alphas = (
309
+ self.model.sqrt_one_minus_alphas_cumprod
310
+ if use_original_steps
311
+ else self.ddim_sqrt_one_minus_alphas
312
+ )
313
+ sigmas = (
314
+ self.model.ddim_sigmas_for_original_num_steps
315
+ if use_original_steps
316
+ else self.ddim_sigmas
317
+ )
318
+
319
+ def get_x_prev_and_pred_x0(e_t, index):
320
+ # select parameters corresponding to the currently considered timestep
321
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
322
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
323
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
324
+ sqrt_one_minus_at = torch.full(
325
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
326
+ )
327
+
328
+ # current prediction for x_0
329
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
330
+ if quantize_denoised:
331
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
332
+ # direction pointing to x_t
333
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
334
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
335
+ if noise_dropout > 0.0:
336
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
337
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
338
+ return x_prev, pred_x0
339
+
340
+ e_t = get_model_output(x, t)
341
+ if len(old_eps) == 0:
342
+ # Pseudo Improved Euler (2nd order)
343
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
344
+ e_t_next = get_model_output(x_prev, t_next)
345
+ e_t_prime = (e_t + e_t_next) / 2
346
+ elif len(old_eps) == 1:
347
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
348
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
349
+ elif len(old_eps) == 2:
350
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
351
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
352
+ elif len(old_eps) >= 3:
353
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
354
+ e_t_prime = (
355
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
356
+ ) / 24
357
+
358
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
359
+
360
+ return x_prev, pred_x0, e_t
audiosr/latent_diffusion/modules/attention.py CHANGED
@@ -1,467 +1,467 @@
1
- from inspect import isfunction
2
- import math
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn, einsum
6
- from einops import rearrange, repeat
7
-
8
- from audiosr.latent_diffusion.modules.diffusionmodules.util import checkpoint
9
-
10
-
11
- def exists(val):
12
- return val is not None
13
-
14
-
15
- def uniq(arr):
16
- return {el: True for el in arr}.keys()
17
-
18
-
19
- def default(val, d):
20
- if exists(val):
21
- return val
22
- return d() if isfunction(d) else d
23
-
24
-
25
- def max_neg_value(t):
26
- return -torch.finfo(t.dtype).max
27
-
28
-
29
- def init_(tensor):
30
- dim = tensor.shape[-1]
31
- std = 1 / math.sqrt(dim)
32
- tensor.uniform_(-std, std)
33
- return tensor
34
-
35
-
36
- # feedforward
37
- class GEGLU(nn.Module):
38
- def __init__(self, dim_in, dim_out):
39
- super().__init__()
40
- self.proj = nn.Linear(dim_in, dim_out * 2)
41
-
42
- def forward(self, x):
43
- x, gate = self.proj(x).chunk(2, dim=-1)
44
- return x * F.gelu(gate)
45
-
46
-
47
- class FeedForward(nn.Module):
48
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
- super().__init__()
50
- inner_dim = int(dim * mult)
51
- dim_out = default(dim_out, dim)
52
- project_in = (
53
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
- if not glu
55
- else GEGLU(dim, inner_dim)
56
- )
57
-
58
- self.net = nn.Sequential(
59
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
- )
61
-
62
- def forward(self, x):
63
- return self.net(x)
64
-
65
-
66
- def zero_module(module):
67
- """
68
- Zero out the parameters of a module and return it.
69
- """
70
- for p in module.parameters():
71
- p.detach().zero_()
72
- return module
73
-
74
-
75
- def Normalize(in_channels):
76
- return torch.nn.GroupNorm(
77
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
- )
79
-
80
-
81
- class LinearAttention(nn.Module):
82
- def __init__(self, dim, heads=4, dim_head=32):
83
- super().__init__()
84
- self.heads = heads
85
- hidden_dim = dim_head * heads
86
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
-
89
- def forward(self, x):
90
- b, c, h, w = x.shape
91
- qkv = self.to_qkv(x)
92
- q, k, v = rearrange(
93
- qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
- )
95
- k = k.softmax(dim=-1)
96
- context = torch.einsum("bhdn,bhen->bhde", k, v)
97
- out = torch.einsum("bhde,bhdn->bhen", context, q)
98
- out = rearrange(
99
- out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
- )
101
- return self.to_out(out)
102
-
103
-
104
- class SpatialSelfAttention(nn.Module):
105
- def __init__(self, in_channels):
106
- super().__init__()
107
- self.in_channels = in_channels
108
-
109
- self.norm = Normalize(in_channels)
110
- self.q = torch.nn.Conv2d(
111
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
- )
113
- self.k = torch.nn.Conv2d(
114
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
- )
116
- self.v = torch.nn.Conv2d(
117
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
- )
119
- self.proj_out = torch.nn.Conv2d(
120
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
- )
122
-
123
- def forward(self, x):
124
- h_ = x
125
- h_ = self.norm(h_)
126
- q = self.q(h_)
127
- k = self.k(h_)
128
- v = self.v(h_)
129
-
130
- # compute attention
131
- b, c, h, w = q.shape
132
- q = rearrange(q, "b c h w -> b (h w) c")
133
- k = rearrange(k, "b c h w -> b c (h w)")
134
- w_ = torch.einsum("bij,bjk->bik", q, k)
135
-
136
- w_ = w_ * (int(c) ** (-0.5))
137
- w_ = torch.nn.functional.softmax(w_, dim=2)
138
-
139
- # attend to values
140
- v = rearrange(v, "b c h w -> b c (h w)")
141
- w_ = rearrange(w_, "b i j -> b j i")
142
- h_ = torch.einsum("bij,bjk->bik", v, w_)
143
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
- h_ = self.proj_out(h_)
145
-
146
- return x + h_
147
-
148
-
149
- # class CrossAttention(nn.Module):
150
- # """
151
- # ### Cross Attention Layer
152
- # This falls-back to self-attention when conditional embeddings are not specified.
153
- # """
154
-
155
- # use_flash_attention: bool = True
156
-
157
- # # use_flash_attention: bool = False
158
- # def __init__(
159
- # self,
160
- # query_dim,
161
- # context_dim=None,
162
- # heads=8,
163
- # dim_head=64,
164
- # dropout=0.0,
165
- # is_inplace: bool = True,
166
- # ):
167
- # # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
- # """
169
- # :param d_model: is the input embedding size
170
- # :param n_heads: is the number of attention heads
171
- # :param d_head: is the size of a attention head
172
- # :param d_cond: is the size of the conditional embeddings
173
- # :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
- # save memory
175
- # """
176
- # super().__init__()
177
-
178
- # self.is_inplace = is_inplace
179
- # self.n_heads = heads
180
- # self.d_head = dim_head
181
-
182
- # # Attention scaling factor
183
- # self.scale = dim_head**-0.5
184
-
185
- # # The normal self-attention layer
186
- # if context_dim is None:
187
- # context_dim = query_dim
188
-
189
- # # Query, key and value mappings
190
- # d_attn = dim_head * heads
191
- # self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
- # self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
- # self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
-
195
- # # Final linear layer
196
- # self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
-
198
- # # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
- # # Flash attention is only used if it's installed
200
- # # and `CrossAttention.use_flash_attention` is set to `True`.
201
- # try:
202
- # # You can install flash attention by cloning their Github repo,
203
- # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
- # # and then running `python setup.py install`
205
- # from flash_attn.flash_attention import FlashAttention
206
-
207
- # self.flash = FlashAttention()
208
- # # Set the scale for scaled dot-product attention.
209
- # self.flash.softmax_scale = self.scale
210
- # # Set to `None` if it's not installed
211
- # except ImportError:
212
- # self.flash = None
213
-
214
- # def forward(self, x, context=None, mask=None):
215
- # """
216
- # :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
- # :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
- # """
219
-
220
- # # If `cond` is `None` we perform self attention
221
- # has_cond = context is not None
222
- # if not has_cond:
223
- # context = x
224
-
225
- # # Get query, key and value vectors
226
- # q = self.to_q(x)
227
- # k = self.to_k(context)
228
- # v = self.to_v(context)
229
-
230
- # # Use flash attention if it's available and the head size is less than or equal to `128`
231
- # if (
232
- # CrossAttention.use_flash_attention
233
- # and self.flash is not None
234
- # and not has_cond
235
- # and self.d_head <= 128
236
- # ):
237
- # return self.flash_attention(q, k, v)
238
- # # Otherwise, fallback to normal attention
239
- # else:
240
- # return self.normal_attention(q, k, v)
241
-
242
- # def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
- # """
244
- # #### Flash Attention
245
- # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
- # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
- # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
- # """
249
-
250
- # # Get batch size and number of elements along sequence axis (`width * height`)
251
- # batch_size, seq_len, _ = q.shape
252
-
253
- # # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
- # # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
- # qkv = torch.stack((q, k, v), dim=2)
256
- # # Split the heads
257
- # qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
-
259
- # # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
- # # fit this size.
261
- # if self.d_head <= 32:
262
- # pad = 32 - self.d_head
263
- # elif self.d_head <= 64:
264
- # pad = 64 - self.d_head
265
- # elif self.d_head <= 128:
266
- # pad = 128 - self.d_head
267
- # else:
268
- # raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
-
270
- # # Pad the heads
271
- # if pad:
272
- # qkv = torch.cat(
273
- # (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
- # )
275
-
276
- # # Compute attention
277
- # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
- # # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
- # # TODO here I add the dtype changing
280
- # out, _ = self.flash(qkv.type(torch.float16))
281
- # # Truncate the extra head size
282
- # out = out[:, :, :, : self.d_head].float()
283
- # # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
- # out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
-
286
- # # Map to `[batch_size, height * width, d_model]` with a linear layer
287
- # return self.to_out(out)
288
-
289
- # def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
- # """
291
- # #### Normal Attention
292
-
293
- # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
- # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
- # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
- # """
297
-
298
- # # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
- # q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
- # k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
- # v = v.view(*v.shape[:2], self.n_heads, -1)
302
-
303
- # # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
- # attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
-
306
- # # Compute softmax
307
- # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
- # if self.is_inplace:
309
- # half = attn.shape[0] // 2
310
- # attn[half:] = attn[half:].softmax(dim=-1)
311
- # attn[:half] = attn[:half].softmax(dim=-1)
312
- # else:
313
- # attn = attn.softmax(dim=-1)
314
-
315
- # # Compute attention output
316
- # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
- # # attn: [bs, 20, 64, 1]
318
- # # v: [bs, 1, 20, 32]
319
- # out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
- # # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
- # out = out.reshape(*out.shape[:2], -1)
322
- # # Map to `[batch_size, height * width, d_model]` with a linear layer
323
- # return self.to_out(out)
324
-
325
-
326
- class CrossAttention(nn.Module):
327
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
328
- super().__init__()
329
- inner_dim = dim_head * heads
330
- context_dim = default(context_dim, query_dim)
331
-
332
- self.scale = dim_head**-0.5
333
- self.heads = heads
334
-
335
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
-
339
- self.to_out = nn.Sequential(
340
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
341
- )
342
-
343
- def forward(self, x, context=None, mask=None):
344
- h = self.heads
345
-
346
- q = self.to_q(x)
347
- context = default(context, x)
348
-
349
- k = self.to_k(context)
350
- v = self.to_v(context)
351
-
352
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
353
-
354
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
355
-
356
- if exists(mask):
357
- mask = rearrange(mask, "b ... -> b (...)")
358
- max_neg_value = -torch.finfo(sim.dtype).max
359
- mask = repeat(mask, "b j -> (b h) () j", h=h)
360
- sim.masked_fill_(~(mask == 1), max_neg_value)
361
-
362
- # attention, what we cannot get enough of
363
- attn = sim.softmax(dim=-1)
364
-
365
- out = einsum("b i j, b j d -> b i d", attn, v)
366
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
367
- return self.to_out(out)
368
-
369
-
370
- class BasicTransformerBlock(nn.Module):
371
- def __init__(
372
- self,
373
- dim,
374
- n_heads,
375
- d_head,
376
- dropout=0.0,
377
- context_dim=None,
378
- gated_ff=True,
379
- checkpoint=True,
380
- ):
381
- super().__init__()
382
- self.attn1 = CrossAttention(
383
- query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
- ) # is a self-attention
385
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
- self.attn2 = CrossAttention(
387
- query_dim=dim,
388
- context_dim=context_dim,
389
- heads=n_heads,
390
- dim_head=d_head,
391
- dropout=dropout,
392
- ) # is self-attn if context is none
393
- self.norm1 = nn.LayerNorm(dim)
394
- self.norm2 = nn.LayerNorm(dim)
395
- self.norm3 = nn.LayerNorm(dim)
396
- self.checkpoint = checkpoint
397
-
398
- def forward(self, x, context=None, mask=None):
399
- if context is None:
400
- return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
- else:
402
- return checkpoint(
403
- self._forward, (x, context, mask), self.parameters(), self.checkpoint
404
- )
405
-
406
- def _forward(self, x, context=None, mask=None):
407
- x = self.attn1(self.norm1(x)) + x
408
- x = self.attn2(self.norm2(x), context=context, mask=mask) + x
409
- x = self.ff(self.norm3(x)) + x
410
- return x
411
-
412
-
413
- class SpatialTransformer(nn.Module):
414
- """
415
- Transformer block for image-like data.
416
- First, project the input (aka embedding)
417
- and reshape to b, t, d.
418
- Then apply standard transformer action.
419
- Finally, reshape to image
420
- """
421
-
422
- def __init__(
423
- self,
424
- in_channels,
425
- n_heads,
426
- d_head,
427
- depth=1,
428
- dropout=0.0,
429
- context_dim=None,
430
- ):
431
- super().__init__()
432
-
433
- context_dim = context_dim
434
-
435
- self.in_channels = in_channels
436
- inner_dim = n_heads * d_head
437
- self.norm = Normalize(in_channels)
438
-
439
- self.proj_in = nn.Conv2d(
440
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
441
- )
442
-
443
- self.transformer_blocks = nn.ModuleList(
444
- [
445
- BasicTransformerBlock(
446
- inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
447
- )
448
- for d in range(depth)
449
- ]
450
- )
451
-
452
- self.proj_out = zero_module(
453
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
454
- )
455
-
456
- def forward(self, x, context=None, mask=None):
457
- # note: if no context is given, cross-attention defaults to self-attention
458
- b, c, h, w = x.shape
459
- x_in = x
460
- x = self.norm(x)
461
- x = self.proj_in(x)
462
- x = rearrange(x, "b c h w -> b (h w) c")
463
- for block in self.transformer_blocks:
464
- x = block(x, context=context, mask=mask)
465
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
466
- x = self.proj_out(x)
467
- return x + x_in
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from audiosr.latent_diffusion.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ # class CrossAttention(nn.Module):
150
+ # """
151
+ # ### Cross Attention Layer
152
+ # This falls-back to self-attention when conditional embeddings are not specified.
153
+ # """
154
+
155
+ # use_flash_attention: bool = True
156
+
157
+ # # use_flash_attention: bool = False
158
+ # def __init__(
159
+ # self,
160
+ # query_dim,
161
+ # context_dim=None,
162
+ # heads=8,
163
+ # dim_head=64,
164
+ # dropout=0.0,
165
+ # is_inplace: bool = True,
166
+ # ):
167
+ # # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
+ # """
169
+ # :param d_model: is the input embedding size
170
+ # :param n_heads: is the number of attention heads
171
+ # :param d_head: is the size of a attention head
172
+ # :param d_cond: is the size of the conditional embeddings
173
+ # :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
+ # save memory
175
+ # """
176
+ # super().__init__()
177
+
178
+ # self.is_inplace = is_inplace
179
+ # self.n_heads = heads
180
+ # self.d_head = dim_head
181
+
182
+ # # Attention scaling factor
183
+ # self.scale = dim_head**-0.5
184
+
185
+ # # The normal self-attention layer
186
+ # if context_dim is None:
187
+ # context_dim = query_dim
188
+
189
+ # # Query, key and value mappings
190
+ # d_attn = dim_head * heads
191
+ # self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
+ # self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
+ # self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
+
195
+ # # Final linear layer
196
+ # self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
+
198
+ # # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
+ # # Flash attention is only used if it's installed
200
+ # # and `CrossAttention.use_flash_attention` is set to `True`.
201
+ # try:
202
+ # # You can install flash attention by cloning their Github repo,
203
+ # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
+ # # and then running `python setup.py install`
205
+ # from flash_attn.flash_attention import FlashAttention
206
+
207
+ # self.flash = FlashAttention()
208
+ # # Set the scale for scaled dot-product attention.
209
+ # self.flash.softmax_scale = self.scale
210
+ # # Set to `None` if it's not installed
211
+ # except ImportError:
212
+ # self.flash = None
213
+
214
+ # def forward(self, x, context=None, mask=None):
215
+ # """
216
+ # :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
+ # :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
+ # """
219
+
220
+ # # If `cond` is `None` we perform self attention
221
+ # has_cond = context is not None
222
+ # if not has_cond:
223
+ # context = x
224
+
225
+ # # Get query, key and value vectors
226
+ # q = self.to_q(x)
227
+ # k = self.to_k(context)
228
+ # v = self.to_v(context)
229
+
230
+ # # Use flash attention if it's available and the head size is less than or equal to `128`
231
+ # if (
232
+ # CrossAttention.use_flash_attention
233
+ # and self.flash is not None
234
+ # and not has_cond
235
+ # and self.d_head <= 128
236
+ # ):
237
+ # return self.flash_attention(q, k, v)
238
+ # # Otherwise, fallback to normal attention
239
+ # else:
240
+ # return self.normal_attention(q, k, v)
241
+
242
+ # def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
+ # """
244
+ # #### Flash Attention
245
+ # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
+ # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
+ # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ # """
249
+
250
+ # # Get batch size and number of elements along sequence axis (`width * height`)
251
+ # batch_size, seq_len, _ = q.shape
252
+
253
+ # # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
+ # # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
+ # qkv = torch.stack((q, k, v), dim=2)
256
+ # # Split the heads
257
+ # qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
+
259
+ # # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
+ # # fit this size.
261
+ # if self.d_head <= 32:
262
+ # pad = 32 - self.d_head
263
+ # elif self.d_head <= 64:
264
+ # pad = 64 - self.d_head
265
+ # elif self.d_head <= 128:
266
+ # pad = 128 - self.d_head
267
+ # else:
268
+ # raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
+
270
+ # # Pad the heads
271
+ # if pad:
272
+ # qkv = torch.cat(
273
+ # (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
+ # )
275
+
276
+ # # Compute attention
277
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
+ # # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
+ # # TODO here I add the dtype changing
280
+ # out, _ = self.flash(qkv.type(torch.float16))
281
+ # # Truncate the extra head size
282
+ # out = out[:, :, :, : self.d_head].float()
283
+ # # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
+ # out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
+
286
+ # # Map to `[batch_size, height * width, d_model]` with a linear layer
287
+ # return self.to_out(out)
288
+
289
+ # def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
+ # """
291
+ # #### Normal Attention
292
+
293
+ # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
+ # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
+ # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
+ # """
297
+
298
+ # # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
+ # q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
+ # k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
+ # v = v.view(*v.shape[:2], self.n_heads, -1)
302
+
303
+ # # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
+ # attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
+
306
+ # # Compute softmax
307
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
+ # if self.is_inplace:
309
+ # half = attn.shape[0] // 2
310
+ # attn[half:] = attn[half:].softmax(dim=-1)
311
+ # attn[:half] = attn[:half].softmax(dim=-1)
312
+ # else:
313
+ # attn = attn.softmax(dim=-1)
314
+
315
+ # # Compute attention output
316
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
+ # # attn: [bs, 20, 64, 1]
318
+ # # v: [bs, 1, 20, 32]
319
+ # out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
+ # # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
+ # out = out.reshape(*out.shape[:2], -1)
322
+ # # Map to `[batch_size, height * width, d_model]` with a linear layer
323
+ # return self.to_out(out)
324
+
325
+
326
+ class CrossAttention(nn.Module):
327
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
328
+ super().__init__()
329
+ inner_dim = dim_head * heads
330
+ context_dim = default(context_dim, query_dim)
331
+
332
+ self.scale = dim_head**-0.5
333
+ self.heads = heads
334
+
335
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
+
339
+ self.to_out = nn.Sequential(
340
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
341
+ )
342
+
343
+ def forward(self, x, context=None, mask=None):
344
+ h = self.heads
345
+
346
+ q = self.to_q(x)
347
+ context = default(context, x)
348
+
349
+ k = self.to_k(context)
350
+ v = self.to_v(context)
351
+
352
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
353
+
354
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
355
+
356
+ if exists(mask):
357
+ mask = rearrange(mask, "b ... -> b (...)")
358
+ max_neg_value = -torch.finfo(sim.dtype).max
359
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
360
+ sim.masked_fill_(~(mask == 1), max_neg_value)
361
+
362
+ # attention, what we cannot get enough of
363
+ attn = sim.softmax(dim=-1)
364
+
365
+ out = einsum("b i j, b j d -> b i d", attn, v)
366
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
367
+ return self.to_out(out)
368
+
369
+
370
+ class BasicTransformerBlock(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ n_heads,
375
+ d_head,
376
+ dropout=0.0,
377
+ context_dim=None,
378
+ gated_ff=True,
379
+ checkpoint=True,
380
+ ):
381
+ super().__init__()
382
+ self.attn1 = CrossAttention(
383
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
+ ) # is a self-attention
385
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
+ self.attn2 = CrossAttention(
387
+ query_dim=dim,
388
+ context_dim=context_dim,
389
+ heads=n_heads,
390
+ dim_head=d_head,
391
+ dropout=dropout,
392
+ ) # is self-attn if context is none
393
+ self.norm1 = nn.LayerNorm(dim)
394
+ self.norm2 = nn.LayerNorm(dim)
395
+ self.norm3 = nn.LayerNorm(dim)
396
+ self.checkpoint = checkpoint
397
+
398
+ def forward(self, x, context=None, mask=None):
399
+ if context is None:
400
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
+ else:
402
+ return checkpoint(
403
+ self._forward, (x, context, mask), self.parameters(), self.checkpoint
404
+ )
405
+
406
+ def _forward(self, x, context=None, mask=None):
407
+ x = self.attn1(self.norm1(x)) + x
408
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
409
+ x = self.ff(self.norm3(x)) + x
410
+ return x
411
+
412
+
413
+ class SpatialTransformer(nn.Module):
414
+ """
415
+ Transformer block for image-like data.
416
+ First, project the input (aka embedding)
417
+ and reshape to b, t, d.
418
+ Then apply standard transformer action.
419
+ Finally, reshape to image
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ n_heads,
426
+ d_head,
427
+ depth=1,
428
+ dropout=0.0,
429
+ context_dim=None,
430
+ ):
431
+ super().__init__()
432
+
433
+ context_dim = context_dim
434
+
435
+ self.in_channels = in_channels
436
+ inner_dim = n_heads * d_head
437
+ self.norm = Normalize(in_channels)
438
+
439
+ self.proj_in = nn.Conv2d(
440
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
441
+ )
442
+
443
+ self.transformer_blocks = nn.ModuleList(
444
+ [
445
+ BasicTransformerBlock(
446
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
447
+ )
448
+ for d in range(depth)
449
+ ]
450
+ )
451
+
452
+ self.proj_out = zero_module(
453
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
454
+ )
455
+
456
+ def forward(self, x, context=None, mask=None):
457
+ # note: if no context is given, cross-attention defaults to self-attention
458
+ b, c, h, w = x.shape
459
+ x_in = x
460
+ x = self.norm(x)
461
+ x = self.proj_in(x)
462
+ x = rearrange(x, "b c h w -> b (h w) c")
463
+ for block in self.transformer_blocks:
464
+ x = block(x, context=context, mask=mask)
465
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
466
+ x = self.proj_out(x)
467
+ return x + x_in
audiosr/latent_diffusion/modules/audiomae/AudioMAE.py CHANGED
@@ -1,149 +1,149 @@
1
- """
2
- Reference Repo: https://github.com/facebookresearch/AudioMAE
3
- """
4
-
5
- import torch
6
- import torch.nn as nn
7
- from timm.models.layers import to_2tuple
8
- import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit
9
- import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae
10
-
11
- # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
12
-
13
-
14
- class PatchEmbed_new(nn.Module):
15
- """Flexible Image to Patch Embedding"""
16
-
17
- def __init__(
18
- self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
19
- ):
20
- super().__init__()
21
- img_size = to_2tuple(img_size)
22
- patch_size = to_2tuple(patch_size)
23
- stride = to_2tuple(stride)
24
-
25
- self.img_size = img_size
26
- self.patch_size = patch_size
27
-
28
- self.proj = nn.Conv2d(
29
- in_chans, embed_dim, kernel_size=patch_size, stride=stride
30
- ) # with overlapped patches
31
- # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
32
-
33
- # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
34
- # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
35
- _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
36
- self.patch_hw = (h, w)
37
- self.num_patches = h * w
38
-
39
- def get_output_shape(self, img_size):
40
- # todo: don't be lazy..
41
- return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
42
-
43
- def forward(self, x):
44
- B, C, H, W = x.shape
45
- # FIXME look at relaxing size constraints
46
- # assert H == self.img_size[0] and W == self.img_size[1], \
47
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
48
- x = self.proj(x)
49
- x = x.flatten(2).transpose(1, 2)
50
- return x
51
-
52
-
53
- class AudioMAE(nn.Module):
54
- """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
55
-
56
- def __init__(
57
- self,
58
- ):
59
- super().__init__()
60
- model = models_vit.__dict__["vit_base_patch16"](
61
- num_classes=527,
62
- drop_path_rate=0.1,
63
- global_pool=True,
64
- mask_2d=True,
65
- use_custom_patch=False,
66
- )
67
-
68
- img_size = (1024, 128)
69
- emb_dim = 768
70
-
71
- model.patch_embed = PatchEmbed_new(
72
- img_size=img_size,
73
- patch_size=(16, 16),
74
- in_chans=1,
75
- embed_dim=emb_dim,
76
- stride=16,
77
- )
78
- num_patches = model.patch_embed.num_patches
79
- # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
80
- model.pos_embed = nn.Parameter(
81
- torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
82
- ) # fixed sin-cos embedding
83
-
84
- # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
85
- # checkpoint = torch.load(checkpoint_path, map_location='cpu')
86
- # msg = model.load_state_dict(checkpoint['model'], strict=False)
87
- # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
88
-
89
- self.model = model
90
-
91
- def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
92
- """
93
- x: mel fbank [Batch, 1, T, F]
94
- mask_t_prob: 'T masking ratio (percentage of removed patches).'
95
- mask_f_prob: 'F masking ratio (percentage of removed patches).'
96
- """
97
- return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
98
-
99
-
100
- class Vanilla_AudioMAE(nn.Module):
101
- """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
102
-
103
- def __init__(
104
- self,
105
- ):
106
- super().__init__()
107
- model = models_mae.__dict__["mae_vit_base_patch16"](
108
- in_chans=1, audio_exp=True, img_size=(1024, 128)
109
- )
110
-
111
- # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
112
- # checkpoint = torch.load(checkpoint_path, map_location='cpu')
113
- # msg = model.load_state_dict(checkpoint['model'], strict=False)
114
-
115
- # Skip the missing keys of decoder modules (not required)
116
- # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
117
-
118
- self.model = model.eval()
119
-
120
- def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
121
- """
122
- x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
123
- mask_ratio: 'masking ratio (percentage of removed patches).'
124
- """
125
- with torch.no_grad():
126
- # embed: [B, 513, 768] for mask_ratio=0.0
127
- if no_mask:
128
- if no_average:
129
- raise RuntimeError("This function is deprecated")
130
- embed = self.model.forward_encoder_no_random_mask_no_average(
131
- x
132
- ) # mask_ratio
133
- else:
134
- embed = self.model.forward_encoder_no_mask(x) # mask_ratio
135
- else:
136
- raise RuntimeError("This function is deprecated")
137
- embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
138
- return embed
139
-
140
-
141
- if __name__ == "__main__":
142
- model = Vanilla_AudioMAE().cuda()
143
- input = torch.randn(4, 1, 1024, 128).cuda()
144
- print("The first run")
145
- embed = model(input, mask_ratio=0.0, no_mask=True)
146
- print(embed)
147
- print("The second run")
148
- embed = model(input, mask_ratio=0.0)
149
- print(embed)
 
1
+ """
2
+ Reference Repo: https://github.com/facebookresearch/AudioMAE
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.layers import to_2tuple
8
+ import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit
9
+ import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae
10
+
11
+ # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
12
+
13
+
14
+ class PatchEmbed_new(nn.Module):
15
+ """Flexible Image to Patch Embedding"""
16
+
17
+ def __init__(
18
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
19
+ ):
20
+ super().__init__()
21
+ img_size = to_2tuple(img_size)
22
+ patch_size = to_2tuple(patch_size)
23
+ stride = to_2tuple(stride)
24
+
25
+ self.img_size = img_size
26
+ self.patch_size = patch_size
27
+
28
+ self.proj = nn.Conv2d(
29
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
30
+ ) # with overlapped patches
31
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
32
+
33
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
34
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
35
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
36
+ self.patch_hw = (h, w)
37
+ self.num_patches = h * w
38
+
39
+ def get_output_shape(self, img_size):
40
+ # todo: don't be lazy..
41
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
42
+
43
+ def forward(self, x):
44
+ B, C, H, W = x.shape
45
+ # FIXME look at relaxing size constraints
46
+ # assert H == self.img_size[0] and W == self.img_size[1], \
47
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
48
+ x = self.proj(x)
49
+ x = x.flatten(2).transpose(1, 2)
50
+ return x
51
+
52
+
53
+ class AudioMAE(nn.Module):
54
+ """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
55
+
56
+ def __init__(
57
+ self,
58
+ ):
59
+ super().__init__()
60
+ model = models_vit.__dict__["vit_base_patch16"](
61
+ num_classes=527,
62
+ drop_path_rate=0.1,
63
+ global_pool=True,
64
+ mask_2d=True,
65
+ use_custom_patch=False,
66
+ )
67
+
68
+ img_size = (1024, 128)
69
+ emb_dim = 768
70
+
71
+ model.patch_embed = PatchEmbed_new(
72
+ img_size=img_size,
73
+ patch_size=(16, 16),
74
+ in_chans=1,
75
+ embed_dim=emb_dim,
76
+ stride=16,
77
+ )
78
+ num_patches = model.patch_embed.num_patches
79
+ # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
80
+ model.pos_embed = nn.Parameter(
81
+ torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
82
+ ) # fixed sin-cos embedding
83
+
84
+ # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
85
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
86
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
87
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
88
+
89
+ self.model = model
90
+
91
+ def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
92
+ """
93
+ x: mel fbank [Batch, 1, T, F]
94
+ mask_t_prob: 'T masking ratio (percentage of removed patches).'
95
+ mask_f_prob: 'F masking ratio (percentage of removed patches).'
96
+ """
97
+ return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
98
+
99
+
100
+ class Vanilla_AudioMAE(nn.Module):
101
+ """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
102
+
103
+ def __init__(
104
+ self,
105
+ ):
106
+ super().__init__()
107
+ model = models_mae.__dict__["mae_vit_base_patch16"](
108
+ in_chans=1, audio_exp=True, img_size=(1024, 128)
109
+ )
110
+
111
+ # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
112
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
113
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
114
+
115
+ # Skip the missing keys of decoder modules (not required)
116
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
117
+
118
+ self.model = model.eval()
119
+
120
+ def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
121
+ """
122
+ x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
123
+ mask_ratio: 'masking ratio (percentage of removed patches).'
124
+ """
125
+ with torch.no_grad():
126
+ # embed: [B, 513, 768] for mask_ratio=0.0
127
+ if no_mask:
128
+ if no_average:
129
+ raise RuntimeError("This function is deprecated")
130
+ embed = self.model.forward_encoder_no_random_mask_no_average(
131
+ x
132
+ ) # mask_ratio
133
+ else:
134
+ embed = self.model.forward_encoder_no_mask(x) # mask_ratio
135
+ else:
136
+ raise RuntimeError("This function is deprecated")
137
+ embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
138
+ return embed
139
+
140
+
141
+ if __name__ == "__main__":
142
+ model = Vanilla_AudioMAE().cuda()
143
+ input = torch.randn(4, 1, 1024, 128).cuda()
144
+ print("The first run")
145
+ embed = model(input, mask_ratio=0.0, no_mask=True)
146
+ print(embed)
147
+ print("The second run")
148
+ embed = model(input, mask_ratio=0.0)
149
+ print(embed)
audiosr/latent_diffusion/modules/audiomae/models_mae.py CHANGED
@@ -1,613 +1,613 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
- # DeiT: https://github.com/facebookresearch/deit
10
- # --------------------------------------------------------
11
-
12
- from functools import partial
13
-
14
- import torch
15
- import torch.nn as nn
16
-
17
- from timm.models.vision_transformer import Block
18
- from audiosr.latent_diffusion.modules.audiomae.util.pos_embed import (
19
- get_2d_sincos_pos_embed,
20
- get_2d_sincos_pos_embed_flexible,
21
- )
22
- from audiosr.latent_diffusion.modules.audiomae.util.patch_embed import (
23
- PatchEmbed_new,
24
- PatchEmbed_org,
25
- )
26
-
27
-
28
- class MaskedAutoencoderViT(nn.Module):
29
- """Masked Autoencoder with VisionTransformer backbone"""
30
-
31
- def __init__(
32
- self,
33
- img_size=224,
34
- patch_size=16,
35
- stride=10,
36
- in_chans=3,
37
- embed_dim=1024,
38
- depth=24,
39
- num_heads=16,
40
- decoder_embed_dim=512,
41
- decoder_depth=8,
42
- decoder_num_heads=16,
43
- mlp_ratio=4.0,
44
- norm_layer=nn.LayerNorm,
45
- norm_pix_loss=False,
46
- audio_exp=False,
47
- alpha=0.0,
48
- temperature=0.2,
49
- mode=0,
50
- contextual_depth=8,
51
- use_custom_patch=False,
52
- split_pos=False,
53
- pos_trainable=False,
54
- use_nce=False,
55
- beta=4.0,
56
- decoder_mode=0,
57
- mask_t_prob=0.6,
58
- mask_f_prob=0.5,
59
- mask_2d=False,
60
- epoch=0,
61
- no_shift=False,
62
- ):
63
- super().__init__()
64
-
65
- self.audio_exp = audio_exp
66
- self.embed_dim = embed_dim
67
- self.decoder_embed_dim = decoder_embed_dim
68
- # --------------------------------------------------------------------------
69
- # MAE encoder specifics
70
- if use_custom_patch:
71
- print(
72
- f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
73
- )
74
- self.patch_embed = PatchEmbed_new(
75
- img_size=img_size,
76
- patch_size=patch_size,
77
- in_chans=in_chans,
78
- embed_dim=embed_dim,
79
- stride=stride,
80
- )
81
- else:
82
- self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
83
- self.use_custom_patch = use_custom_patch
84
- num_patches = self.patch_embed.num_patches
85
-
86
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
87
-
88
- # self.split_pos = split_pos # not useful
89
- self.pos_embed = nn.Parameter(
90
- torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
91
- ) # fixed sin-cos embedding
92
-
93
- self.encoder_depth = depth
94
- self.contextual_depth = contextual_depth
95
- self.blocks = nn.ModuleList(
96
- [
97
- Block(
98
- embed_dim,
99
- num_heads,
100
- mlp_ratio,
101
- qkv_bias=True,
102
- norm_layer=norm_layer,
103
- ) # qk_scale=None
104
- for i in range(depth)
105
- ]
106
- )
107
- self.norm = norm_layer(embed_dim)
108
-
109
- # --------------------------------------------------------------------------
110
- # MAE decoder specifics
111
- self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
112
-
113
- self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
114
- self.decoder_pos_embed = nn.Parameter(
115
- torch.zeros(1, num_patches + 1, decoder_embed_dim),
116
- requires_grad=pos_trainable,
117
- ) # fixed sin-cos embedding
118
-
119
- self.no_shift = no_shift
120
-
121
- self.decoder_mode = decoder_mode
122
- if (
123
- self.use_custom_patch
124
- ): # overlapped patches as in AST. Similar performance yet compute heavy
125
- window_size = (6, 6)
126
- feat_size = (102, 12)
127
- else:
128
- window_size = (4, 4)
129
- feat_size = (64, 8)
130
- if self.decoder_mode == 1:
131
- decoder_modules = []
132
- for index in range(16):
133
- if self.no_shift:
134
- shift_size = (0, 0)
135
- else:
136
- if (index % 2) == 0:
137
- shift_size = (0, 0)
138
- else:
139
- shift_size = (2, 0)
140
- # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
141
- decoder_modules.append(
142
- SwinTransformerBlock(
143
- dim=decoder_embed_dim,
144
- num_heads=16,
145
- feat_size=feat_size,
146
- window_size=window_size,
147
- shift_size=shift_size,
148
- mlp_ratio=mlp_ratio,
149
- drop=0.0,
150
- drop_attn=0.0,
151
- drop_path=0.0,
152
- extra_norm=False,
153
- sequential_attn=False,
154
- norm_layer=norm_layer, # nn.LayerNorm,
155
- )
156
- )
157
- self.decoder_blocks = nn.ModuleList(decoder_modules)
158
- else:
159
- # Transfomer
160
- self.decoder_blocks = nn.ModuleList(
161
- [
162
- Block(
163
- decoder_embed_dim,
164
- decoder_num_heads,
165
- mlp_ratio,
166
- qkv_bias=True,
167
- norm_layer=norm_layer,
168
- ) # qk_scale=None,
169
- for i in range(decoder_depth)
170
- ]
171
- )
172
-
173
- self.decoder_norm = norm_layer(decoder_embed_dim)
174
- self.decoder_pred = nn.Linear(
175
- decoder_embed_dim, patch_size**2 * in_chans, bias=True
176
- ) # decoder to patch
177
-
178
- # --------------------------------------------------------------------------
179
-
180
- self.norm_pix_loss = norm_pix_loss
181
-
182
- self.patch_size = patch_size
183
- self.stride = stride
184
-
185
- # audio exps
186
- self.alpha = alpha
187
- self.T = temperature
188
- self.mode = mode
189
- self.use_nce = use_nce
190
- self.beta = beta
191
-
192
- self.log_softmax = nn.LogSoftmax(dim=-1)
193
-
194
- self.mask_t_prob = mask_t_prob
195
- self.mask_f_prob = mask_f_prob
196
- self.mask_2d = mask_2d
197
-
198
- self.epoch = epoch
199
-
200
- self.initialize_weights()
201
-
202
- def initialize_weights(self):
203
- # initialization
204
- # initialize (and freeze) pos_embed by sin-cos embedding
205
- if self.audio_exp:
206
- pos_embed = get_2d_sincos_pos_embed_flexible(
207
- self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
208
- )
209
- else:
210
- pos_embed = get_2d_sincos_pos_embed(
211
- self.pos_embed.shape[-1],
212
- int(self.patch_embed.num_patches**0.5),
213
- cls_token=True,
214
- )
215
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
216
-
217
- if self.audio_exp:
218
- decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
219
- self.decoder_pos_embed.shape[-1],
220
- self.patch_embed.patch_hw,
221
- cls_token=True,
222
- )
223
- else:
224
- decoder_pos_embed = get_2d_sincos_pos_embed(
225
- self.decoder_pos_embed.shape[-1],
226
- int(self.patch_embed.num_patches**0.5),
227
- cls_token=True,
228
- )
229
- self.decoder_pos_embed.data.copy_(
230
- torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
231
- )
232
-
233
- # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
234
- w = self.patch_embed.proj.weight.data
235
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
236
-
237
- # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
238
- torch.nn.init.normal_(self.cls_token, std=0.02)
239
- torch.nn.init.normal_(self.mask_token, std=0.02)
240
-
241
- # initialize nn.Linear and nn.LayerNorm
242
- self.apply(self._init_weights)
243
-
244
- def _init_weights(self, m):
245
- if isinstance(m, nn.Linear):
246
- # we use xavier_uniform following official JAX ViT:
247
- torch.nn.init.xavier_uniform_(m.weight)
248
- if isinstance(m, nn.Linear) and m.bias is not None:
249
- nn.init.constant_(m.bias, 0)
250
- elif isinstance(m, nn.LayerNorm):
251
- nn.init.constant_(m.bias, 0)
252
- nn.init.constant_(m.weight, 1.0)
253
-
254
- def patchify(self, imgs):
255
- """
256
- imgs: (N, 3, H, W)
257
- x: (N, L, patch_size**2 *3)
258
- L = (H/p)*(W/p)
259
- """
260
- p = self.patch_embed.patch_size[0]
261
- # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
262
-
263
- if self.audio_exp:
264
- if self.use_custom_patch: # overlapped patch
265
- h, w = self.patch_embed.patch_hw
266
- # todo: fixed h/w patch size and stride size. Make hw custom in the future
267
- x = imgs.unfold(2, self.patch_size, self.stride).unfold(
268
- 3, self.patch_size, self.stride
269
- ) # n,1,H,W -> n,1,h,w,p,p
270
- x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
271
- # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
272
- # x = torch.einsum('nchpwq->nhwpqc', x)
273
- # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
274
- else:
275
- h = imgs.shape[2] // p
276
- w = imgs.shape[3] // p
277
- # h,w = self.patch_embed.patch_hw
278
- x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
279
- x = torch.einsum("nchpwq->nhwpqc", x)
280
- x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
281
- else:
282
- h = w = imgs.shape[2] // p
283
- x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
284
- x = torch.einsum("nchpwq->nhwpqc", x)
285
- x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
286
-
287
- return x
288
-
289
- def unpatchify(self, x):
290
- """
291
- x: (N, L, patch_size**2 *3)
292
- specs: (N, 1, H, W)
293
- """
294
- p = self.patch_embed.patch_size[0]
295
- h = 1024 // p
296
- w = 128 // p
297
- x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
298
- x = torch.einsum("nhwpqc->nchpwq", x)
299
- specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
300
- return specs
301
-
302
- def random_masking(self, x, mask_ratio):
303
- """
304
- Perform per-sample random masking by per-sample shuffling.
305
- Per-sample shuffling is done by argsort random noise.
306
- x: [N, L, D], sequence
307
- """
308
- N, L, D = x.shape # batch, length, dim
309
- len_keep = int(L * (1 - mask_ratio))
310
-
311
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
312
-
313
- # sort noise for each sample
314
- ids_shuffle = torch.argsort(
315
- noise, dim=1
316
- ) # ascend: small is keep, large is remove
317
- ids_restore = torch.argsort(ids_shuffle, dim=1)
318
-
319
- # keep the first subset
320
- ids_keep = ids_shuffle[:, :len_keep]
321
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
322
-
323
- # generate the binary mask: 0 is keep, 1 is remove
324
- mask = torch.ones([N, L], device=x.device)
325
- mask[:, :len_keep] = 0
326
- # unshuffle to get the binary mask
327
- mask = torch.gather(mask, dim=1, index=ids_restore)
328
-
329
- return x_masked, mask, ids_restore
330
-
331
- def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
332
- """
333
- 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
334
- Perform per-sample random masking by per-sample shuffling.
335
- Per-sample shuffling is done by argsort random noise.
336
- x: [N, L, D], sequence
337
- """
338
- N, L, D = x.shape # batch, length, dim
339
- if self.use_custom_patch: # overlapped patch
340
- T = 101
341
- F = 12
342
- else:
343
- T = 64
344
- F = 8
345
- # x = x.reshape(N, T, F, D)
346
- len_keep_t = int(T * (1 - mask_t_prob))
347
- len_keep_f = int(F * (1 - mask_f_prob))
348
-
349
- # noise for mask in time
350
- noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
351
- # sort noise for each sample aling time
352
- ids_shuffle_t = torch.argsort(
353
- noise_t, dim=1
354
- ) # ascend: small is keep, large is remove
355
- ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
356
- ids_keep_t = ids_shuffle_t[:, :len_keep_t]
357
- # noise mask in freq
358
- noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
359
- ids_shuffle_f = torch.argsort(
360
- noise_f, dim=1
361
- ) # ascend: small is keep, large is remove
362
- ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
363
- ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
364
-
365
- # generate the binary mask: 0 is keep, 1 is remove
366
- # mask in freq
367
- mask_f = torch.ones(N, F, device=x.device)
368
- mask_f[:, :len_keep_f] = 0
369
- mask_f = (
370
- torch.gather(mask_f, dim=1, index=ids_restore_f)
371
- .unsqueeze(1)
372
- .repeat(1, T, 1)
373
- ) # N,T,F
374
- # mask in time
375
- mask_t = torch.ones(N, T, device=x.device)
376
- mask_t[:, :len_keep_t] = 0
377
- mask_t = (
378
- torch.gather(mask_t, dim=1, index=ids_restore_t)
379
- .unsqueeze(1)
380
- .repeat(1, F, 1)
381
- .permute(0, 2, 1)
382
- ) # N,T,F
383
- mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
384
-
385
- # get masked x
386
- id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
387
- id2res = id2res + 999 * mask # add a large value for masked elements
388
- id2res2 = torch.argsort(id2res.flatten(start_dim=1))
389
- ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
390
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
391
-
392
- ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
393
- mask = mask.flatten(start_dim=1)
394
-
395
- return x_masked, mask, ids_restore
396
-
397
- def forward_encoder(self, x, mask_ratio, mask_2d=False):
398
- # embed patches
399
- x = self.patch_embed(x)
400
- # add pos embed w/o cls token
401
- x = x + self.pos_embed[:, 1:, :]
402
-
403
- # masking: length -> length * mask_ratio
404
- if mask_2d:
405
- x, mask, ids_restore = self.random_masking_2d(
406
- x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
407
- )
408
- else:
409
- x, mask, ids_restore = self.random_masking(x, mask_ratio)
410
-
411
- # append cls token
412
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
413
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
414
- x = torch.cat((cls_tokens, x), dim=1)
415
-
416
- # apply Transformer blocks
417
- for blk in self.blocks:
418
- x = blk(x)
419
- x = self.norm(x)
420
-
421
- return x, mask, ids_restore, None
422
-
423
- def forward_encoder_no_random_mask_no_average(self, x):
424
- # embed patches
425
- x = self.patch_embed(x)
426
- # add pos embed w/o cls token
427
- x = x + self.pos_embed[:, 1:, :]
428
-
429
- # masking: length -> length * mask_ratio
430
- # if mask_2d:
431
- # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
432
- # else:
433
- # x, mask, ids_restore = self.random_masking(x, mask_ratio)
434
-
435
- # append cls token
436
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
437
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
438
- x = torch.cat((cls_tokens, x), dim=1)
439
-
440
- # apply Transformer blocks
441
- for blk in self.blocks:
442
- x = blk(x)
443
- x = self.norm(x)
444
-
445
- return x
446
-
447
- def forward_encoder_no_mask(self, x):
448
- # embed patches
449
- x = self.patch_embed(x)
450
-
451
- # add pos embed w/o cls token
452
- x = x + self.pos_embed[:, 1:, :]
453
-
454
- # masking: length -> length * mask_ratio
455
- # x, mask, ids_restore = self.random_masking(x, mask_ratio)
456
- # append cls token
457
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
458
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
459
- x = torch.cat((cls_tokens, x), dim=1)
460
-
461
- # apply Transformer blocks
462
- contextual_embs = []
463
- for n, blk in enumerate(self.blocks):
464
- x = blk(x)
465
- if n > self.contextual_depth:
466
- contextual_embs.append(self.norm(x))
467
- # x = self.norm(x)
468
- contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
469
-
470
- return contextual_emb
471
-
472
- def forward_decoder(self, x, ids_restore):
473
- # embed tokens
474
- x = self.decoder_embed(x)
475
-
476
- # append mask tokens to sequence
477
- mask_tokens = self.mask_token.repeat(
478
- x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
479
- )
480
- x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
481
- x_ = torch.gather(
482
- x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
483
- ) # unshuffle
484
- x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
485
-
486
- # add pos embed
487
- x = x + self.decoder_pos_embed
488
-
489
- if self.decoder_mode != 0:
490
- B, L, D = x.shape
491
- x = x[:, 1:, :]
492
- if self.use_custom_patch:
493
- x = x.reshape(B, 101, 12, D)
494
- x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
495
- x = x.reshape(B, 1224, D)
496
- if self.decoder_mode > 3: # mvit
497
- x = self.decoder_blocks(x)
498
- else:
499
- # apply Transformer blocks
500
- for blk in self.decoder_blocks:
501
- x = blk(x)
502
- x = self.decoder_norm(x)
503
-
504
- # predictor projection
505
- pred = self.decoder_pred(x)
506
-
507
- # remove cls token
508
- if self.decoder_mode != 0:
509
- if self.use_custom_patch:
510
- pred = pred.reshape(B, 102, 12, 256)
511
- pred = pred[:, :101, :, :]
512
- pred = pred.reshape(B, 1212, 256)
513
- else:
514
- pred = pred
515
- else:
516
- pred = pred[:, 1:, :]
517
- return pred, None, None # emb, emb_pixel
518
-
519
- def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
520
- """
521
- imgs: [N, 3, H, W]
522
- pred: [N, L, p*p*3]
523
- mask: [N, L], 0 is keep, 1 is remove,
524
- """
525
- target = self.patchify(imgs)
526
- if norm_pix_loss:
527
- mean = target.mean(dim=-1, keepdim=True)
528
- var = target.var(dim=-1, keepdim=True)
529
- target = (target - mean) / (var + 1.0e-6) ** 0.5
530
-
531
- loss = (pred - target) ** 2
532
- loss = loss.mean(dim=-1) # [N, L], mean loss per patch
533
-
534
- loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
535
- return loss
536
-
537
- def forward(self, imgs, mask_ratio=0.8):
538
- emb_enc, mask, ids_restore, _ = self.forward_encoder(
539
- imgs, mask_ratio, mask_2d=self.mask_2d
540
- )
541
- pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
542
- loss_recon = self.forward_loss(
543
- imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
544
- )
545
- loss_contrastive = torch.FloatTensor([0.0]).cuda()
546
- return loss_recon, pred, mask, loss_contrastive
547
-
548
-
549
- def mae_vit_small_patch16_dec512d8b(**kwargs):
550
- model = MaskedAutoencoderViT(
551
- patch_size=16,
552
- embed_dim=384,
553
- depth=12,
554
- num_heads=6,
555
- decoder_embed_dim=512,
556
- decoder_num_heads=16,
557
- mlp_ratio=4,
558
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
559
- **kwargs,
560
- )
561
- return model
562
-
563
-
564
- def mae_vit_base_patch16_dec512d8b(**kwargs):
565
- model = MaskedAutoencoderViT(
566
- patch_size=16,
567
- embed_dim=768,
568
- depth=12,
569
- num_heads=12,
570
- decoder_embed_dim=512,
571
- decoder_num_heads=16,
572
- mlp_ratio=4,
573
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
574
- **kwargs,
575
- )
576
- return model
577
-
578
-
579
- def mae_vit_large_patch16_dec512d8b(**kwargs):
580
- model = MaskedAutoencoderViT(
581
- patch_size=16,
582
- embed_dim=1024,
583
- depth=24,
584
- num_heads=16,
585
- decoder_embed_dim=512,
586
- decoder_num_heads=16,
587
- mlp_ratio=4,
588
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
589
- **kwargs,
590
- )
591
- return model
592
-
593
-
594
- def mae_vit_huge_patch14_dec512d8b(**kwargs):
595
- model = MaskedAutoencoderViT(
596
- patch_size=14,
597
- embed_dim=1280,
598
- depth=32,
599
- num_heads=16,
600
- decoder_embed_dim=512,
601
- decoder_num_heads=16,
602
- mlp_ratio=4,
603
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
604
- **kwargs,
605
- )
606
- return model
607
-
608
-
609
- # set recommended archs
610
- mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
611
- mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
612
- mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
613
- mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from timm.models.vision_transformer import Block
18
+ from audiosr.latent_diffusion.modules.audiomae.util.pos_embed import (
19
+ get_2d_sincos_pos_embed,
20
+ get_2d_sincos_pos_embed_flexible,
21
+ )
22
+ from audiosr.latent_diffusion.modules.audiomae.util.patch_embed import (
23
+ PatchEmbed_new,
24
+ PatchEmbed_org,
25
+ )
26
+
27
+
28
+ class MaskedAutoencoderViT(nn.Module):
29
+ """Masked Autoencoder with VisionTransformer backbone"""
30
+
31
+ def __init__(
32
+ self,
33
+ img_size=224,
34
+ patch_size=16,
35
+ stride=10,
36
+ in_chans=3,
37
+ embed_dim=1024,
38
+ depth=24,
39
+ num_heads=16,
40
+ decoder_embed_dim=512,
41
+ decoder_depth=8,
42
+ decoder_num_heads=16,
43
+ mlp_ratio=4.0,
44
+ norm_layer=nn.LayerNorm,
45
+ norm_pix_loss=False,
46
+ audio_exp=False,
47
+ alpha=0.0,
48
+ temperature=0.2,
49
+ mode=0,
50
+ contextual_depth=8,
51
+ use_custom_patch=False,
52
+ split_pos=False,
53
+ pos_trainable=False,
54
+ use_nce=False,
55
+ beta=4.0,
56
+ decoder_mode=0,
57
+ mask_t_prob=0.6,
58
+ mask_f_prob=0.5,
59
+ mask_2d=False,
60
+ epoch=0,
61
+ no_shift=False,
62
+ ):
63
+ super().__init__()
64
+
65
+ self.audio_exp = audio_exp
66
+ self.embed_dim = embed_dim
67
+ self.decoder_embed_dim = decoder_embed_dim
68
+ # --------------------------------------------------------------------------
69
+ # MAE encoder specifics
70
+ if use_custom_patch:
71
+ print(
72
+ f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
73
+ )
74
+ self.patch_embed = PatchEmbed_new(
75
+ img_size=img_size,
76
+ patch_size=patch_size,
77
+ in_chans=in_chans,
78
+ embed_dim=embed_dim,
79
+ stride=stride,
80
+ )
81
+ else:
82
+ self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
83
+ self.use_custom_patch = use_custom_patch
84
+ num_patches = self.patch_embed.num_patches
85
+
86
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
87
+
88
+ # self.split_pos = split_pos # not useful
89
+ self.pos_embed = nn.Parameter(
90
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
91
+ ) # fixed sin-cos embedding
92
+
93
+ self.encoder_depth = depth
94
+ self.contextual_depth = contextual_depth
95
+ self.blocks = nn.ModuleList(
96
+ [
97
+ Block(
98
+ embed_dim,
99
+ num_heads,
100
+ mlp_ratio,
101
+ qkv_bias=True,
102
+ norm_layer=norm_layer,
103
+ ) # qk_scale=None
104
+ for i in range(depth)
105
+ ]
106
+ )
107
+ self.norm = norm_layer(embed_dim)
108
+
109
+ # --------------------------------------------------------------------------
110
+ # MAE decoder specifics
111
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
112
+
113
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
114
+ self.decoder_pos_embed = nn.Parameter(
115
+ torch.zeros(1, num_patches + 1, decoder_embed_dim),
116
+ requires_grad=pos_trainable,
117
+ ) # fixed sin-cos embedding
118
+
119
+ self.no_shift = no_shift
120
+
121
+ self.decoder_mode = decoder_mode
122
+ if (
123
+ self.use_custom_patch
124
+ ): # overlapped patches as in AST. Similar performance yet compute heavy
125
+ window_size = (6, 6)
126
+ feat_size = (102, 12)
127
+ else:
128
+ window_size = (4, 4)
129
+ feat_size = (64, 8)
130
+ if self.decoder_mode == 1:
131
+ decoder_modules = []
132
+ for index in range(16):
133
+ if self.no_shift:
134
+ shift_size = (0, 0)
135
+ else:
136
+ if (index % 2) == 0:
137
+ shift_size = (0, 0)
138
+ else:
139
+ shift_size = (2, 0)
140
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
141
+ decoder_modules.append(
142
+ SwinTransformerBlock(
143
+ dim=decoder_embed_dim,
144
+ num_heads=16,
145
+ feat_size=feat_size,
146
+ window_size=window_size,
147
+ shift_size=shift_size,
148
+ mlp_ratio=mlp_ratio,
149
+ drop=0.0,
150
+ drop_attn=0.0,
151
+ drop_path=0.0,
152
+ extra_norm=False,
153
+ sequential_attn=False,
154
+ norm_layer=norm_layer, # nn.LayerNorm,
155
+ )
156
+ )
157
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
158
+ else:
159
+ # Transfomer
160
+ self.decoder_blocks = nn.ModuleList(
161
+ [
162
+ Block(
163
+ decoder_embed_dim,
164
+ decoder_num_heads,
165
+ mlp_ratio,
166
+ qkv_bias=True,
167
+ norm_layer=norm_layer,
168
+ ) # qk_scale=None,
169
+ for i in range(decoder_depth)
170
+ ]
171
+ )
172
+
173
+ self.decoder_norm = norm_layer(decoder_embed_dim)
174
+ self.decoder_pred = nn.Linear(
175
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True
176
+ ) # decoder to patch
177
+
178
+ # --------------------------------------------------------------------------
179
+
180
+ self.norm_pix_loss = norm_pix_loss
181
+
182
+ self.patch_size = patch_size
183
+ self.stride = stride
184
+
185
+ # audio exps
186
+ self.alpha = alpha
187
+ self.T = temperature
188
+ self.mode = mode
189
+ self.use_nce = use_nce
190
+ self.beta = beta
191
+
192
+ self.log_softmax = nn.LogSoftmax(dim=-1)
193
+
194
+ self.mask_t_prob = mask_t_prob
195
+ self.mask_f_prob = mask_f_prob
196
+ self.mask_2d = mask_2d
197
+
198
+ self.epoch = epoch
199
+
200
+ self.initialize_weights()
201
+
202
+ def initialize_weights(self):
203
+ # initialization
204
+ # initialize (and freeze) pos_embed by sin-cos embedding
205
+ if self.audio_exp:
206
+ pos_embed = get_2d_sincos_pos_embed_flexible(
207
+ self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
208
+ )
209
+ else:
210
+ pos_embed = get_2d_sincos_pos_embed(
211
+ self.pos_embed.shape[-1],
212
+ int(self.patch_embed.num_patches**0.5),
213
+ cls_token=True,
214
+ )
215
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
216
+
217
+ if self.audio_exp:
218
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
219
+ self.decoder_pos_embed.shape[-1],
220
+ self.patch_embed.patch_hw,
221
+ cls_token=True,
222
+ )
223
+ else:
224
+ decoder_pos_embed = get_2d_sincos_pos_embed(
225
+ self.decoder_pos_embed.shape[-1],
226
+ int(self.patch_embed.num_patches**0.5),
227
+ cls_token=True,
228
+ )
229
+ self.decoder_pos_embed.data.copy_(
230
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
231
+ )
232
+
233
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
234
+ w = self.patch_embed.proj.weight.data
235
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
236
+
237
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
238
+ torch.nn.init.normal_(self.cls_token, std=0.02)
239
+ torch.nn.init.normal_(self.mask_token, std=0.02)
240
+
241
+ # initialize nn.Linear and nn.LayerNorm
242
+ self.apply(self._init_weights)
243
+
244
+ def _init_weights(self, m):
245
+ if isinstance(m, nn.Linear):
246
+ # we use xavier_uniform following official JAX ViT:
247
+ torch.nn.init.xavier_uniform_(m.weight)
248
+ if isinstance(m, nn.Linear) and m.bias is not None:
249
+ nn.init.constant_(m.bias, 0)
250
+ elif isinstance(m, nn.LayerNorm):
251
+ nn.init.constant_(m.bias, 0)
252
+ nn.init.constant_(m.weight, 1.0)
253
+
254
+ def patchify(self, imgs):
255
+ """
256
+ imgs: (N, 3, H, W)
257
+ x: (N, L, patch_size**2 *3)
258
+ L = (H/p)*(W/p)
259
+ """
260
+ p = self.patch_embed.patch_size[0]
261
+ # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
262
+
263
+ if self.audio_exp:
264
+ if self.use_custom_patch: # overlapped patch
265
+ h, w = self.patch_embed.patch_hw
266
+ # todo: fixed h/w patch size and stride size. Make hw custom in the future
267
+ x = imgs.unfold(2, self.patch_size, self.stride).unfold(
268
+ 3, self.patch_size, self.stride
269
+ ) # n,1,H,W -> n,1,h,w,p,p
270
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
271
+ # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
272
+ # x = torch.einsum('nchpwq->nhwpqc', x)
273
+ # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
274
+ else:
275
+ h = imgs.shape[2] // p
276
+ w = imgs.shape[3] // p
277
+ # h,w = self.patch_embed.patch_hw
278
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
279
+ x = torch.einsum("nchpwq->nhwpqc", x)
280
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
281
+ else:
282
+ h = w = imgs.shape[2] // p
283
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
284
+ x = torch.einsum("nchpwq->nhwpqc", x)
285
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
286
+
287
+ return x
288
+
289
+ def unpatchify(self, x):
290
+ """
291
+ x: (N, L, patch_size**2 *3)
292
+ specs: (N, 1, H, W)
293
+ """
294
+ p = self.patch_embed.patch_size[0]
295
+ h = 1024 // p
296
+ w = 128 // p
297
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
298
+ x = torch.einsum("nhwpqc->nchpwq", x)
299
+ specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
300
+ return specs
301
+
302
+ def random_masking(self, x, mask_ratio):
303
+ """
304
+ Perform per-sample random masking by per-sample shuffling.
305
+ Per-sample shuffling is done by argsort random noise.
306
+ x: [N, L, D], sequence
307
+ """
308
+ N, L, D = x.shape # batch, length, dim
309
+ len_keep = int(L * (1 - mask_ratio))
310
+
311
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
312
+
313
+ # sort noise for each sample
314
+ ids_shuffle = torch.argsort(
315
+ noise, dim=1
316
+ ) # ascend: small is keep, large is remove
317
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
318
+
319
+ # keep the first subset
320
+ ids_keep = ids_shuffle[:, :len_keep]
321
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
322
+
323
+ # generate the binary mask: 0 is keep, 1 is remove
324
+ mask = torch.ones([N, L], device=x.device)
325
+ mask[:, :len_keep] = 0
326
+ # unshuffle to get the binary mask
327
+ mask = torch.gather(mask, dim=1, index=ids_restore)
328
+
329
+ return x_masked, mask, ids_restore
330
+
331
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
332
+ """
333
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
334
+ Perform per-sample random masking by per-sample shuffling.
335
+ Per-sample shuffling is done by argsort random noise.
336
+ x: [N, L, D], sequence
337
+ """
338
+ N, L, D = x.shape # batch, length, dim
339
+ if self.use_custom_patch: # overlapped patch
340
+ T = 101
341
+ F = 12
342
+ else:
343
+ T = 64
344
+ F = 8
345
+ # x = x.reshape(N, T, F, D)
346
+ len_keep_t = int(T * (1 - mask_t_prob))
347
+ len_keep_f = int(F * (1 - mask_f_prob))
348
+
349
+ # noise for mask in time
350
+ noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
351
+ # sort noise for each sample aling time
352
+ ids_shuffle_t = torch.argsort(
353
+ noise_t, dim=1
354
+ ) # ascend: small is keep, large is remove
355
+ ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
356
+ ids_keep_t = ids_shuffle_t[:, :len_keep_t]
357
+ # noise mask in freq
358
+ noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
359
+ ids_shuffle_f = torch.argsort(
360
+ noise_f, dim=1
361
+ ) # ascend: small is keep, large is remove
362
+ ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
363
+ ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
364
+
365
+ # generate the binary mask: 0 is keep, 1 is remove
366
+ # mask in freq
367
+ mask_f = torch.ones(N, F, device=x.device)
368
+ mask_f[:, :len_keep_f] = 0
369
+ mask_f = (
370
+ torch.gather(mask_f, dim=1, index=ids_restore_f)
371
+ .unsqueeze(1)
372
+ .repeat(1, T, 1)
373
+ ) # N,T,F
374
+ # mask in time
375
+ mask_t = torch.ones(N, T, device=x.device)
376
+ mask_t[:, :len_keep_t] = 0
377
+ mask_t = (
378
+ torch.gather(mask_t, dim=1, index=ids_restore_t)
379
+ .unsqueeze(1)
380
+ .repeat(1, F, 1)
381
+ .permute(0, 2, 1)
382
+ ) # N,T,F
383
+ mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
384
+
385
+ # get masked x
386
+ id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
387
+ id2res = id2res + 999 * mask # add a large value for masked elements
388
+ id2res2 = torch.argsort(id2res.flatten(start_dim=1))
389
+ ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
390
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
391
+
392
+ ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
393
+ mask = mask.flatten(start_dim=1)
394
+
395
+ return x_masked, mask, ids_restore
396
+
397
+ def forward_encoder(self, x, mask_ratio, mask_2d=False):
398
+ # embed patches
399
+ x = self.patch_embed(x)
400
+ # add pos embed w/o cls token
401
+ x = x + self.pos_embed[:, 1:, :]
402
+
403
+ # masking: length -> length * mask_ratio
404
+ if mask_2d:
405
+ x, mask, ids_restore = self.random_masking_2d(
406
+ x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
407
+ )
408
+ else:
409
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
410
+
411
+ # append cls token
412
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
413
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
414
+ x = torch.cat((cls_tokens, x), dim=1)
415
+
416
+ # apply Transformer blocks
417
+ for blk in self.blocks:
418
+ x = blk(x)
419
+ x = self.norm(x)
420
+
421
+ return x, mask, ids_restore, None
422
+
423
+ def forward_encoder_no_random_mask_no_average(self, x):
424
+ # embed patches
425
+ x = self.patch_embed(x)
426
+ # add pos embed w/o cls token
427
+ x = x + self.pos_embed[:, 1:, :]
428
+
429
+ # masking: length -> length * mask_ratio
430
+ # if mask_2d:
431
+ # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
432
+ # else:
433
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
434
+
435
+ # append cls token
436
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
437
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
438
+ x = torch.cat((cls_tokens, x), dim=1)
439
+
440
+ # apply Transformer blocks
441
+ for blk in self.blocks:
442
+ x = blk(x)
443
+ x = self.norm(x)
444
+
445
+ return x
446
+
447
+ def forward_encoder_no_mask(self, x):
448
+ # embed patches
449
+ x = self.patch_embed(x)
450
+
451
+ # add pos embed w/o cls token
452
+ x = x + self.pos_embed[:, 1:, :]
453
+
454
+ # masking: length -> length * mask_ratio
455
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
456
+ # append cls token
457
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
458
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
459
+ x = torch.cat((cls_tokens, x), dim=1)
460
+
461
+ # apply Transformer blocks
462
+ contextual_embs = []
463
+ for n, blk in enumerate(self.blocks):
464
+ x = blk(x)
465
+ if n > self.contextual_depth:
466
+ contextual_embs.append(self.norm(x))
467
+ # x = self.norm(x)
468
+ contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
469
+
470
+ return contextual_emb
471
+
472
+ def forward_decoder(self, x, ids_restore):
473
+ # embed tokens
474
+ x = self.decoder_embed(x)
475
+
476
+ # append mask tokens to sequence
477
+ mask_tokens = self.mask_token.repeat(
478
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
479
+ )
480
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
481
+ x_ = torch.gather(
482
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
483
+ ) # unshuffle
484
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
485
+
486
+ # add pos embed
487
+ x = x + self.decoder_pos_embed
488
+
489
+ if self.decoder_mode != 0:
490
+ B, L, D = x.shape
491
+ x = x[:, 1:, :]
492
+ if self.use_custom_patch:
493
+ x = x.reshape(B, 101, 12, D)
494
+ x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
495
+ x = x.reshape(B, 1224, D)
496
+ if self.decoder_mode > 3: # mvit
497
+ x = self.decoder_blocks(x)
498
+ else:
499
+ # apply Transformer blocks
500
+ for blk in self.decoder_blocks:
501
+ x = blk(x)
502
+ x = self.decoder_norm(x)
503
+
504
+ # predictor projection
505
+ pred = self.decoder_pred(x)
506
+
507
+ # remove cls token
508
+ if self.decoder_mode != 0:
509
+ if self.use_custom_patch:
510
+ pred = pred.reshape(B, 102, 12, 256)
511
+ pred = pred[:, :101, :, :]
512
+ pred = pred.reshape(B, 1212, 256)
513
+ else:
514
+ pred = pred
515
+ else:
516
+ pred = pred[:, 1:, :]
517
+ return pred, None, None # emb, emb_pixel
518
+
519
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
520
+ """
521
+ imgs: [N, 3, H, W]
522
+ pred: [N, L, p*p*3]
523
+ mask: [N, L], 0 is keep, 1 is remove,
524
+ """
525
+ target = self.patchify(imgs)
526
+ if norm_pix_loss:
527
+ mean = target.mean(dim=-1, keepdim=True)
528
+ var = target.var(dim=-1, keepdim=True)
529
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
530
+
531
+ loss = (pred - target) ** 2
532
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
533
+
534
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
535
+ return loss
536
+
537
+ def forward(self, imgs, mask_ratio=0.8):
538
+ emb_enc, mask, ids_restore, _ = self.forward_encoder(
539
+ imgs, mask_ratio, mask_2d=self.mask_2d
540
+ )
541
+ pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
542
+ loss_recon = self.forward_loss(
543
+ imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
544
+ )
545
+ loss_contrastive = torch.FloatTensor([0.0]).cuda()
546
+ return loss_recon, pred, mask, loss_contrastive
547
+
548
+
549
+ def mae_vit_small_patch16_dec512d8b(**kwargs):
550
+ model = MaskedAutoencoderViT(
551
+ patch_size=16,
552
+ embed_dim=384,
553
+ depth=12,
554
+ num_heads=6,
555
+ decoder_embed_dim=512,
556
+ decoder_num_heads=16,
557
+ mlp_ratio=4,
558
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
559
+ **kwargs,
560
+ )
561
+ return model
562
+
563
+
564
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
565
+ model = MaskedAutoencoderViT(
566
+ patch_size=16,
567
+ embed_dim=768,
568
+ depth=12,
569
+ num_heads=12,
570
+ decoder_embed_dim=512,
571
+ decoder_num_heads=16,
572
+ mlp_ratio=4,
573
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
574
+ **kwargs,
575
+ )
576
+ return model
577
+
578
+
579
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
580
+ model = MaskedAutoencoderViT(
581
+ patch_size=16,
582
+ embed_dim=1024,
583
+ depth=24,
584
+ num_heads=16,
585
+ decoder_embed_dim=512,
586
+ decoder_num_heads=16,
587
+ mlp_ratio=4,
588
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
589
+ **kwargs,
590
+ )
591
+ return model
592
+
593
+
594
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
595
+ model = MaskedAutoencoderViT(
596
+ patch_size=14,
597
+ embed_dim=1280,
598
+ depth=32,
599
+ num_heads=16,
600
+ decoder_embed_dim=512,
601
+ decoder_num_heads=16,
602
+ mlp_ratio=4,
603
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
604
+ **kwargs,
605
+ )
606
+ return model
607
+
608
+
609
+ # set recommended archs
610
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
611
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
612
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
613
+ mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
audiosr/latent_diffusion/modules/audiomae/models_vit.py CHANGED
@@ -1,243 +1,243 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
- # DeiT: https://github.com/facebookresearch/deit
10
- # --------------------------------------------------------
11
-
12
- from functools import partial
13
-
14
- import torch
15
- import torch.nn as nn
16
- import timm.models.vision_transformer
17
-
18
-
19
- class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
20
- """Vision Transformer with support for global average pooling"""
21
-
22
- def __init__(
23
- self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
24
- ):
25
- super(VisionTransformer, self).__init__(**kwargs)
26
-
27
- self.global_pool = global_pool
28
- if self.global_pool:
29
- norm_layer = kwargs["norm_layer"]
30
- embed_dim = kwargs["embed_dim"]
31
- self.fc_norm = norm_layer(embed_dim)
32
- del self.norm # remove the original norm
33
- self.mask_2d = mask_2d
34
- self.use_custom_patch = use_custom_patch
35
-
36
- def forward_features(self, x):
37
- B = x.shape[0]
38
- x = self.patch_embed(x)
39
- x = x + self.pos_embed[:, 1:, :]
40
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
41
- cls_tokens = cls_token.expand(
42
- B, -1, -1
43
- ) # stole cls_tokens impl from Phil Wang, thanks
44
- x = torch.cat((cls_tokens, x), dim=1)
45
- x = self.pos_drop(x)
46
-
47
- for blk in self.blocks:
48
- x = blk(x)
49
-
50
- if self.global_pool:
51
- x = x[:, 1:, :].mean(dim=1) # global pool without cls token
52
- outcome = self.fc_norm(x)
53
- else:
54
- x = self.norm(x)
55
- outcome = x[:, 0]
56
-
57
- return outcome
58
-
59
- def random_masking(self, x, mask_ratio):
60
- """
61
- Perform per-sample random masking by per-sample shuffling.
62
- Per-sample shuffling is done by argsort random noise.
63
- x: [N, L, D], sequence
64
- """
65
- N, L, D = x.shape # batch, length, dim
66
- len_keep = int(L * (1 - mask_ratio))
67
-
68
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
69
-
70
- # sort noise for each sample
71
- ids_shuffle = torch.argsort(
72
- noise, dim=1
73
- ) # ascend: small is keep, large is remove
74
- ids_restore = torch.argsort(ids_shuffle, dim=1)
75
-
76
- # keep the first subset
77
- ids_keep = ids_shuffle[:, :len_keep]
78
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
79
-
80
- # generate the binary mask: 0 is keep, 1 is remove
81
- mask = torch.ones([N, L], device=x.device)
82
- mask[:, :len_keep] = 0
83
- # unshuffle to get the binary mask
84
- mask = torch.gather(mask, dim=1, index=ids_restore)
85
-
86
- return x_masked, mask, ids_restore
87
-
88
- def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
89
- """
90
- 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
91
- Perform per-sample random masking by per-sample shuffling.
92
- Per-sample shuffling is done by argsort random noise.
93
- x: [N, L, D], sequence
94
- """
95
-
96
- N, L, D = x.shape # batch, length, dim
97
- if self.use_custom_patch:
98
- # # for AS
99
- T = 101 # 64,101
100
- F = 12 # 8,12
101
- # # for ESC
102
- # T=50
103
- # F=12
104
- # for SPC
105
- # T=12
106
- # F=12
107
- else:
108
- # ## for AS
109
- T = 64
110
- F = 8
111
- # ## for ESC
112
- # T=32
113
- # F=8
114
- ## for SPC
115
- # T=8
116
- # F=8
117
-
118
- # mask T
119
- x = x.reshape(N, T, F, D)
120
- len_keep_T = int(T * (1 - mask_t_prob))
121
- noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
122
- # sort noise for each sample
123
- ids_shuffle = torch.argsort(
124
- noise, dim=1
125
- ) # ascend: small is keep, large is remove
126
- ids_keep = ids_shuffle[:, :len_keep_T]
127
- index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
128
- # x_masked = torch.gather(x, dim=1, index=index)
129
- # x_masked = x_masked.reshape(N,len_keep_T*F,D)
130
- x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
131
-
132
- # mask F
133
- # x = x.reshape(N, T, F, D)
134
- x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
135
- len_keep_F = int(F * (1 - mask_f_prob))
136
- noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
137
- # sort noise for each sample
138
- ids_shuffle = torch.argsort(
139
- noise, dim=1
140
- ) # ascend: small is keep, large is remove
141
- ids_keep = ids_shuffle[:, :len_keep_F]
142
- # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
143
- index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
144
- x_masked = torch.gather(x, dim=1, index=index)
145
- x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
146
- # x_masked = x_masked.reshape(N,len_keep*T,D)
147
- x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
148
-
149
- return x_masked, None, None
150
-
151
- def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
152
- B = x.shape[0] # 4,1,1024,128
153
- x = self.patch_embed(x) # 4, 512, 768
154
-
155
- x = x + self.pos_embed[:, 1:, :]
156
- if self.random_masking_2d:
157
- x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
158
- else:
159
- x, mask, ids_restore = self.random_masking(x, mask_t_prob)
160
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
161
- cls_tokens = cls_token.expand(B, -1, -1)
162
- x = torch.cat((cls_tokens, x), dim=1)
163
- x = self.pos_drop(x)
164
-
165
- # apply Transformer blocks
166
- for blk in self.blocks:
167
- x = blk(x)
168
-
169
- if self.global_pool:
170
- x = x[:, 1:, :].mean(dim=1) # global pool without cls token
171
- outcome = self.fc_norm(x)
172
- else:
173
- x = self.norm(x)
174
- outcome = x[:, 0]
175
-
176
- return outcome
177
-
178
- # overwrite original timm
179
- def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
180
- if mask_t_prob > 0.0 or mask_f_prob > 0.0:
181
- x = self.forward_features_mask(
182
- x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
183
- )
184
- else:
185
- x = self.forward_features(x)
186
- x = self.head(x)
187
- return x
188
-
189
-
190
- def vit_small_patch16(**kwargs):
191
- model = VisionTransformer(
192
- patch_size=16,
193
- embed_dim=384,
194
- depth=12,
195
- num_heads=6,
196
- mlp_ratio=4,
197
- qkv_bias=True,
198
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
199
- **kwargs
200
- )
201
- return model
202
-
203
-
204
- def vit_base_patch16(**kwargs):
205
- model = VisionTransformer(
206
- patch_size=16,
207
- embed_dim=768,
208
- depth=12,
209
- num_heads=12,
210
- mlp_ratio=4,
211
- qkv_bias=True,
212
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
213
- **kwargs
214
- )
215
- return model
216
-
217
-
218
- def vit_large_patch16(**kwargs):
219
- model = VisionTransformer(
220
- patch_size=16,
221
- embed_dim=1024,
222
- depth=24,
223
- num_heads=16,
224
- mlp_ratio=4,
225
- qkv_bias=True,
226
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
227
- **kwargs
228
- )
229
- return model
230
-
231
-
232
- def vit_huge_patch14(**kwargs):
233
- model = VisionTransformer(
234
- patch_size=14,
235
- embed_dim=1280,
236
- depth=32,
237
- num_heads=16,
238
- mlp_ratio=4,
239
- qkv_bias=True,
240
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
241
- **kwargs
242
- )
243
- return model
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import timm.models.vision_transformer
17
+
18
+
19
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
20
+ """Vision Transformer with support for global average pooling"""
21
+
22
+ def __init__(
23
+ self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
24
+ ):
25
+ super(VisionTransformer, self).__init__(**kwargs)
26
+
27
+ self.global_pool = global_pool
28
+ if self.global_pool:
29
+ norm_layer = kwargs["norm_layer"]
30
+ embed_dim = kwargs["embed_dim"]
31
+ self.fc_norm = norm_layer(embed_dim)
32
+ del self.norm # remove the original norm
33
+ self.mask_2d = mask_2d
34
+ self.use_custom_patch = use_custom_patch
35
+
36
+ def forward_features(self, x):
37
+ B = x.shape[0]
38
+ x = self.patch_embed(x)
39
+ x = x + self.pos_embed[:, 1:, :]
40
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
41
+ cls_tokens = cls_token.expand(
42
+ B, -1, -1
43
+ ) # stole cls_tokens impl from Phil Wang, thanks
44
+ x = torch.cat((cls_tokens, x), dim=1)
45
+ x = self.pos_drop(x)
46
+
47
+ for blk in self.blocks:
48
+ x = blk(x)
49
+
50
+ if self.global_pool:
51
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
52
+ outcome = self.fc_norm(x)
53
+ else:
54
+ x = self.norm(x)
55
+ outcome = x[:, 0]
56
+
57
+ return outcome
58
+
59
+ def random_masking(self, x, mask_ratio):
60
+ """
61
+ Perform per-sample random masking by per-sample shuffling.
62
+ Per-sample shuffling is done by argsort random noise.
63
+ x: [N, L, D], sequence
64
+ """
65
+ N, L, D = x.shape # batch, length, dim
66
+ len_keep = int(L * (1 - mask_ratio))
67
+
68
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
69
+
70
+ # sort noise for each sample
71
+ ids_shuffle = torch.argsort(
72
+ noise, dim=1
73
+ ) # ascend: small is keep, large is remove
74
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
79
+
80
+ # generate the binary mask: 0 is keep, 1 is remove
81
+ mask = torch.ones([N, L], device=x.device)
82
+ mask[:, :len_keep] = 0
83
+ # unshuffle to get the binary mask
84
+ mask = torch.gather(mask, dim=1, index=ids_restore)
85
+
86
+ return x_masked, mask, ids_restore
87
+
88
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
89
+ """
90
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
91
+ Perform per-sample random masking by per-sample shuffling.
92
+ Per-sample shuffling is done by argsort random noise.
93
+ x: [N, L, D], sequence
94
+ """
95
+
96
+ N, L, D = x.shape # batch, length, dim
97
+ if self.use_custom_patch:
98
+ # # for AS
99
+ T = 101 # 64,101
100
+ F = 12 # 8,12
101
+ # # for ESC
102
+ # T=50
103
+ # F=12
104
+ # for SPC
105
+ # T=12
106
+ # F=12
107
+ else:
108
+ # ## for AS
109
+ T = 64
110
+ F = 8
111
+ # ## for ESC
112
+ # T=32
113
+ # F=8
114
+ ## for SPC
115
+ # T=8
116
+ # F=8
117
+
118
+ # mask T
119
+ x = x.reshape(N, T, F, D)
120
+ len_keep_T = int(T * (1 - mask_t_prob))
121
+ noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
122
+ # sort noise for each sample
123
+ ids_shuffle = torch.argsort(
124
+ noise, dim=1
125
+ ) # ascend: small is keep, large is remove
126
+ ids_keep = ids_shuffle[:, :len_keep_T]
127
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
128
+ # x_masked = torch.gather(x, dim=1, index=index)
129
+ # x_masked = x_masked.reshape(N,len_keep_T*F,D)
130
+ x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
131
+
132
+ # mask F
133
+ # x = x.reshape(N, T, F, D)
134
+ x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
135
+ len_keep_F = int(F * (1 - mask_f_prob))
136
+ noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
137
+ # sort noise for each sample
138
+ ids_shuffle = torch.argsort(
139
+ noise, dim=1
140
+ ) # ascend: small is keep, large is remove
141
+ ids_keep = ids_shuffle[:, :len_keep_F]
142
+ # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
143
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
144
+ x_masked = torch.gather(x, dim=1, index=index)
145
+ x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
146
+ # x_masked = x_masked.reshape(N,len_keep*T,D)
147
+ x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
148
+
149
+ return x_masked, None, None
150
+
151
+ def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
152
+ B = x.shape[0] # 4,1,1024,128
153
+ x = self.patch_embed(x) # 4, 512, 768
154
+
155
+ x = x + self.pos_embed[:, 1:, :]
156
+ if self.random_masking_2d:
157
+ x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
158
+ else:
159
+ x, mask, ids_restore = self.random_masking(x, mask_t_prob)
160
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
161
+ cls_tokens = cls_token.expand(B, -1, -1)
162
+ x = torch.cat((cls_tokens, x), dim=1)
163
+ x = self.pos_drop(x)
164
+
165
+ # apply Transformer blocks
166
+ for blk in self.blocks:
167
+ x = blk(x)
168
+
169
+ if self.global_pool:
170
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
171
+ outcome = self.fc_norm(x)
172
+ else:
173
+ x = self.norm(x)
174
+ outcome = x[:, 0]
175
+
176
+ return outcome
177
+
178
+ # overwrite original timm
179
+ def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
180
+ if mask_t_prob > 0.0 or mask_f_prob > 0.0:
181
+ x = self.forward_features_mask(
182
+ x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
183
+ )
184
+ else:
185
+ x = self.forward_features(x)
186
+ x = self.head(x)
187
+ return x
188
+
189
+
190
+ def vit_small_patch16(**kwargs):
191
+ model = VisionTransformer(
192
+ patch_size=16,
193
+ embed_dim=384,
194
+ depth=12,
195
+ num_heads=6,
196
+ mlp_ratio=4,
197
+ qkv_bias=True,
198
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
199
+ **kwargs
200
+ )
201
+ return model
202
+
203
+
204
+ def vit_base_patch16(**kwargs):
205
+ model = VisionTransformer(
206
+ patch_size=16,
207
+ embed_dim=768,
208
+ depth=12,
209
+ num_heads=12,
210
+ mlp_ratio=4,
211
+ qkv_bias=True,
212
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
213
+ **kwargs
214
+ )
215
+ return model
216
+
217
+
218
+ def vit_large_patch16(**kwargs):
219
+ model = VisionTransformer(
220
+ patch_size=16,
221
+ embed_dim=1024,
222
+ depth=24,
223
+ num_heads=16,
224
+ mlp_ratio=4,
225
+ qkv_bias=True,
226
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
227
+ **kwargs
228
+ )
229
+ return model
230
+
231
+
232
+ def vit_huge_patch14(**kwargs):
233
+ model = VisionTransformer(
234
+ patch_size=14,
235
+ embed_dim=1280,
236
+ depth=32,
237
+ num_heads=16,
238
+ mlp_ratio=4,
239
+ qkv_bias=True,
240
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
241
+ **kwargs
242
+ )
243
+ return model
audiosr/latent_diffusion/modules/audiomae/util/crop.py CHANGED
@@ -1,43 +1,43 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import math
8
-
9
- import torch
10
-
11
- from torchvision import transforms
12
- from torchvision.transforms import functional as F
13
-
14
-
15
- class RandomResizedCrop(transforms.RandomResizedCrop):
16
- """
17
- RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
- This may lead to results different with torchvision's version.
19
- Following BYOL's TF code:
20
- https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
- """
22
-
23
- @staticmethod
24
- def get_params(img, scale, ratio):
25
- width, height = F._get_image_size(img)
26
- area = height * width
27
-
28
- target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
29
- log_ratio = torch.log(torch.tensor(ratio))
30
- aspect_ratio = torch.exp(
31
- torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
32
- ).item()
33
-
34
- w = int(round(math.sqrt(target_area * aspect_ratio)))
35
- h = int(round(math.sqrt(target_area / aspect_ratio)))
36
-
37
- w = min(w, width)
38
- h = min(h, height)
39
-
40
- i = torch.randint(0, height - h + 1, size=(1,)).item()
41
- j = torch.randint(0, width - w + 1, size=(1,)).item()
42
-
43
- return i, j, h, w
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+
23
+ @staticmethod
24
+ def get_params(img, scale, ratio):
25
+ width, height = F._get_image_size(img)
26
+ area = height * width
27
+
28
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
29
+ log_ratio = torch.log(torch.tensor(ratio))
30
+ aspect_ratio = torch.exp(
31
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
32
+ ).item()
33
+
34
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
35
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
36
+
37
+ w = min(w, width)
38
+ h = min(h, height)
39
+
40
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
41
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
42
+
43
+ return i, j, h, w